#include "td-tut-2-main.h"
#include <raymath.h>
#include <stdlib.h>
#include <math.h>

typedef struct GameTime
{
  float time;
  float deltaTime;
} GameTime;

GameTime gameTime = {0};

//# Enemies

#define ENEMY_MAX_COUNT 400
#define ENEMY_TYPE_NONE 0
#define ENEMY_TYPE_MINION 1

typedef struct EnemyId
{
  uint16_t index;
  uint16_t generation;
} EnemyId;

typedef struct EnemyClassConfig
{
  float speed;
  float health;
  float radius;
  float maxAcceleration;
} EnemyClassConfig;

typedef struct Enemy
{
  int16_t currentX, currentY;
  int16_t nextX, nextY;
  Vector2 simPosition;
  Vector2 simVelocity;
  uint16_t generation;
  float startMovingTime;
  float damage, futureDamage;
  uint8_t enemyType;
} Enemy;

Enemy enemies[ENEMY_MAX_COUNT];
int enemyCount = 0;

EnemyClassConfig enemyClassConfigs[] = {
    [ENEMY_TYPE_MINION] = {.health = 3.0f, .speed = 1.0f, .radius = 0.25f, .maxAcceleration = 1.0f},
};

void EnemyInit()
{
  for (int i = 0; i < ENEMY_MAX_COUNT; i++)
  {
    enemies[i] = (Enemy){0};
  }
  enemyCount = 0;
}

float EnemyGetCurrentMaxSpeed(Enemy *enemy)
{
  return enemyClassConfigs[enemy->enemyType].speed;
}

float EnemyGetMaxHealth(Enemy *enemy)
{
  return enemyClassConfigs[enemy->enemyType].health;
}

int EnemyGetNextPosition(int16_t currentX, int16_t currentY, int16_t *nextX, int16_t *nextY)
{
  int16_t castleX = 0;
  int16_t castleY = 0;
  int16_t dx = castleX - currentX;
  int16_t dy = castleY - currentY;
  if (dx == 0 && dy == 0)
  {
    *nextX = currentX;
    *nextY = currentY;
    return 1;
  }
  if (abs(dx) > abs(dy))
  {
    *nextX = currentX + (dx > 0 ? 1 : -1);
    *nextY = currentY;
  }
  else
  {
    *nextX = currentX;
    *nextY = currentY + (dy > 0 ? 1 : -1);
  }
  return 0;
}


// this function predicts the movement of the unit for the next deltaT seconds
Vector2 EnemyGetPosition(Enemy *enemy, float deltaT, Vector2 *velocity, int *waypointPassedCount)
{
  const float pointReachedDistance = 0.25f;
  const float pointReachedDistance2 = pointReachedDistance * pointReachedDistance;
  const float maxSimStepTime = 0.015625f;
  
  float maxAcceleration = enemyClassConfigs[enemy->enemyType].maxAcceleration;
  float maxSpeed = EnemyGetCurrentMaxSpeed(enemy);
  int16_t nextX = enemy->nextX;
  int16_t nextY = enemy->nextY;
  Vector2 position = enemy->simPosition;
  int passedCount = 0;
  for (float t = 0.0f; t < deltaT; t += maxSimStepTime)
  {
    float stepTime = fminf(deltaT - t, maxSimStepTime);
    Vector2 target = (Vector2){nextX, nextY};
    float speed = Vector2Length(*velocity);
    // draw the target position for debugging
    DrawCubeWires((Vector3){target.x, 0.2f, target.y}, 0.1f, 0.4f, 0.1f, RED);
    Vector2 lookForwardPos = Vector2Add(position, Vector2Scale(*velocity, speed));
    if (Vector2DistanceSqr(target, lookForwardPos) <= pointReachedDistance2)
    {
      // we reached the target position, let's move to the next waypoint
      EnemyGetNextPosition(nextX, nextY, &nextX, &nextY);
      target = (Vector2){nextX, nextY};
      // track how many waypoints we passed
      passedCount++;
    }
    
    // acceleration towards the target
    Vector2 unitDirection = Vector2Normalize(Vector2Subtract(target, lookForwardPos));
    Vector2 acceleration = Vector2Scale(unitDirection, maxAcceleration * stepTime);
    *velocity = Vector2Add(*velocity, acceleration);

    // limit the speed to the maximum speed
    if (speed > maxSpeed)
    {
      *velocity = Vector2Scale(*velocity, maxSpeed / speed);
    }

    // move the enemy
    position = Vector2Add(position, Vector2Scale(*velocity, stepTime));
  }

  if (waypointPassedCount)
  {
    (*waypointPassedCount) = passedCount;
  }

  return position;
}

