#include "theo.h"

namespace theo
{
	hmm_ctx::hmm_ctx(const mapper_routines_t& routines)
		:
		alloc(std::get<0>(routines)),
		mcopy(std::get<1>(routines)),
		resolve_symbol(std::get<2>(routines))
	{}

	auto hmm_ctx::map_objs(std::vector<lnk::obj_buffer_t>& objs) -> bool
	{
		DBG_PRINT("[+] allocating space for symbols...\n");
		if (!alloc_symbol_space(objs))
		{
			DBG_PRINT("[!] failed to allocate symbol space...\n");
			return {};
		}

		DBG_PRINT("[+] allocating space for obfuscated symbols...\n");
		if (!alloc_obfuscated_symbol_space(objs))
		{
			DBG_PRINT("[!] failed to allocate space for obfuscated functions...\n");
			return {};
		}

		DBG_PRINT("[+] mapping obfuscated symbols...\n");
		if (!map_obfuscated_symbols(objs))
		{
			DBG_PRINT("[!] failed to resolve obfuscated relocs...\n");
			return {};
		}

		DBG_PRINT("[+] resolving non-obfuscated relocations...\n");
		if (!resolve_relocs(objs))
		{
			DBG_PRINT("[!] failed to resolve relocations...\n");
			return {};
		}

		DBG_PRINT("[+] mapping non-obfuscated symbols...\n");
		if (!map_symbols(objs))
		{
			DBG_PRINT("> failed to map symbols into memory...\n");
			return {};
		}

		DBG_PRINT("[+] linking complete...\n");
		return true;
	}

	auto hmm_ctx::get_symbol(std::string symbol_name)->std::uintptr_t
	{
		return mapped_symbols[symbol_name];
	}

	bool hmm_ctx::map_symbols(std::vector<lnk::obj_buffer_t>& objs)
	{
		for (auto& obj : objs)
		{
			const auto coff_header =
				reinterpret_cast<PIMAGE_FILE_HEADER>(obj.data());

			const auto section_headers =
				reinterpret_cast<PIMAGE_SECTION_HEADER>(
					obj.data() + sizeof IMAGE_FILE_HEADER);

			const auto symbol_table =
				reinterpret_cast<PIMAGE_SYMBOL>(
					coff_header->PointerToSymbolTable + obj.data());

			for (auto& symbol : lnk::sym::symbols(obj))
			{
				// dont map obfuscated routines into memory as they
				// get mapped differently...
				if (symbol.obfuscate_type)
					continue;

				const auto symbol_mapped =
					reinterpret_cast<void*>(
						mapped_symbols[symbol.symbol_name]);

				if (!symbol_mapped)
				{
					DBG_PRINT("\t> failed to resolve symbol allocation = %s\n", 
						symbol.symbol_name.c_str());

					return false;
				}

				const auto coff_header =
					reinterpret_cast<PIMAGE_FILE_HEADER>(obj.data());

				const auto section_headers =
					reinterpret_cast<PIMAGE_SECTION_HEADER>(
						obj.data() + sizeof IMAGE_FILE_HEADER);
				
				const auto section_name = 
					reinterpret_cast<const char*>(
						section_headers[symbol.section_number - 1].Name);

				if (!strncmp(section_name, ".bss", sizeof(".bss") - 1))
				{
					DBG_PRINT("\t> zero symbol = %s, at = 0x%p, size = 0x%x\n",
						symbol.symbol_name.c_str(), symbol_mapped, symbol.size);

					const auto temp_zero = malloc(symbol.size);
					memset(temp_zero, NULL, symbol.size);

					mcopy(symbol_mapped, temp_zero, symbol.size);
					free(temp_zero);
				}
				else
				{
					DBG_PRINT("\t> mapping symbol = %s, at = 0x%p, from = 0x%p, size = 0x%x\n",
						symbol.symbol_name.c_str(), symbol_mapped, obj.data() +
						symbol.file_offset, symbol.size);

					mcopy(symbol_mapped, obj.data() +
						symbol.file_offset, symbol.size);
				}
			}

			// there are .data/.rdata sections that have 0 symbols but
			// still have relocations to them... find them and map them...
			for (auto idx = 0u; idx < coff_header->NumberOfSections; ++idx)
			{
				// for some reason the compiler makes some empty sections...
				if (!section_headers[idx].SizeOfRawData)
					continue;

				// skip over sections that we will not be using...
				if (section_headers[idx].Characteristics & IMAGE_SCN_LNK_REMOVE)
					continue;

				// skip over discardable sections...
				if (section_headers[idx].Characteristics & IMAGE_SCN_MEM_DISCARDABLE)
					continue;

				// skip over executable sections...
				if (section_headers[idx].Characteristics & IMAGE_SCN_MEM_EXECUTE)
					continue;

				// skip both the .pdata and the .xdata sections... 
				// these are used for exceptions...
				// also skip over .bss sections...
				if (!strncmp(reinterpret_cast<const char*>(
					section_headers[idx].Name), ".pdata", sizeof(".pdata") - 1))
					continue;

				if (!strncmp(reinterpret_cast<const char*>(
					section_headers[idx].Name), ".xdata", sizeof(".xdata") - 1))
					continue;

				if (!strncmp(reinterpret_cast<const char*>(
					section_headers[idx].Name), ".bss", sizeof(".bss") - 1))
					continue;

				const auto data_section_sym =
					std::string(reinterpret_cast<const char*>(
						section_headers[idx].Name))
					.append("#")
					.append(std::to_string(idx + 1))
					.append("!")
					.append(std::to_string(obj.size()));

				if (!mapped_symbols[data_section_sym])
				{
					DBG_PRINT("\t> failed to resolve symbol allocation = %s\n",
						data_section_sym.c_str());

					return false;
				}

				// copy this section into memory... there are no symbols to it
				// but there are relocations to it... this is done by static init data...
				mcopy(reinterpret_cast<void*>(mapped_symbols[data_section_sym]), 
					obj.data() + section_headers[idx].PointerToRawData, section_headers[idx].SizeOfRawData);
			}
		}
		return true;
	}

