#include "client.hpp"

namespace theo
{
	client::client(SOCKET client_socket)
		:
		client_socket(client_socket),
		handler_thread(std::thread(&client::handler, this))
	{ handler_thread.detach(); }

	client::~client()
	{ 
		// socket might already be closed...
		closesocket(client_socket);
		connections.erase(client_socket);
	}

	void* client::wrapper_memcpy(void* dest, const void* src, std::size_t size) const
	{
		while (size)
		{
			std::uint32_t copy_size = size;
			if (copy_size > PACKET_DATA_SIZE)
				copy_size = PACKET_DATA_SIZE;

			theo_data* packet = new theo_data;
			memset(packet, NULL, sizeof theo_data);

			packet->type = theo_packet_type::copy_memory;
			packet->copy_memory.dest_addr = dest;
			packet->copy_memory.size = copy_size;
			memcpy(packet->copy_memory.data, src, copy_size);

			if (send(client_socket, reinterpret_cast<char*>(packet),
				sizeof theo_data, NULL) == SOCKET_ERROR)
			{
				std::printf("[!] failed to send data... reason = %d\n",
					WSAGetLastError());

				delete packet;
				return nullptr;
			}

			if (recv(client_socket, reinterpret_cast<char*>(packet),
				sizeof theo_data, MSG_WAITALL) == SOCKET_ERROR)
			{
				std::printf("[!] failed to recv alloc data... reason = %d\n",
					WSAGetLastError());

				delete packet;
				return nullptr;
			}

			delete packet;
			dest = reinterpret_cast<void*>(
				reinterpret_cast<std::uintptr_t>(dest) + copy_size);

			src = reinterpret_cast<void*>(
				reinterpret_cast<std::uintptr_t>(src) + copy_size);

			size -= copy_size;
		}
		return dest;
	}

	void* client::wrapper_alloc(std::size_t size, std::uint32_t prot) const
	{
		theo_data* packet = new theo_data;
		memset(packet, NULL, sizeof theo_data);

		packet->type = theo_packet_type::alloc_memory;
		packet->alloc.alloc_size = size;
		packet->alloc.prot = prot;

		if (send(client_socket, reinterpret_cast<char*>(packet), 
			sizeof theo_data, NULL) == SOCKET_ERROR)
		{
			std::printf("[!] failed to send data... reason = %d\n",
				WSAGetLastError());

			delete packet;
			return nullptr;
		}

		if (recv(client_socket, reinterpret_cast<char*>(packet),
			sizeof theo_data, MSG_WAITALL) == SOCKET_ERROR)
		{
			std::printf("[!] failed to recv alloc data... reason = %d\n",
				WSAGetLastError());

			delete packet;
			return nullptr;
		}

		const auto result = 
			packet->alloc.addr;

		delete packet;
		return result;
	}

	std::uintptr_t client::wrapper_resolve_symbol(const char* symbol_name) const
	{
		theo_data* packet = new theo_data;
		memset(packet, NULL, sizeof theo_data);

		packet->type = theo_packet_type::resolve_symbol;
		packet->resolve.symbol_size = strlen(symbol_name);
		strcpy(packet->resolve.symbol, symbol_name);

		if (send(client_socket, reinterpret_cast<char*>(packet), 
			sizeof theo_data, NULL) == SOCKET_ERROR)
		{
			std::printf("[!] failed to send data... reason = %d\n", 
				WSAGetLastError());

			delete packet;
			return {};
		}

		if (recv(client_socket, reinterpret_cast<char*>(packet),
			sizeof theo_data, MSG_WAITALL) == SOCKET_ERROR)
		{
			std::printf("[!] failed to recv alloc data... reason = %d\n",
				WSAGetLastError());

			delete packet;
			return {};
		}

		const auto result = 
			packet->resolve.symbol_addr;

		delete packet;
		return result;
	}

	void client::handler() const
	{
		int result{};
		theo_data* packet = new theo_data;
		memset(packet, NULL, sizeof theo_data);

		while ((result = recv(client_socket, reinterpret_cast<char*>(packet),
			sizeof theo_data, MSG_WAITALL)) != SOCKET_ERROR)
		{
			switch (packet->type)
			{
				case theo_packet_type::init:
				{
					theo::malloc_t alloc = [&](std::size_t size, std::uint32_t prot) -> void*
					{ return this->wrapper_alloc(size, prot); };

					theo::memcpy_t mcopy = 
						[&](void* dest, const void* src, std::size_t size) -> void*
					{ return this->wrapper_memcpy(dest, src, size); };

					theo::resolve_symbol_t resolve_symbol =
						[&](const char* symbol_name) -> std::uintptr_t
					{ return this->wrapper_resolve_symbol(symbol_name); };

					theo::hmm_ctx linker({ alloc, mcopy, resolve_symbol });
					std::vector<lnk::obj_buffer_t> objs = lib_files[packet->file];

					// map objs using a copy of the objs....
					if (!linker.map_objs(objs))
					{
						std::printf("[!] failed to map obj files... closing socket...\n");
						closesocket(client_socket);
						break; // cannot recover from this...
					}

					theo_data* response = new theo_data;
					memset(response, NULL, sizeof theo_data);
					response->type = theo_packet_type::disconnect;

					switch (packet->file)
					{
						case theo_file_type::demo_drv:
						{
							response->entry_point =
								linker.get_symbol("DrvEntry");
							break;
						}
						case theo_file_type::demo_dll:
						case theo_file_type::demo_imgui:
						{
							response->entry_point =
								linker.get_symbol("main");
							break;
						}
						default:
						{
							std::printf("[!] unsupported file... type = %d\n", response->file);
							closesocket(client_socket);
						}
					}

					std::printf("[+] completed linking...\n");
					std::printf("[+] module entry = 0x%p\n", response->entry_point);
					send(client_socket, reinterpret_cast<char*>(response), sizeof theo_data, NULL);
					closesocket(client_socket);
					break;
				}
				default:
				{
					std::printf("[!] unknown command = %d\n", packet->type);
					closesocket(client_socket);
				}
			}

			memset(packet, NULL, sizeof theo_data);
		}

		delete packet;
		std::printf("[+] socket closed with reason = %d\n", WSAGetLastError());
	}
}