void EnemyDraw()
{
  for (int i = 0; i < enemyCount; i++)
  {
    Enemy enemy = enemies[i];
    Vector2 position = EnemyGetPosition(&enemy, gameTime.time - enemy.startMovingTime, &enemy.simVelocity, 0);
    
    switch (enemy.enemyType)
    {
    case ENEMY_TYPE_MINION:
      DrawCubeWires((Vector3){position.x, 0.2f, position.y}, 0.4f, 0.4f, 0.4f, GREEN);
      break;
    }
  }
}

void EnemyUpdate()
{
  for (int i = 0; i < enemyCount; i++)
  {
    Enemy *enemy = &enemies[i];
    if (enemy->enemyType == ENEMY_TYPE_NONE)
    {
      continue;
    }

    int waypointPassedCount = 0;
    enemy->simPosition = EnemyGetPosition(enemy, gameTime.time - enemy->startMovingTime, &enemy->simVelocity, &waypointPassedCount);
    enemy->startMovingTime = gameTime.time;
    if (waypointPassedCount > 0)
    {
      enemy->currentX = enemy->nextX;
      enemy->currentY = enemy->nextY;
      if (EnemyGetNextPosition(enemy->currentX, enemy->currentY, &enemy->nextX, &enemy->nextY))
      {
        // enemy reached the castle; remove it
        enemy->enemyType = ENEMY_TYPE_NONE;
        continue;
      }
    }

  }
}

EnemyId EnemyGetId(Enemy *enemy)
{
  return (EnemyId){enemy - enemies, enemy->generation};
}

Enemy *EnemyTryResolve(EnemyId enemyId)
{
  if (enemyId.index >= ENEMY_MAX_COUNT)
  {
    return 0;
  }
  Enemy *enemy = &enemies[enemyId.index];
  if (enemy->generation != enemyId.generation || enemy->enemyType == ENEMY_TYPE_NONE)
  {
    return 0;
  }
  return enemy;
}

Enemy *EnemyTryAdd(uint8_t enemyType, int16_t currentX, int16_t currentY)
{
  Enemy *spawn = 0;
  for (int i = 0; i < enemyCount; i++)
  {
    Enemy *enemy = &enemies[i];
    if (enemy->enemyType == ENEMY_TYPE_NONE)
    {
      spawn = enemy;
      break;
    }
  }

  if (enemyCount < ENEMY_MAX_COUNT && !spawn)
  {
    spawn = &enemies[enemyCount++];
  }

  if (spawn)
  {
    spawn->currentX = currentX;
    spawn->currentY = currentY;
    spawn->nextX = currentX;
    spawn->nextY = currentY;
    spawn->simPosition = (Vector2){currentX, currentY};
    spawn->simVelocity = (Vector2){0, 0};
    spawn->enemyType = enemyType;
    spawn->startMovingTime = gameTime.time;
    spawn->damage = 0.0f;
    spawn->futureDamage = 0.0f;
    spawn->generation++;
  }

  return spawn;
}

int EnemyAddDamage(Enemy *enemy, float damage)
{
  enemy->damage += damage;
  if (enemy->damage >= EnemyGetMaxHealth(enemy))
  {
    enemy->enemyType = ENEMY_TYPE_NONE;
    return 1;
  }

  return 0;
}

Enemy* EnemyGetClosestToCastle(int16_t towerX, int16_t towerY, float range)
{
  int16_t castleX = 0;
  int16_t castleY = 0;
  Enemy* closest = 0;
  int16_t closestDistance = 0;
  float range2 = range * range;
  for (int i = 0; i < enemyCount; i++)
  {
    Enemy* enemy = &enemies[i];
    if (enemy->enemyType == ENEMY_TYPE_NONE)
    {
      continue;
    }
    float maxHealth = EnemyGetMaxHealth(enemy);
    if (enemy->futureDamage >= maxHealth)
    {
      // ignore enemies that will die soon
      continue;
    }
    int16_t dx = castleX - enemy->currentX;
    int16_t dy = castleY - enemy->currentY;
    int16_t distance = abs(dx) + abs(dy);
    if (!closest || distance < closestDistance)
    {
      float tdx = towerX - enemy->currentX;
      float tdy = towerY - enemy->currentY;
      float tdistance2 = tdx * tdx + tdy * tdy;
      if (tdistance2 <= range2)
      {
        closest = enemy;
        closestDistance = distance;
      }
    }
  }
  return closest;
}

int EnemyCount()
{
  int count = 0;
  for (int i = 0; i < enemyCount; i++)
  {
    if (enemies[i].enemyType != ENEMY_TYPE_NONE)
    {
      count++;
    }
  }
  return count;
}

