#include "linker.hpp"
using obj_info_t = std::pair<PIMAGE_ARCHIVE_MEMBER_HEADER, std::uint32_t>;
using lib_info_t = std::map<unsigned long, obj_info_t>;

namespace lnk
{
	auto get_symbol_size(symbol_t& sym, obj_buffer_t& obj) -> std::uint32_t
	{
		const auto coff_header =
			reinterpret_cast<PIMAGE_FILE_HEADER>(obj.data());

		const auto section_headers =
			reinterpret_cast<PIMAGE_SECTION_HEADER>(
				obj.data() + sizeof IMAGE_FILE_HEADER);

		const auto symbol_table =
			reinterpret_cast<PIMAGE_SYMBOL>(
				coff_header->PointerToSymbolTable + obj.data());

		std::uint32_t result = 
			section_headers[sym.section_number - 1].SizeOfRawData;

		// loop over all symbols in this object...
		// find the next symbol inside of the same section...
		for (auto idx = 0u; idx < coff_header->NumberOfSymbols; ++idx)
			if (symbol_table[idx].SectionNumber == sym.section_number)
				if (symbol_table[idx].Value > sym.section_offset && symbol_table[idx].Value < result)
					result = symbol_table[idx].Value;

		return result - sym.section_offset;
	}

	auto get_map_symbols(std::string map_path) -> map_symbols_t
	{
		std::ifstream map_file(map_path);

		if (!map_file.is_open())
			return { {}, {} };

		std::string line;
		map_symbols_t result;
		while (std::getline(map_file, line))
		{
			const auto colon_index = line.find(":");
			if (colon_index == std::string::npos)
				break;

			const auto section_number = 
				std::strtoul(line.substr(1, 
					colon_index).c_str(), NULL, 16);

			const auto section_offset = 
				std::strtoull(line.substr(
					colon_index + 1, 16).c_str(), NULL, 16);

			auto symbol = line.substr(
				colon_index + 16 + 8,
					line.length() - (colon_index + 16 + 7));

			symbol[symbol.length()] = '\0';
			result[symbol] = { section_number, section_offset };
		}
		return result;
	}

	auto get_objs(std::string lib_path, std::vector<obj_buffer_t>& objs) -> bool
	{
		std::vector<std::uint8_t> lib_file;
		utils::open_binary_file(lib_path, lib_file);

		if (strncmp(reinterpret_cast<const char*>(lib_file.data()),
			IMAGE_ARCHIVE_START, sizeof IMAGE_ARCHIVE_START - 1))
			return false;

		const auto linker_header =
			reinterpret_cast<unsigned long*>(
				lib_file.data() + (sizeof IMAGE_ARCHIVE_START - 1) +
					sizeof IMAGE_ARCHIVE_MEMBER_HEADER);

		lib_info_t offsets;
		for (auto idx = 1u; idx < _byteswap_ulong(linker_header[0]); ++idx)
		{
			const auto archive_header = 
				reinterpret_cast<PIMAGE_ARCHIVE_MEMBER_HEADER>(
					lib_file.data() + _byteswap_ulong(linker_header[idx]));

			const auto obj_size = std::atoi(reinterpret_cast<const char*>(archive_header->Size));
			offsets[_byteswap_ulong(linker_header[idx])] = { archive_header, obj_size };
		}

		for (auto& [file_offset, obj_info] : offsets)
		{
			const auto obj_start = lib_file.data() + 
				file_offset + sizeof IMAGE_ARCHIVE_MEMBER_HEADER;

			objs.push_back(obj_buffer_t(
				obj_start, obj_start + obj_info.second));
		}

		return true;
	}

	namespace sym
	{
		auto relocations(obj_buffer_t& obj) -> std::vector<image_reloc_t>
		{
			const auto coff_header =
				reinterpret_cast<PIMAGE_FILE_HEADER>(obj.data());

			const auto section_headers =
				reinterpret_cast<PIMAGE_SECTION_HEADER>(
					obj.data() + sizeof IMAGE_FILE_HEADER);

			const auto symbol_table =
				reinterpret_cast<PIMAGE_SYMBOL>(
					coff_header->PointerToSymbolTable + obj.data());

			const auto string_table =
				reinterpret_cast<const char*>(
					reinterpret_cast<std::uintptr_t>(symbol_table) +
						(coff_header->NumberOfSymbols * sizeof IMAGE_SYMBOL));

			std::vector<image_reloc_t> result;
			for (auto idx = 0u; idx < coff_header->NumberOfSections; ++idx)
			{
				if (section_headers[idx].PointerToRelocations)
				{
					// for some reason the compiler makes some empty sections...
					if (!section_headers[idx].SizeOfRawData)
						continue;

					// skip over sections that we will not be using...
					if (section_headers[idx].Characteristics & IMAGE_SCN_LNK_REMOVE)
						continue;

					// skip over discardable sections...
					if (section_headers[idx].Characteristics & IMAGE_SCN_MEM_DISCARDABLE)
						continue;

					// skip both the .pdata and the .xdata sections... these are used for exceptions...
					if (!strncmp(reinterpret_cast<const char*>(
						section_headers[idx].Name), ".pdata", strlen(".pdata") - 1))
						continue;

					if (!strncmp(reinterpret_cast<const char*>(
						section_headers[idx].Name), ".xdata", strlen(".xdata") - 1))
						continue;

					const auto reloc_dir =
						reinterpret_cast<PIMAGE_RELOCATION>(
							section_headers[idx].PointerToRelocations + obj.data());

					for (auto reloc_idx = 0u; reloc_idx <
						section_headers[idx].NumberOfRelocations; ++reloc_idx)
					{
						image_reloc_t entry;
						entry.file_offset =
							reloc_dir[reloc_idx].VirtualAddress +
							section_headers[idx].PointerToRawData;

						if (symbol_table[reloc_dir[reloc_idx].SymbolTableIndex].N.Name.Short)
							entry.resolve_symbol_name =
							std::string(reinterpret_cast<const char*>(symbol_table[reloc_dir[
								reloc_idx].SymbolTableIndex].N.ShortName));
						else
							entry.resolve_symbol_name = std::string(
								string_table + symbol_table[reloc_dir[
									reloc_idx].SymbolTableIndex].N.Name.Long);

						entry.raw_reloc = reloc_dir[reloc_idx];
						entry.raw_symbol = symbol_table[reloc_dir[reloc_idx].SymbolTableIndex];
						entry.type = reloc_dir[reloc_idx].Type;
						result.push_back(entry);
					}
				}
			}

			return result;
		}