	bool hmm_ctx::resolve_relocs(std::vector<lnk::obj_buffer_t>& objs)
	{
		for (auto& obj : objs)
		{
			for (auto& reloc : lnk::sym::relocations(obj))
			{
				if (reloc.type != IMAGE_REL_AMD64_ADDR64)
				{
					DBG_PRINT("[!] error... unsupported relocation at file offset = 0x%x\n", reloc.file_offset);
					DBG_PRINT("\t> symbol = %s\n", reloc.resolve_symbol_name.c_str());
					DBG_PRINT("\t> reloc type = 0x%x\n", reloc.type);
					DBG_PRINT("\t> object size = 0x%x\n", obj.size());
					return false;
				}

				const auto reloc_addr =
					reinterpret_cast<std::uintptr_t*>(
						obj.data() + reloc.file_offset);

				// check obj symbol table for this relocation...
				if (mapped_symbols[reloc.resolve_symbol_name])
				{
					DBG_PRINT("\t> resolving internal symbol...\n");
					DBG_PRINT("\t\t> address = 0x%p\n", mapped_symbols[reloc.resolve_symbol_name]);
					DBG_PRINT("\t\t> symbol = %s\n", reloc.resolve_symbol_name.c_str());
					*reloc_addr = mapped_symbols[reloc.resolve_symbol_name];
				}
				// this is a symbol that is defined but is given no value... the linker must allocate space for it...
				// these symbols are generated by llvm-obfuscator as "x" and "y"... 
				else if (reloc.raw_symbol.StorageClass == IMAGE_SYM_CLASS_EXTERNAL && reloc.raw_symbol.Value)
				{
					const auto zero_me = alloc(reloc.raw_symbol.Value, PAGE_READWRITE);
					const auto zero_size = malloc(reloc.raw_symbol.Value);

					memset(zero_size, NULL, reloc.raw_symbol.Value);
					mcopy(zero_me, zero_size, reloc.raw_symbol.Value);
					free(zero_size);

					*reloc_addr = reinterpret_cast<std::uintptr_t>(zero_me);
				}
				else if (reloc.resolve_symbol_name[0] == '.') // these relocations are offsets into already existing symbols...
				{
					const auto section_offset =
						*reinterpret_cast<std::uintptr_t*>(
							obj.data() + reloc.file_offset);

					const auto section_name = 
						std::string(reloc.resolve_symbol_name)
						.append("#")
						.append(std::to_string(reloc.raw_symbol.SectionNumber))
						.append("!")
						.append(std::to_string(obj.size()));

					DBG_PRINT("[+] reloc to section %s\n", section_name.c_str());
					DBG_PRINT("\t> relocation was going to be applied at = 0x%p\n", reloc_addr);
					DBG_PRINT("\t> reloc file offset = 0x%x\n", reloc.file_offset);
					DBG_PRINT("\t> object size = 0x%x\n", obj.size());
					DBG_PRINT("\t\t> symbol section number = %d\n", reloc.raw_symbol.SectionNumber);
					DBG_PRINT("\t\t> symbol storage class = %d\n", reloc.raw_symbol.StorageClass);
					DBG_PRINT("\t\t> symbol type = %d\n", reloc.raw_symbol.Type);
					DBG_PRINT("\t\t> section offset = 0x%x\n", section_offset);

					if (!mapped_symbols[section_name])
						return false;

					*reloc_addr = mapped_symbols[section_name] + section_offset;
				}
				else // else check external symbol table...
				{
					const auto extern_symbol =
						resolve_symbol(reloc.resolve_symbol_name.c_str());

					if (!extern_symbol)
					{
						DBG_PRINT("[!] unresolved external symbol = %s...\n",
							reloc.resolve_symbol_name.c_str());

						DBG_PRINT("\t> relocation was going to be applied at = 0x%p\n", reloc_addr);
						DBG_PRINT("\t> reloc file offset = 0x%x\n", reloc.file_offset);
						DBG_PRINT("\t> object size = 0x%x\n", obj.size());
						DBG_PRINT("\t\t> symbol section number = %d\n", reloc.raw_symbol.SectionNumber);
						DBG_PRINT("\t\t> symbol storage class = %d\n", reloc.raw_symbol.StorageClass);
						DBG_PRINT("\t\t> symbol type = %d\n", reloc.raw_symbol.Type);
						DBG_PRINT("\t\t> symbol offset = 0x%x\n", reloc.raw_symbol.Value);
						return false;
					}

					*reloc_addr = extern_symbol;

					DBG_PRINT("\t> resolving external symbol...\n");
					DBG_PRINT("\t\t> address = 0x%p\n", *reloc_addr);
					DBG_PRINT("\t\t> symbol = %s\n", reloc.resolve_symbol_name.c_str());
				}
			}
		}
		return true;
	}