//# Projectiles
#define PROJECTILE_MAX_COUNT 1200
#define PROJECTILE_TYPE_NONE 0
#define PROJECTILE_TYPE_BULLET 1

typedef struct Projectile
{
  uint8_t projectileType;
  float shootTime;
  float arrivalTime;
  float damage;
  Vector2 position;
  Vector2 target;
  Vector2 directionNormal;
  EnemyId targetEnemy;
} Projectile;

Projectile projectiles[PROJECTILE_MAX_COUNT];
int projectileCount = 0;

void ProjectileInit()
{
  for (int i = 0; i < PROJECTILE_MAX_COUNT; i++)
  {
    projectiles[i] = (Projectile){0};
  }
}

void ProjectileDraw()
{
  for (int i = 0; i < projectileCount; i++)
  {
    Projectile projectile = projectiles[i];
    if (projectile.projectileType == PROJECTILE_TYPE_NONE)
    {
      continue;
    }
    float transition = (gameTime.time - projectile.shootTime) / (projectile.arrivalTime - projectile.shootTime);
    if (transition >= 1.0f)
    {
      continue;
    }
    Vector2 position = Vector2Lerp(projectile.position, projectile.target, transition);
    float x = position.x;
    float y = position.y;
    float dx = projectile.directionNormal.x;
    float dy = projectile.directionNormal.y;
    for (float d = 1.0f; d > 0.0f; d -= 0.25f)
    {
      x -= dx * 0.1f;
      y -= dy * 0.1f;
      float size = 0.1f * d;
      DrawCube((Vector3){x, 0.2f, y}, size, size, size, RED);
    }
  }
}

void ProjectileUpdate()
{
  for (int i = 0; i < projectileCount; i++)
  {
    Projectile *projectile = &projectiles[i];
    if (projectile->projectileType == PROJECTILE_TYPE_NONE)
    {
      continue;
    }
    float transition = (gameTime.time - projectile->shootTime) / (projectile->arrivalTime - projectile->shootTime);
    if (transition >= 1.0f)
    {
      projectile->projectileType = PROJECTILE_TYPE_NONE;
      Enemy *enemy = EnemyTryResolve(projectile->targetEnemy);
      if (enemy)
      {
        EnemyAddDamage(enemy, projectile->damage);
      }
      continue;
    }
  }
}

Projectile *ProjectileTryAdd(uint8_t projectileType, Enemy *enemy, Vector2 position, Vector2 target, float speed, float damage)
{
  for (int i = 0; i < PROJECTILE_MAX_COUNT; i++)
  {
    Projectile *projectile = &projectiles[i];
    if (projectile->projectileType == PROJECTILE_TYPE_NONE)
    {
      projectile->projectileType = projectileType;
      projectile->shootTime = gameTime.time;
      projectile->arrivalTime = gameTime.time + Vector2Distance(position, target) / speed;
      projectile->damage = damage;
      projectile->position = position;
      projectile->target = target;
      projectile->directionNormal = Vector2Normalize(Vector2Subtract(target, position));
      projectile->targetEnemy = EnemyGetId(enemy);
      projectileCount = projectileCount <= i ? i + 1 : projectileCount;
      return projectile;
    }
  }
  return 0;
}

//# Towers

#define TOWER_MAX_COUNT 400
#define TOWER_TYPE_NONE 0
#define TOWER_TYPE_BASE 1
#define TOWER_TYPE_GUN 2

typedef struct Tower
{
  int16_t x, y;
  uint8_t towerType;
  float cooldown;
} Tower;

Tower towers[TOWER_MAX_COUNT];
int towerCount = 0;

void TowerInit()
{
  for (int i = 0; i < TOWER_MAX_COUNT; i++)
  {
    towers[i] = (Tower){0};
  }
  towerCount = 0;
}

Tower *TowerGetAt(int16_t x, int16_t y)
{
  for (int i = 0; i < towerCount; i++)
  {
    if (towers[i].x == x && towers[i].y == y)
    {
      return &towers[i];
    }
  }
  return 0;
}

Tower *TowerTryAdd(uint8_t towerType, int16_t x, int16_t y)
{
  if (towerCount >= TOWER_MAX_COUNT)
  {
    return 0;
  }

  Tower *tower = TowerGetAt(x, y);
  if (tower)
  {
    return 0;
  }

  tower = &towers[towerCount++];
  tower->x = x;
  tower->y = y;
  tower->towerType = towerType;
  return tower;
}

