#ifndef _VULKAN_GAME_H
#define _VULKAN_GAME_H

#define GLM_FORCE_RADIANS
#define GLM_FORCE_DEPTH_ZERO_TO_ONE // Since, in Vulkan, the depth range is 0 to 1 instead of -1 to 1
#define GLM_FORCE_RIGHT_HANDED

#include <glm/glm.hpp>
#include <glm/gtc/matrix_transform.hpp>

#include "game-gui-sdl.hpp"
#include "graphics-pipeline_vulkan.hpp"

#include "vulkan-utils.hpp"

using namespace glm;

#ifdef NDEBUG
   const bool ENABLE_VALIDATION_LAYERS = false;
#else
   const bool ENABLE_VALIDATION_LAYERS = true;
#endif

struct ModelVertex {
   vec3 pos;
   vec3 color;
   vec2 texCoord;
};

struct OverlayVertex {
   vec3 pos;
   vec2 texCoord;
};

struct ShipVertex {
   vec3 pos;
   vec3 color;
   vec3 normal;
   unsigned int objIndex;
};

// TODO: Change the index type to uint32_t and check the Vulkan Tutorial loading model section as a reference
// TODO: Create a typedef for index type so I can easily change uin16_t to something else later
template<class VertexType>
struct SceneObject {
   vector<VertexType> vertices;
   vector<uint16_t> indices;

   mat4 model_base;
   mat4 model_transform;
};

struct UBO_VP_mats {
   alignas(16) mat4 view;
   alignas(16) mat4 proj;
};

struct SBO_SceneObject {
   alignas(16) mat4 model;
};

class VulkanGame {
   public:
      VulkanGame(int maxFramesInFlight);
      ~VulkanGame();

      void run(int width, int height, unsigned char guiFlags);

   private:
      const int MAX_FRAMES_IN_FLIGHT;

      const float NEAR_CLIP = 0.1f;
      const float FAR_CLIP = 100.0f;
      const float FOV_ANGLE = 67.0f; // means the camera lens goes from -33 deg to 33 def

      vec3 cam_pos;

      GameGui* gui;

      SDL_version sdlVersion;
      SDL_Window* window = nullptr;
      SDL_Renderer* renderer = nullptr;

      SDL_Texture* uiOverlay = nullptr;

      VkInstance instance;
      VkDebugUtilsMessengerEXT debugMessenger;
      VkSurfaceKHR surface; // TODO: Change the variable name to vulkanSurface
      VkPhysicalDevice physicalDevice = VK_NULL_HANDLE;
      VkDevice device;

      VkQueue graphicsQueue;
      VkQueue presentQueue;

      VkSwapchainKHR swapChain;
      vector<VkImage> swapChainImages;
      VkFormat swapChainImageFormat;
      VkExtent2D swapChainExtent;
      vector<VkImageView> swapChainImageViews;
      vector<VkFramebuffer> swapChainFramebuffers;

      VkRenderPass renderPass;
      VkCommandPool commandPool;
      vector<VkCommandBuffer> commandBuffers;

      VulkanImage depthImage;

      VkSampler textureSampler;

      VulkanImage floorTextureImage;
      VkDescriptorImageInfo floorTextureImageDescriptor;

      VulkanImage sdlOverlayImage;
      VkDescriptorImageInfo sdlOverlayImageDescriptor;

      TTF_Font* font;
      SDL_Texture* fontSDLTexture;

      SDL_Texture* imageSDLTexture;

      vector<VkSemaphore> imageAvailableSemaphores;
      vector<VkSemaphore> renderFinishedSemaphores;
      vector<VkFence> inFlightFences;

      size_t currentFrame;

      bool framebufferResized;

      // TODO: I should probably rename the uniformBuffer* and storageBuffer*
      // variables to better reflect the data they hold

      GraphicsPipeline_Vulkan<OverlayVertex> overlayPipeline;

      vector<SceneObject<OverlayVertex>> overlayObjects;

      // TODO: Rename all the variables related to modelPipeline to use the same pipelie name

      GraphicsPipeline_Vulkan<ModelVertex> modelPipeline;

      vector<SceneObject<ModelVertex>> modelObjects;

      vector<VkBuffer> uniformBuffers_scenePipeline;
      vector<VkDeviceMemory> uniformBuffersMemory_scenePipeline;

      vector<VkDescriptorBufferInfo> uniformBufferInfoList_scenePipeline;

      vector<VkBuffer> storageBuffers_scenePipeline;
      vector<VkDeviceMemory> storageBuffersMemory_scenePipeline;

      vector<VkDescriptorBufferInfo> storageBufferInfoList_scenePipeline;

      UBO_VP_mats object_VP_mats;
      SBO_SceneObject so_Object;

      GraphicsPipeline_Vulkan<ShipVertex> shipPipeline;

      vector<SceneObject<ShipVertex>> shipObjects;

      vector<VkBuffer> uniformBuffers_shipPipeline;
      vector<VkDeviceMemory> uniformBuffersMemory_shipPipeline;

      vector<VkDescriptorBufferInfo> uniformBufferInfoList_shipPipeline;

      vector<VkBuffer> storageBuffers_shipPipeline;
      vector<VkDeviceMemory> storageBuffersMemory_shipPipeline;

      vector<VkDescriptorBufferInfo> storageBufferInfoList_shipPipeline;

      UBO_VP_mats ship_VP_mats;
      SBO_SceneObject so_Ship;