	bool hmm_ctx::map_obfuscated_symbols(std::vector<lnk::obj_buffer_t>& objs)
	{
		for (auto& obj : objs)
		{
			for (auto& symbol : lnk::sym::symbols(obj))
			{
				if (!symbol.obfuscate_type)
					continue;
				
				DBG_PRINT("\t> mapping obfuscated routine %s into memory...\n", symbol.symbol_name.c_str());
				std::int32_t instruc_offset = 0u;

				while (true) // TODO: this is bad code... dont do this!
				{
					auto symbol_name = symbol.symbol_name;

					if (instruc_offset)
						symbol_name.append("@")
							.append(std::to_string(instruc_offset));

					// if there is no allocation for this symbol then we are done...
					if (!mapped_symbols[symbol_name])
						break;

					const auto instruc_len = obfuscated_gadgets
						[mapped_symbols[symbol_name]]->get_instruc().length;

					auto gadget_stack = obfuscated_gadgets
						[mapped_symbols[symbol_name]]->get_gadget();

					const auto gadget_size = obfuscated_gadgets
						[mapped_symbols[symbol_name]]->get_size();

					unsigned gadget_offset = 0u;
					std::vector<std::uint8_t> gadget_raw;

					for (auto& [gadget, reloc] : gadget_stack)
					{
						const auto fix_reloc_addr = 
							gadget.data() + reloc.offset;

						switch (reloc.type)
						{
							case obfuscation::reloc_type::jcc:
							{
								const auto next_instruc_symbol =
									std::string(symbol.symbol_name).append("@")
										.append(std::to_string(instruc_offset + instruc_len + reloc.rva));

								*reinterpret_cast<std::uintptr_t*>(fix_reloc_addr) =
									mapped_symbols[next_instruc_symbol];
								break;
							}
							case obfuscation::reloc_type::next_instruction_addr:
							{
								const auto next_instruc_symbol =
									std::string(symbol.symbol_name).append("@")
										.append(std::to_string(instruc_offset + instruc_len));

								*reinterpret_cast<std::uintptr_t*>(fix_reloc_addr) = 
									mapped_symbols[next_instruc_symbol];

								break; // we resolved our own relocation...
							}
							case obfuscation::reloc_type::none:
							{
								break;
							}
							default: 
							{
								// check this instruction to see if it needs any relocs...
								for (auto& reloc : lnk::sym::relocations(obj))
								{
									if (reloc.file_offset >= symbol.file_offset + instruc_offset &&
										reloc.file_offset < symbol.file_offset + instruc_offset + instruc_len)
									{
										DBG_PRINT("\t\t> resolving relocation for instruction...\n");
										if (reloc.type != IMAGE_REL_AMD64_ADDR64)
										{
											DBG_PRINT("[!] error, cannot resolve reloc = %s, type = 0x%x\n",
												reloc.resolve_symbol_name.c_str(), reloc.type);

											// cant relocate anything but IMAGE_REL_AMD64_ADDR64...
											// this is fine since the compiler shouldnt ever make any rip relative code
											// besides JCC's...
											return false;
										}

										const auto reloc_instruc_offset =
											reloc.file_offset - (symbol.file_offset + instruc_offset);

										const auto reloc_addr =
											reinterpret_cast<std::uintptr_t*>(
												gadget.data() + reloc_instruc_offset);

										// check obj symbol table for this relocation...
										if (mapped_symbols[reloc.resolve_symbol_name])
										{
											*reloc_addr = mapped_symbols[reloc.resolve_symbol_name];
										}
										// this is a symbol that is defined but is given no value... the linker must allocate space for it...
										// these symbols are generated by llvm-obfuscator as "x" and "y"... 
										else if (reloc.raw_symbol.StorageClass == IMAGE_SYM_CLASS_EXTERNAL && reloc.raw_symbol.Value)
										{
											const auto zero_me = alloc(reloc.raw_symbol.Value, PAGE_READWRITE);
											const auto zero_size = malloc(reloc.raw_symbol.Value);
											memset(zero_size, NULL, reloc.raw_symbol.Value);
											mcopy(zero_me, zero_size, reloc.raw_symbol.Value);
											free(zero_size);

											mapped_symbols[reloc.resolve_symbol_name] = reinterpret_cast<std::uintptr_t>(zero_me);
											*reloc_addr = reinterpret_cast<std::uintptr_t>(zero_me);
										}
										else if (reloc.resolve_symbol_name[0] == '.') // these relocations are offsets into already existing symbols...
										{
											const auto section_offset =
												*reinterpret_cast<std::uintptr_t*>(
													obj.data() + reloc.file_offset);

											const auto section_name =
												std::string(reloc.resolve_symbol_name)
												.append("#")
												.append(std::to_string(reloc.raw_symbol.SectionNumber))
												.append("!")
												.append(std::to_string(obj.size()));

											DBG_PRINT("[+] reloc to section %s\n", section_name.c_str());
											DBG_PRINT("\t> relocation was going to be applied at = 0x%p\n", reloc_addr);
											DBG_PRINT("\t> reloc file offset = 0x%x\n", reloc.file_offset);
											DBG_PRINT("\t> object size = 0x%x\n", obj.size());
											DBG_PRINT("\t\t> symbol section number = %d\n", reloc.raw_symbol.SectionNumber);
											DBG_PRINT("\t\t> symbol storage class = %d\n", reloc.raw_symbol.StorageClass);
											DBG_PRINT("\t\t> symbol type = %d\n", reloc.raw_symbol.Type);
											DBG_PRINT("\t\t> section offset = 0x%x\n", section_offset);

											if (!mapped_symbols[section_name])
												return false;

											*reloc_addr = mapped_symbols[section_name] + section_offset;
										}
										else // else check external symbol table...
										{
											const auto extern_symbol =
												resolve_symbol(reloc.resolve_symbol_name.c_str());

											if (!extern_symbol)
											{
												DBG_PRINT("[!] unresolved external symbol = %s...\n",
													reloc.resolve_symbol_name.c_str());

												return false;
											}

											*reloc_addr = extern_symbol;
										}

										DBG_PRINT("\t\t\t> address = 0x%p\n", *reloc_addr);
										DBG_PRINT("\t\t\t> symbol = %s\n", reloc.resolve_symbol_name.c_str());
										break; // break out of for loop... we resolve the symbol...
									}
								}
							}
						}

						gadget_raw.insert(gadget_raw.end(), gadget.begin(), gadget.end());
						gadget_offset += gadget.size();
					}

					const auto gadget_addr = 
						reinterpret_cast<void*>(
							mapped_symbols[symbol_name]);

					DBG_PRINT("\t> copying gadget at = 0x%p\n", gadget_addr);
					mcopy(gadget_addr, gadget_raw.data(), gadget_raw.size());
					// used to calc symbol for next instruction...
					instruc_offset += instruc_len;
				}
			}
		}
		return true;
	}