		auto symbols(obj_buffer_t& obj) -> std::vector<symbol_t>
		{
			const auto coff_header =
				reinterpret_cast<PIMAGE_FILE_HEADER>(obj.data());

			const auto section_headers =
				reinterpret_cast<PIMAGE_SECTION_HEADER>(
					obj.data() + sizeof IMAGE_FILE_HEADER);

			const auto symbol_table =
				reinterpret_cast<PIMAGE_SYMBOL>(
					coff_header->PointerToSymbolTable + obj.data());

			const auto string_table =
				reinterpret_cast<const char*>(
					reinterpret_cast<std::uintptr_t>(symbol_table) +
					(coff_header->NumberOfSymbols * sizeof IMAGE_SYMBOL));

			std::vector<symbol_t> result;
			for (auto idx = 0u; idx < coff_header->NumberOfSymbols; ++idx)
			{
				symbol_t symbol{};
				if (symbol_table[idx].N.Name.Short)
					symbol.symbol_name =
					std::string(reinterpret_cast<char*>(
						symbol_table[idx].N.ShortName));
				else
					symbol.symbol_name =
					std::string(string_table +
						symbol_table[idx].N.Name.Long);

				// skip section symbols... we only want 
				// .data, .rdata, and executable (function) symbols...
				if (symbol.symbol_name.empty() ||
					symbol.symbol_name.c_str()[0] == '.' ||
					symbol_table[idx].SectionNumber < 1)
				{
					if (symbol_table[idx].NumberOfAuxSymbols)
						idx += symbol_table[idx].NumberOfAuxSymbols;

					continue;
				}

				symbol.file_offset = section_headers[symbol_table[idx]
					.SectionNumber - 1].PointerToRawData + symbol_table[idx].Value;

				symbol.section_number = symbol_table[idx].SectionNumber;
				symbol.section_offset = symbol_table[idx].Value;
				symbol.type = symbol_table[idx].Type;
				symbol.size = get_symbol_size(symbol, obj);

				const auto section_name =
					reinterpret_cast<const char*>(
						section_headers[symbol_table[idx].SectionNumber - 1].Name);

				if (!strncmp(section_name, ".theo2", sizeof(".theo2") - 1))
					symbol.obfuscate_type = theo_type::encrypt;
				else if (!strncmp(section_name, ".theo1", sizeof(".theo1") - 1))
					symbol.obfuscate_type = theo_type::mutate;
				else if (!strncmp(section_name, ".theo", sizeof(".theo") - 1))
					symbol.obfuscate_type = theo_type::obfuscate;
				else
					symbol.obfuscate_type = (theo_type)NULL;

				// there can be more then one aux symbols...
				if (symbol_table[idx].NumberOfAuxSymbols)
					idx += symbol_table[idx].NumberOfAuxSymbols;

				result.push_back(symbol);
			}
			return result;
		}
	}

	namespace section
	{
		auto get_header(obj_buffer_t& obj, const char* section_name) -> PIMAGE_SECTION_HEADER
		{
			const auto coff_header =
				reinterpret_cast<PIMAGE_FILE_HEADER>(obj.data());

			const auto section_headers =
				reinterpret_cast<PIMAGE_SECTION_HEADER>(
					obj.data() + sizeof IMAGE_FILE_HEADER);

			for (auto idx = 0u; idx < coff_header->NumberOfSections; ++idx)
				if (!strncmp((char*)section_headers[idx].Name, section_name, strlen(section_name) - 1))
					return section_headers + idx;

			return {};
		}

		auto for_each(section_callback_t callback, obj_buffer_t& obj) -> void
		{
			const auto coff_header =
				reinterpret_cast<PIMAGE_FILE_HEADER>(obj.data());

			const auto section_headers =
				reinterpret_cast<PIMAGE_SECTION_HEADER>(
					obj.data() + sizeof IMAGE_FILE_HEADER);

			for (auto idx = 0u; idx < coff_header->NumberOfSections; ++idx)
				if (!callback(section_headers + idx, obj))
					break;
		}
	}
}