      bool initWindow(int width, int height, unsigned char guiFlags);
      void initVulkan();
      void initGraphicsPipelines();
      void initMatrices();
      void mainLoop();
      void updateScene(uint32_t currentImage);
      void renderUI();
      void renderScene();
      void cleanup();

      void createVulkanInstance(const vector<const char*> &validationLayers);
      void setupDebugMessenger();
      void populateDebugMessengerCreateInfo(VkDebugUtilsMessengerCreateInfoEXT& createInfo);
      void createVulkanSurface();
      void pickPhysicalDevice(const vector<const char*>& deviceExtensions);
      bool isDeviceSuitable(VkPhysicalDevice physicalDevice, const vector<const char*>& deviceExtensions);
      void createLogicalDevice(
         const vector<const char*> validationLayers,
         const vector<const char*>& deviceExtensions);
      void createSwapChain();
      void createImageViews();
      void createRenderPass();
      VkFormat findDepthFormat();
      void createCommandPool();
      void createImageResources();

      void createTextureSampler();
      void createFramebuffers();
      void createCommandBuffers();
      void createSyncObjects();

      template<class VertexType>
      void addObject(vector<SceneObject<VertexType>>& objects, GraphicsPipeline_Vulkan<VertexType>& pipeline,
         const vector<VertexType>& vertices, vector<uint16_t> indices);

      template<class VertexType>
      vector<VertexType> addVertexNormals(vector<VertexType> vertices);

      template<class VertexType>
      vector<VertexType> addObjectIndex(unsigned int objIndex, vector<VertexType> vertices);

      template<class VertexType>
      vector<VertexType> centerObject(vector<VertexType> vertices);

      template<class VertexType>
      void transformObject(SceneObject<VertexType>& obj, mat4 mat);

      void createBufferSet(VkDeviceSize bufferSize, VkBufferUsageFlags flags,
         vector<VkBuffer>& buffers, vector<VkDeviceMemory>& buffersMemory, vector<VkDescriptorBufferInfo>& bufferInfoList);

      void recreateSwapChain();

      void cleanupSwapChain();

      static VKAPI_ATTR VkBool32 VKAPI_CALL debugCallback(
            VkDebugUtilsMessageSeverityFlagBitsEXT messageSeverity,
            VkDebugUtilsMessageTypeFlagsEXT messageType,
            const VkDebugUtilsMessengerCallbackDataEXT* pCallbackData,
            void* pUserData);
};

template<class VertexType>
void VulkanGame::addObject(vector<SceneObject<VertexType>>& objects, GraphicsPipeline_Vulkan<VertexType>& pipeline,
      const vector<VertexType>& vertices, vector<uint16_t> indices) {
   size_t numVertices = pipeline.getNumVertices();

   for (uint16_t& idx : indices) {
      idx += numVertices;
   }

   objects.push_back({ vertices, indices, mat4(1.0f), mat4(1.0f) });

   pipeline.addVertices(vertices, indices, commandPool, graphicsQueue);
}

template<class VertexType>
vector<VertexType> VulkanGame::addVertexNormals(vector<VertexType> vertices) {
   for (unsigned int i = 0; i < vertices.size(); i += 3) {
      vec3 p1 = vertices[i].pos;
      vec3 p2 = vertices[i+1].pos;
      vec3 p3 = vertices[i+2].pos;

      vec3 normal = normalize(cross(p2 - p1, p3 - p1));

      // Add the same normal for all 3 vertices
      vertices[i].normal = normal;
      vertices[i+1].normal = normal;
      vertices[i+2].normal = normal;
   }

   return vertices;
}

template<class VertexType>
vector<VertexType> VulkanGame::addObjectIndex(unsigned int objIndex, vector<VertexType> vertices) {
   for (VertexType& vertex : vertices) {
      vertex.objIndex = objIndex;
   }

   return vertices;
}

template<class VertexType>
vector<VertexType> VulkanGame::centerObject(vector<VertexType> vertices) {
   float min_x = vertices[0].pos.x;
   float max_x = vertices[0].pos.x;
   float min_y = vertices[0].pos.y;
   float max_y = vertices[0].pos.y;
   float min_z = vertices[0].pos.z;
   float max_z = vertices[0].pos.z;

   // start from the second point
   for (unsigned int i = 1; i < vertices.size(); i++) {
      if (min_x > vertices[i].pos.x) {
         min_x = vertices[i].pos.x;
      } else if (max_x < vertices[i].pos.x) {
         max_x = vertices[i].pos.x;
      }

      if (min_y > vertices[i].pos.y) {
         min_y = vertices[i].pos.y;
      } else if (max_y < vertices[i].pos.y) {
         max_y = vertices[i].pos.y;
      }

      if (min_z > vertices[i].pos.z) {
         min_z = vertices[i].pos.z;
      } else if (max_z < vertices[i].pos.z) {
         max_z = vertices[i].pos.z;
      }
   }

   vec3 center = vec3(min_x + max_x, min_y + max_y, min_z + max_z) / 2.0f;

   for (unsigned int i = 0; i < vertices.size(); i++) {
      vertices[i].pos -= center;
   }

   return vertices;
}

template<class VertexType>
void VulkanGame::transformObject(SceneObject<VertexType>& obj, mat4 mat) {
   obj.model_transform = mat * obj.model_transform;
}

#endif // _VULKAN_GAME_H