	bool hmm_ctx::alloc_obfuscated_symbol_space(std::vector<lnk::obj_buffer_t>& objs)
	{
		ZydisDecoder decoder;
		ZydisDecoderInit(&decoder, ZYDIS_MACHINE_MODE_LONG_64, ZYDIS_ADDRESS_WIDTH_64);

		for (auto& obj : objs)
		{
			for (auto& symbol : lnk::sym::symbols(obj))
			{
				// skip normal routines for now... those get scattered...
				if (!symbol.obfuscate_type)
					continue;

				ZyanUSize offset = 0;
				ZyanUSize length = symbol.size;
				ZydisDecodedInstruction instruction;

				const auto routine_begin = 
					symbol.file_offset + obj.data();

				bool first_instruction = true;
				while (ZYAN_SUCCESS(ZydisDecoderDecodeBuffer(
					&decoder, routine_begin + offset, length - offset, &instruction)))
				{
					// dont append @offset for the first instruction...
					auto new_symbol = symbol.symbol_name;

					if (first_instruction)
						first_instruction = false;
					else
						new_symbol.append("@")
							.append(std::to_string(offset));

					std::vector<std::uint8_t> instruc_bytes{};
					instruc_bytes.resize(instruction.length);

					memcpy(instruc_bytes.data(), obj.data() +
						symbol.file_offset + offset, instruction.length);

					std::shared_ptr<obfuscation::obfuscate> new_gadget{};
					switch (symbol.obfuscate_type)
					{
						case lnk::theo_type::obfuscate:
						{
							new_gadget.reset(
								new obfuscation::obfuscate(
									{ instruction, instruc_bytes }));
							break;
						}
						case lnk::theo_type::mutate:
						{
							new_gadget.reset(
								new obfuscation::mutation(
									{ instruction, instruc_bytes }));
							break;
						}
						default:
						{
							DBG_PRINT("[!] unsupported obfuscation type on routine = %s, type = %d\n", 
								symbol.symbol_name.c_str(), symbol.obfuscate_type);
							return false;
						}
					}

					mapped_symbols[new_symbol] =
						reinterpret_cast<std::uintptr_t>(
							alloc(new_gadget->get_size(), PAGE_EXECUTE_READWRITE));

					obfuscated_gadgets[mapped_symbols[new_symbol]] = new_gadget;
					DBG_PRINT("\t\t> %s allocated = 0x%p, size = %d\n", new_symbol.c_str(),
						mapped_symbols[new_symbol], new_gadget->get_size());

					offset += instruction.length;
				}
			}
		}
		return true;
	}