void TowerDraw()
{
  for (int i = 0; i < towerCount; i++)
  {
    Tower tower = towers[i];
    DrawCube((Vector3){tower.x, 0.125f, tower.y}, 1.0f, 0.25f, 1.0f, GRAY);
    switch (tower.towerType)
    {
    case TOWER_TYPE_BASE:
      DrawCube((Vector3){tower.x, 0.4f, tower.y}, 0.8f, 0.8f, 0.8f, MAROON);
      break;
    case TOWER_TYPE_GUN:
      DrawCube((Vector3){tower.x, 0.2f, tower.y}, 0.8f, 0.4f, 0.8f, DARKPURPLE);
      break;
    }
  }
}

void TowerGunUpdate(Tower *tower)
{
  if (tower->cooldown <= 0)
  {
    Enemy *enemy = EnemyGetClosestToCastle(tower->x, tower->y, 3.0f);
    if (enemy)
    {
      tower->cooldown = 0.25f;
      // shoot the enemy; determine future position of the enemy
      float bulletSpeed = 1.0f;
      float bulletDamage = 3.0f;
      Vector2 velocity = enemy->simVelocity;
      Vector2 futurePosition = EnemyGetPosition(enemy, gameTime.time - enemy->startMovingTime, &velocity, 0);
      Vector2 towerPosition = {tower->x, tower->y};
      float eta = Vector2Distance(towerPosition, futurePosition) / bulletSpeed;
      for (int i = 0; i < 8; i++) {
        velocity = enemy->simVelocity;
        futurePosition = EnemyGetPosition(enemy, gameTime.time - enemy->startMovingTime + eta, &velocity, 0);
        float distance = Vector2Distance(towerPosition, futurePosition);
        float eta2 = distance / bulletSpeed;
        if (fabs(eta - eta2) < 0.01f) {
          break;
        }
        eta = (eta2 + eta) * 0.5f;
      }
      ProjectileTryAdd(PROJECTILE_TYPE_BULLET, enemy, towerPosition, futurePosition, 
        bulletSpeed, bulletDamage);
      enemy->futureDamage += bulletDamage;
    }
  }
  else
  {
    tower->cooldown -= gameTime.deltaTime;
  }
}

void TowerUpdate()
{
  for (int i = 0; i < towerCount; i++)
  {
    Tower *tower = &towers[i];
    switch (tower->towerType)
    {
    case TOWER_TYPE_GUN:
      TowerGunUpdate(tower);
      break;
    }
  }
}

//# Game

float nextSpawnTime = 0.0f;

void InitGame()
{
  TowerInit();
  EnemyInit();
  ProjectileInit();

  TowerTryAdd(TOWER_TYPE_BASE, 0, 0);
  TowerTryAdd(TOWER_TYPE_GUN, 2, 0);
  TowerTryAdd(TOWER_TYPE_GUN, -2, 0);
  EnemyTryAdd(ENEMY_TYPE_MINION, 5, 4);
}

void GameUpdate()
{
  float dt = GetFrameTime();
  // cap maximum delta time to 0.1 seconds to prevent large time steps
  if (dt > 0.1f) dt = 0.1f;
  gameTime.time += dt;
  gameTime.deltaTime = dt;
  EnemyUpdate();
  TowerUpdate();
  ProjectileUpdate();

  // spawn a new enemy every second
  if (gameTime.time >= nextSpawnTime && EnemyCount() < 1)
  {
    nextSpawnTime = gameTime.time + 1.0f;
    // add a new enemy at the boundary of the map
    int randValue = GetRandomValue(-5, 5);
    int randSide = GetRandomValue(0, 3);
    int16_t x = randSide == 0 ? -5 : randSide == 1 ? 5 : randValue;
    int16_t y = randSide == 2 ? -5 : randSide == 3 ? 5 : randValue;
    EnemyTryAdd(ENEMY_TYPE_MINION, x, y);
  }
}

int main(void)
{
  int screenWidth, screenHeight;
  GetPreferredSize(&screenWidth, &screenHeight);
  InitWindow(screenWidth, screenHeight, "Tower defense");
  SetTargetFPS(30);

  Camera3D camera = {0};
  camera.position = (Vector3){0.0f, 10.0f, 5.0f};
  camera.target = (Vector3){0.0f, 0.0f, 0.0f};
  camera.up = (Vector3){0.0f, 0.0f, -1.0f};
  camera.fovy = 45.0f;
  camera.projection = CAMERA_PERSPECTIVE;

  InitGame();

  while (!WindowShouldClose())
  {
    if (IsPaused()) {
      // canvas is not visible in browser - do nothing
      continue;
    }
    BeginDrawing();
    ClearBackground(DARKBLUE);

    BeginMode3D(camera);
    DrawGrid(10, 1.0f);
    TowerDraw();
    EnemyDraw();
    ProjectileDraw();
    GameUpdate();
    EndMode3D();

    DrawText("Tower defense tutorial", 5, 5, 20, WHITE);
    EndDrawing();
  }

  CloseWindow();

  return 0;
}