#include <iostream>
#include <filesystem>

#include "theo.h"
#include "linker.hpp"
#include "vdm_ctx/vdm_ctx.hpp"
#include "utils.hpp"

using extern_symbols_t = std::vector<std::pair<std::string, lnk::map_symbols_t>>;
using objs_buffer_t = std::vector<lnk::obj_buffer_t>;

auto get_mapping_info(int argc, char** argv) -> std::pair<objs_buffer_t, extern_symbols_t>
{
	auto maps_offset = 0u;
	std::vector<lnk::obj_buffer_t> image_objs;
	std::vector<std::pair<std::string, lnk::map_symbols_t>> extern_symbols;

	for (auto idx = 2; idx < argc; ++idx)
	{
		if (!strcmp(argv[idx], "--maps"))
		{
			maps_offset = idx + 1;
			break;
		}

		// another flag so we break...
		if (argv[idx][0] == '-' && argv[idx][1] == '-')
			continue;

		if (!lnk::get_objs(argv[idx], image_objs))
		{
			std::printf("> failed to parse lib...\n");
			return {};
		}
	}

	if (maps_offset)
	{
		for (auto idx = maps_offset; idx <= argc - 1; ++idx)
		{
			extern_symbols.push_back
			({
				std::filesystem::path(argv[idx]).stem().string(),
				lnk::get_map_symbols(argv[idx])
			});
		}
	}

	return { image_objs, extern_symbols };
}

int main(int argc, char** argv)
{
	if (argc < 3 || strcmp(argv[1], "--libs"))
	{
		std::printf("[!] invalid usage... please use one of the following:\n");
		std::printf("\t\t> theo.exe --libs one.lib two.lib three.lib\n");
		std::printf("\t\t> theo.exe --libs one.lib --maps ntoskrnl.exe.map win32kbase.sys.map\n");
		return -1;
	}

	for (auto idx = 3; idx < argc; ++idx)
		if (!strcmp(argv[idx], "--debug"))
			dbg_print = true;

	auto [image_objs, extern_symbols] = get_mapping_info(argc, argv);
	std::printf("[+] number of objs = %d\n", image_objs.size());

	if (!image_objs.size())
	{
		std::printf("[!] failed to parse .lib...\n");
		return -1;
	}

	const auto [drv_handle, drv_key, drv_status] = vdm::load_drv();
	if (drv_status != STATUS_SUCCESS || drv_handle == INVALID_HANDLE_VALUE)
	{
		std::printf("> failed to load driver... reason -> 0x%x\n", drv_status);
		return -1;
	}

	// read physical memory using the driver...
	vdm::read_phys_t _read_phys =
		[&](void* addr, void* buffer, std::size_t size) -> bool
	{
		return vdm::read_phys(addr, buffer, size);
	};

	// write physical memory using the driver...
	vdm::write_phys_t _write_phys =
		[&](void* addr, void* buffer, std::size_t size) -> bool
	{
		return vdm::write_phys(addr, buffer, size);
	};

	// use VDM to syscall into ExAllocatePool...
	vdm::vdm_ctx vdm(_read_phys, _write_phys);
	theo::malloc_t _kalloc = 
		[&](std::size_t size, std::uint32_t prot) -> void*
	{
		using ex_alloc_pool_t =
			void* (*)(std::uint32_t, std::uint32_t);

		static const auto ex_alloc_pool =
			reinterpret_cast<void*>(
				utils::kmodule::get_export(
					"ntoskrnl.exe", "ExAllocatePool"));

		return vdm.syscall<ex_alloc_pool_t>(ex_alloc_pool, NULL, size);
	};

	// use VDM to syscall into memcpy exported by ntoskrnl.exe...
	theo::memcpy_t _kmemcpy =
		[&](void* dest, const void* src, std::size_t size) -> void*
	{
		static const auto kmemcpy =
			reinterpret_cast<void*>(
				utils::kmodule::get_export(
					"ntoskrnl.exe", "memcpy"));

		return vdm.syscall<decltype(&memcpy)>(kmemcpy, dest, src, size);
	};

	theo::resolve_symbol_t resolve_symbol =
		[&, &extern_symbols = extern_symbols](const char* symbol_name) -> std::uintptr_t
	{
		std::uintptr_t result = 0u;
		utils::kmodule::each_module
		(
			[&](PRTL_PROCESS_MODULE_INFORMATION drv_info, const char* drv_path) -> bool
			{
				const auto drv_name =
					reinterpret_cast<const char*>(
						drv_info->OffsetToFileName + drv_info->FullPathName);

				// false if we found the symbol...
				return (!(result = utils::kmodule::get_export(drv_name, symbol_name)));
			}
		);

		if (!result)
		{
			for (auto& [drv_name, drv_symbols] : extern_symbols)
			{
				// each kernel module... find a driver with a matching map file name...
				// I.E ntoskrnl.exe.map == ntoskrnl.exe...
				utils::kmodule::each_module
				(
					[&, &drv_name = drv_name, &drv_symbols = drv_symbols]
					(PRTL_PROCESS_MODULE_INFORMATION drv_info, const char* drv_path) -> bool
					{
						const auto _drv_name =
							reinterpret_cast<const char*>(
								drv_info->OffsetToFileName + drv_info->FullPathName);

						// if this is the driver, load it, loop over its sections
						// calc the absolute virtual address of the symbol...
						if (!strcmp(_drv_name, drv_name.c_str()))
						{
							const auto drv_load_addr =
								reinterpret_cast<std::uintptr_t>(
									LoadLibraryExA(drv_path, NULL, DONT_RESOLVE_DLL_REFERENCES));

							std::uint32_t section_count = 1u;
							utils::pe::each_section
							(
								[&, &drv_symbols = drv_symbols]
								(PIMAGE_SECTION_HEADER section_header, std::uintptr_t img_base) -> bool
								{
									if (section_count == drv_symbols[symbol_name].first)
									{
										result = reinterpret_cast<std::uintptr_t>(drv_info->ImageBase) +
											section_header->VirtualAddress + drv_symbols[symbol_name].second;

										// we found the symbol...
										return false;
									}

									++section_count;
									// keep going over sections...
									return true;
								}, drv_load_addr
								);
						}

						// keep looping over modules until we resolve the symbol...
						return !result;
					}
				);

				// if we found the symbol then break out of the loop... else keep looping...
				if (result) break;
			}
		}

		// finally return the result...
		return result;
	};

	theo::hmm_ctx drv_mapper({ _kalloc, _kmemcpy, resolve_symbol });
	if (!drv_mapper.map_objs(image_objs))
	{
		std::printf("[!] failed to map object files...\n");
		return -1;
	}

	const auto drv_entry = drv_mapper.get_symbol("DrvEntry");
	std::printf("> driver entry -> 0x%p\n", drv_entry);
	std::getchar();

	if (drv_entry)
	{
		// call driver entry... its up to you to do this using whatever method...
		// with VDM you can syscall into it... with msrexec you will use msrexec::exec...
		const auto entry_result =
			vdm.syscall<NTSTATUS(*)()>(
				reinterpret_cast<void*>(drv_entry));
	}

	const auto unload_status = vdm::unload_drv(drv_handle, drv_key);
	if (unload_status != STATUS_SUCCESS)
	{
		std::printf("> failed to unload driver... reason -> 0x%x\n", unload_status);
		return -1;
	}

	std::printf("> press enter to close...\n");
	std::getchar();
}