	bool hmm_ctx::alloc_symbol_space(std::vector<lnk::obj_buffer_t>& objs)
	{
		for (auto& obj : objs)
		{
			const auto coff_header =
				reinterpret_cast<PIMAGE_FILE_HEADER>(obj.data());

			const auto section_headers =
				reinterpret_cast<PIMAGE_SECTION_HEADER>(
					obj.data() + sizeof IMAGE_FILE_HEADER);

			const auto symbol_table =
				reinterpret_cast<PIMAGE_SYMBOL>(
					coff_header->PointerToSymbolTable + obj.data());

			for (auto& symbol : lnk::sym::symbols(obj))
			{
				// skip obfuscated routines for now... those get scattered...
				if (symbol.obfuscate_type)
					continue;

				// symbol is a function...
				if (symbol.type == IMAGE_SYM_FUNCTION)
				{
					mapped_symbols[symbol.symbol_name] =
						reinterpret_cast<std::uintptr_t>(alloc(symbol.size, PAGE_EXECUTE_READWRITE));

					DBG_PRINT("\t> %s allocated at = 0x%p, size = %d\n",
						symbol.symbol_name.c_str(), mapped_symbols[symbol.symbol_name], symbol.size);
				}
				else // else its a data/bss/rdata symbol... we map the entire section...
				{
					const auto data_section_sym =
						std::string(reinterpret_cast<const char*>(
							section_headers[symbol.section_number - 1].Name))
						.append("#")
						.append(std::to_string(symbol.section_number))
						.append("!")
						.append(std::to_string(obj.size()));

					if (!mapped_symbols[data_section_sym])
					{
						mapped_symbols[data_section_sym] =
							reinterpret_cast<std::uintptr_t>(alloc(
								section_headers[symbol.section_number - 1].SizeOfRawData, PAGE_READWRITE));

						DBG_PRINT("\t> section %s allocated at = 0x%p, size = %d\n",
							data_section_sym.c_str(),
							mapped_symbols[data_section_sym],
							section_headers[symbol.section_number - 1].SizeOfRawData);
					}

					mapped_symbols[symbol.symbol_name] =
						mapped_symbols[data_section_sym] + symbol.section_offset;
				}
			}

			// there are .data/.rdata sections that have 0 symbols but
			// still have relocations to them... find them and map them...
			for (auto idx = 0u; idx < coff_header->NumberOfSections; ++idx)
			{
				// for some reason the compiler makes some empty sections...
				if (!section_headers[idx].SizeOfRawData)
					continue;

				// skip over sections that we will not be using...
				if (section_headers[idx].Characteristics & IMAGE_SCN_LNK_REMOVE)
					continue;

				// skip over discardable sections...
				if (section_headers[idx].Characteristics & IMAGE_SCN_MEM_DISCARDABLE)
					continue;

				// skip over executable sections...
				if (section_headers[idx].Characteristics & IMAGE_SCN_MEM_EXECUTE)
					continue;

				// skip both the .pdata and the .xdata sections... these are used for exceptions...
				if (!strncmp(reinterpret_cast<const char*>(
					section_headers[idx].Name), ".pdata", strlen(".pdata") - 1))
					continue;

				if (!strncmp(reinterpret_cast<const char*>(
					section_headers[idx].Name), ".xdata", strlen(".xdata") - 1))
					continue;

				const auto data_section_sym =
					std::string(reinterpret_cast<const char*>(
						section_headers[idx].Name))
					.append("#")
					.append(std::to_string(idx + 1))
					.append("!")
					.append(std::to_string(obj.size()));

				if (!mapped_symbols[data_section_sym])
				{
					mapped_symbols[data_section_sym] =
						reinterpret_cast<std::uintptr_t>(alloc(
							section_headers[idx].SizeOfRawData, PAGE_READWRITE));

					DBG_PRINT("\t> section %s allocated at = 0x%p, size = %d\n",
						data_section_sym.c_str(),
						mapped_symbols[data_section_sym],
						section_headers[idx].SizeOfRawData);
				}
			}
		}

		return true;
	}
}