#include <iostream>
#include <Windows.h>
#include <psapi.h>
#include <filesystem>

#include "theo.h"
#include "linker.hpp"

using extern_symbols_t = std::vector<std::pair<std::string, lnk::map_symbols_t>>;
using objs_buffer_t = std::vector<lnk::obj_buffer_t>;

auto get_mapping_info(int argc, char** argv) -> std::pair<objs_buffer_t, extern_symbols_t>
{
	auto maps_offset = 0u;
	std::vector<lnk::obj_buffer_t> image_objs;
	std::vector<std::pair<std::string, lnk::map_symbols_t>> extern_symbols;

	for (auto idx = 2; idx < argc; ++idx)
	{
		if (!strcmp(argv[idx], "--maps"))
		{
			maps_offset = idx + 1;
			break;
		}

		// another flag so we break...
		if (argv[idx][0] == '-' && argv[idx][1] == '-')
			break;

		if (!lnk::get_objs(argv[idx], image_objs))
		{
			std::printf("> failed to parse lib...\n");
			return {};
		}
	}

	if (maps_offset)
	{
		for (auto idx = maps_offset; idx <= argc - 1; ++idx)
		{
			extern_symbols.push_back
			({
				std::filesystem::path(argv[idx]).stem().string(),
				lnk::get_map_symbols(argv[idx])
			});
		}
	}

	return { image_objs, extern_symbols };
}

int main(int argc, char** argv)
{
	if (argc < 3 || strcmp(argv[1], "--libs"))
	{
		std::printf("[!] invalid usage... please use one of the following:\n");
		std::printf("		> theo.exe --libs one.lib two.lib three.lib --pid 1234\n");
		std::printf("		> theo.exe --libs one.lib --pid 1234\n");
		return -1;
	}

	auto [image_objs, extern_symbols] = get_mapping_info(argc, argv);
	std::printf("[+] number of objs = %d\n", image_objs.size());

	if (!image_objs.size())
	{
		std::printf("[!] failed to parse .lib...\n");
		return -1;
	}

	std::uint32_t pid_offset = 0u, pid = 0u;
	for (auto idx = 3; idx < argc; ++idx)
		if (!strcmp(argv[idx], "--pid"))
			pid_offset = idx + 1;

	for (auto idx = 3; idx < argc; ++idx)
		if (!strcmp(argv[idx], "--debug"))
			dbg_print = true;

	if (!pid_offset || !(pid = std::atoi(argv[pid_offset])))
	{
		std::printf("[!] invalid pid...\n");
		return -1;
	}

	const auto phandle = 
		OpenProcess(PROCESS_ALL_ACCESS, FALSE, pid);

	if (phandle == INVALID_HANDLE_VALUE)
	{
		std::printf("[!] failed to open handle...\n");
		return -1;
	}

	theo::malloc_t _alloc = [&](std::size_t size, std::uint32_t prot) -> void*
	{
		const auto result = 
			VirtualAllocEx
			(
				phandle, 
				nullptr, 
				size, 
				MEM_COMMIT | MEM_RESERVE, 
				prot
			);

		if (!result)
		{
			std::printf("[!] failed to allocate virtual memory...\n");
			exit(-1);
		}

		return result;
	};

	theo::memcpy_t _memcpy =
		[&](void* dest, const void* src, std::size_t size) -> void*
	{
		SIZE_T bytes_handled;
		if (!WriteProcessMemory(phandle, dest, src, size, &bytes_handled))
		{
			std::printf("[!] failed to write memory... reason = 0x%x\n", GetLastError());
			std::getchar();
		}
		return dest;
	};

	theo::resolve_symbol_t _resolver = 
		[&, &extern_symbols = extern_symbols](const char* symbol_name) -> std::uintptr_t
	{
		static std::map<std::string, std::uintptr_t> symbol_table;

		if (!symbol_table[symbol_name])
		{
			auto loaded_modules = std::make_unique<HMODULE[]>(64);
			std::uintptr_t result = 0u, loaded_module_sz = 0u;

			if (!EnumProcessModules(phandle,
				loaded_modules.get(), 512, (PDWORD)&loaded_module_sz))
				return {};

			for (auto i = 0u; i < loaded_module_sz / 8u; i++)
			{
				wchar_t file_name[MAX_PATH] = L"";
				if (!GetModuleFileNameExW(phandle,
					loaded_modules.get()[i], file_name, _countof(file_name)))
					continue;

				if ((result = reinterpret_cast<std::uintptr_t>(
					GetProcAddress(LoadLibrary(file_name), symbol_name))))
					break;
			}

			symbol_table[symbol_name] = result;
			return result;
		}

		return symbol_table[symbol_name];
	};

	theo::hmm_ctx mapper({ _alloc, _memcpy, _resolver });
	if (!mapper.map_objs(image_objs))
	{
		std::printf("[!] failed to map object files...\n");
		std::getchar();
		return -1;
	}

	const auto module_entry = 
		reinterpret_cast<LPTHREAD_START_ROUTINE>(
			mapper.get_symbol("main"));

	std::printf("[+] module entry -> 0x%p\n", module_entry);
	std::getchar();

	if (module_entry)
	{
		std::uint32_t tid = 0u;
		CreateRemoteThread(phandle, NULL,
			NULL, module_entry, NULL, NULL, (LPDWORD)&tid);

		std::printf("[+] thread id = %d\n", tid);
	}
	std::printf("[+] press enter to close...\n");
	std::getchar();
}