#include "td-tut-2-main.h"
#include <raymath.h>
#include <stdlib.h>
#include <math.h>

//# Declarations

#define TOWER_MAX_COUNT 400
#define TOWER_TYPE_NONE 0
#define TOWER_TYPE_BASE 1
#define TOWER_TYPE_GUN 2
#define TOWER_TYPE_WALL 3

typedef struct Tower
{
  int16_t x, y;
  uint8_t towerType;
  float cooldown;
  float damage;
} Tower;

typedef struct GameTime
{
  float time;
  float deltaTime;
} GameTime;

GameTime gameTime = {0};

Tower towers[TOWER_MAX_COUNT];
int towerCount = 0;

float TowerGetMaxHealth(Tower *tower);

//# Pathfinding map
typedef struct DeltaSrc
{
  char x, y;
} DeltaSrc;

typedef struct PathfindingMap
{
  int width, height;
  float scale;
  float *distances;
  long *towerIndex; 
  DeltaSrc *deltaSrc;
  float maxDistance;
  Matrix toMapSpace;
  Matrix toWorldSpace;
} PathfindingMap;

// when we execute the pathfinding algorithm, we need to store the active nodes
// in a queue. Each node has a position, a distance from the start, and the
// position of the node that we came from.
typedef struct PathfindingNode
{
  int16_t x, y, fromX, fromY;
  float distance;
} PathfindingNode;

// The queue is a simple array of nodes, we add nodes to the end and remove
// nodes from the front. We keep the array around to avoid unnecessary allocations
static PathfindingNode *pathfindingNodeQueue = 0;
static int pathfindingNodeQueueCount = 0;
static int pathfindingNodeQueueCapacity = 0;

// The pathfinding map stores the distances from the castle to each cell in the map.
PathfindingMap pathfindingMap = {0};

void PathfindingMapInit(int width, int height, Vector3 translate, float scale)
{
  // transforming between map space and world space allows us to adapt 
  // position and scale of the map without changing the pathfinding data
  pathfindingMap.toWorldSpace = MatrixTranslate(translate.x, translate.y, translate.z);
  pathfindingMap.toWorldSpace = MatrixMultiply(pathfindingMap.toWorldSpace, MatrixScale(scale, scale, scale));
  pathfindingMap.toMapSpace = MatrixInvert(pathfindingMap.toWorldSpace);
  pathfindingMap.width = width;
  pathfindingMap.height = height;
  pathfindingMap.scale = scale;
  pathfindingMap.distances = (float *)MemAlloc(width * height * sizeof(float));
  for (int i = 0; i < width * height; i++)
  {
    pathfindingMap.distances[i] = -1.0f;
  }

  pathfindingMap.towerIndex = (long *)MemAlloc(width * height * sizeof(long));
  pathfindingMap.deltaSrc = (DeltaSrc *)MemAlloc(width * height * sizeof(DeltaSrc));
}

float PathFindingGetDistance(int mapX, int mapY)
{
  if (mapX < 0 || mapX >= pathfindingMap.width || mapY < 0 || mapY >= pathfindingMap.height)
  {
    // when outside the map, we return the manhattan distance to the castle (0,0)
    return fabsf((float)mapX) + fabsf((float)mapY);
  }

  return pathfindingMap.distances[mapY * pathfindingMap.width + mapX];
}

void PathFindingNodePush(int16_t x, int16_t y, int16_t fromX, int16_t fromY, float distance)
{
  if (pathfindingNodeQueueCount >= pathfindingNodeQueueCapacity)
  {
    pathfindingNodeQueueCapacity = pathfindingNodeQueueCapacity == 0 ? 256 : pathfindingNodeQueueCapacity * 2;
    // we use MemAlloc/MemRealloc to allocate memory for the queue
    // I am not entirely sure if MemRealloc allows passing a null pointer
    // so we check if the pointer is null and use MemAlloc in that case
    if (pathfindingNodeQueue == 0)
    {
      pathfindingNodeQueue = (PathfindingNode *)MemAlloc(pathfindingNodeQueueCapacity * sizeof(PathfindingNode));
    }
    else
    {
      pathfindingNodeQueue = (PathfindingNode *)MemRealloc(pathfindingNodeQueue, pathfindingNodeQueueCapacity * sizeof(PathfindingNode));
    }
  }

  PathfindingNode *node = &pathfindingNodeQueue[pathfindingNodeQueueCount++];
  node->x = x;
  node->y = y;
  node->fromX = fromX;
  node->fromY = fromY;
  node->distance = distance;
}

PathfindingNode *PathFindingNodePop()
{
  if (pathfindingNodeQueueCount == 0)
  {
    return 0;
  }
  // we return the first node in the queue; we want to return a pointer to the node
  // so we can return 0 if the queue is empty. 
  // We should _not_ return a pointer to the element in the list, because the list
  // may be reallocated and the pointer would become invalid. Or the 
  // popped element is overwritten by the next push operation.
  // Using static here means that the variable is permanently allocated.
  static PathfindingNode node;
  node = pathfindingNodeQueue[0];
  // we shift all nodes one position to the front
  for (int i = 1; i < pathfindingNodeQueueCount; i++)
  {
    pathfindingNodeQueue[i - 1] = pathfindingNodeQueue[i];
  }
  --pathfindingNodeQueueCount;
  return &node;
}

// transform a world position to a map position in the array; 
// returns true if the position is inside the map
int PathFindingFromWorldToMapPosition(Vector3 worldPosition, int16_t *mapX, int16_t *mapY)
{
  Vector3 mapPosition = Vector3Transform(worldPosition, pathfindingMap.toMapSpace);
  *mapX = (int16_t)mapPosition.x;
  *mapY = (int16_t)mapPosition.z;
  return *mapX >= 0 && *mapX < pathfindingMap.width && *mapY >= 0 && *mapY < pathfindingMap.height;
}

void PathFindingMapUpdate()
{
  const int castleX = 0, castleY = 0;
  int16_t castleMapX, castleMapY;
  if (!PathFindingFromWorldToMapPosition((Vector3){castleX, 0.0f, castleY}, &castleMapX, &castleMapY))
  {
    return;
  }
  int width = pathfindingMap.width, height = pathfindingMap.height;

  // reset the distances to -1
  for (int i = 0; i < width * height; i++)
  {
    pathfindingMap.distances[i] = -1.0f;
  }
  // reset the tower indices
  for (int i = 0; i < width * height; i++)
  {
    pathfindingMap.towerIndex[i] = -1;
  }
  // reset the delta src
  for (int i = 0; i < width * height; i++)
  {
    pathfindingMap.deltaSrc[i].x = 0;
    pathfindingMap.deltaSrc[i].y = 0;
  }

  for (int i = 0; i < towerCount; i++)
  {
    Tower *tower = &towers[i];
    if (tower->towerType == TOWER_TYPE_NONE || tower->towerType == TOWER_TYPE_BASE)
    {
      continue;
    }
    int16_t mapX, mapY;
    // technically, if the tower cell scale is not in sync with the pathfinding map scale,
    // this would not work correctly and needs to be refined to allow towers covering multiple cells
    // or having multiple towers in one cell; for simplicity, we assume that the tower covers exactly
    // one cell. For now.
    if (!PathFindingFromWorldToMapPosition((Vector3){tower->x, 0.0f, tower->y}, &mapX, &mapY))
    {
      continue;
    }
    int index = mapY * width + mapX;
    pathfindingMap.towerIndex[index] = i;
  }

  // we start at the castle and add the castle to the queue
  pathfindingMap.maxDistance = 0.0f;
  pathfindingNodeQueueCount = 0;
  PathFindingNodePush(castleMapX, castleMapY, castleMapX, castleMapY, 0.0f);
  PathfindingNode *node = 0;
  while ((node = PathFindingNodePop()))
  {
    if (node->x < 0 || node->x >= width || node->y < 0 || node->y >= height)
    {
      continue;
    }
    int index = node->y * width + node->x;
    if (pathfindingMap.distances[index] >= 0 && pathfindingMap.distances[index] <= node->distance)
    {
      continue;
    }

    int deltaX = node->x - node->fromX;
    int deltaY = node->y - node->fromY;
    // even if the cell is blocked by a tower, we still may want to store the direction
    // (though this might not be needed, IDK right now)
    pathfindingMap.deltaSrc[index].x = (char) deltaX;
    pathfindingMap.deltaSrc[index].y = (char) deltaY;

    // we skip nodes that are blocked by towers
    if (pathfindingMap.towerIndex[index] >= 0)
    {
      node->distance += 8.0f;
    }
    pathfindingMap.distances[index] = node->distance;
    pathfindingMap.maxDistance = fmaxf(pathfindingMap.maxDistance, node->distance);
    PathFindingNodePush(node->x, node->y + 1, node->x, node->y, node->distance + 1.0f);
    PathFindingNodePush(node->x, node->y - 1, node->x, node->y, node->distance + 1.0f);
    PathFindingNodePush(node->x + 1, node->y, node->x, node->y, node->distance + 1.0f);
    PathFindingNodePush(node->x - 1, node->y, node->x, node->y, node->distance + 1.0f);
  }
}

void PathFindingMapDraw()
{
  float cellSize = pathfindingMap.scale * 0.9f;
  float highlightDistance = fmodf(GetTime() * 4.0f, pathfindingMap.maxDistance);
  for (int x = 0; x < pathfindingMap.width; x++)
  {
    for (int y = 0; y < pathfindingMap.height; y++)
    {
      float distance = pathfindingMap.distances[y * pathfindingMap.width + x];
      float colorV = distance < 0 ? 0 : fminf(distance / pathfindingMap.maxDistance, 1.0f);
      Color color = distance < 0 ? BLUE : (Color){fminf(colorV, 1.0f) * 255, 0, 0, 255};
      Vector3 position = Vector3Transform((Vector3){x, -0.25f, y}, pathfindingMap.toWorldSpace);
      // animate the distance "wave" to show how the pathfinding algorithm expands
      // from the castle
      if (distance + 0.5f > highlightDistance && distance - 0.5f < highlightDistance)
      {
        color = BLACK;
      }
      DrawCube(position, cellSize, 0.1f, cellSize, color);
    }
  }
}

Vector2 PathFindingGetGradient(Vector3 world)
{
  int16_t mapX, mapY;
  if (PathFindingFromWorldToMapPosition(world, &mapX, &mapY))
  {
    DeltaSrc delta = pathfindingMap.deltaSrc[mapY * pathfindingMap.width + mapX];
    return (Vector2){(float)-delta.x, (float)-delta.y};
  }
  // fallback to a simple gradient calculation
  float n = PathFindingGetDistance(mapX, mapY - 1);
  float s = PathFindingGetDistance(mapX, mapY + 1);
  float w = PathFindingGetDistance(mapX - 1, mapY);
  float e = PathFindingGetDistance(mapX + 1, mapY);
  return (Vector2){w - e + 0.25f, n - s + 0.125f};
}

//# Enemies

#define ENEMY_MAX_PATH_COUNT 8
#define ENEMY_MAX_COUNT 400
#define ENEMY_TYPE_NONE 0
#define ENEMY_TYPE_MINION 1

typedef struct EnemyId
{
  uint16_t index;
  uint16_t generation;
} EnemyId;

typedef struct EnemyClassConfig
{
  float speed;
  float health;
  float radius;
  float maxAcceleration;
  float explosionDamage;
} EnemyClassConfig;

typedef struct Enemy
{
  int16_t currentX, currentY;
  int16_t nextX, nextY;
  Vector2 simPosition;
  Vector2 simVelocity;
  uint16_t generation;
  float startMovingTime;
  float damage, futureDamage;
  uint8_t enemyType;
  uint8_t movePathCount;
  Vector2 movePath[ENEMY_MAX_PATH_COUNT];
} Enemy;

Enemy enemies[ENEMY_MAX_COUNT];
int enemyCount = 0;

EnemyClassConfig enemyClassConfigs[] = {
    [ENEMY_TYPE_MINION] = {
      .health = 3.0f, 
      .speed = 1.0f, 
      .radius = 0.25f, 
      .maxAcceleration = 1.0f,
      .explosionDamage = 1.0f,
    },
};

void EnemyInit()
{
  for (int i = 0; i < ENEMY_MAX_COUNT; i++)
  {
    enemies[i] = (Enemy){0};
  }
  enemyCount = 0;
}

float EnemyGetCurrentMaxSpeed(Enemy *enemy)
{
  return enemyClassConfigs[enemy->enemyType].speed;
}

float EnemyGetMaxHealth(Enemy *enemy)
{
  return enemyClassConfigs[enemy->enemyType].health;
}

int EnemyGetNextPosition(int16_t currentX, int16_t currentY, int16_t *nextX, int16_t *nextY)
{
  int16_t castleX = 0;
  int16_t castleY = 0;
  int16_t dx = castleX - currentX;
  int16_t dy = castleY - currentY;
  if (dx == 0 && dy == 0)
  {
    *nextX = currentX;
    *nextY = currentY;
    return 1;
  }
  Vector2 gradient = PathFindingGetGradient((Vector3){currentX, 0, currentY});

  if (gradient.x == 0 && gradient.y == 0)
  {
    *nextX = currentX;
    *nextY = currentY;
    return 1;
  }

  if (fabsf(gradient.x) > fabsf(gradient.y))
  {
    *nextX = currentX + (int16_t)(gradient.x > 0.0f ? 1 : -1);
    *nextY = currentY;
    return 0;
  }
  *nextX = currentX;
  *nextY = currentY + (int16_t)(gradient.y > 0.0f ? 1 : -1);
  return 0;
}


// this function predicts the movement of the unit for the next deltaT seconds
Vector2 EnemyGetPosition(Enemy *enemy, float deltaT, Vector2 *velocity, int *waypointPassedCount)
{
  const float pointReachedDistance = 0.25f;
  const float pointReachedDistance2 = pointReachedDistance * pointReachedDistance;
  const float maxSimStepTime = 0.015625f;
  
  float maxAcceleration = enemyClassConfigs[enemy->enemyType].maxAcceleration;
  float maxSpeed = EnemyGetCurrentMaxSpeed(enemy);
  int16_t nextX = enemy->nextX;
  int16_t nextY = enemy->nextY;
  Vector2 position = enemy->simPosition;
  int passedCount = 0;
  for (float t = 0.0f; t < deltaT; t += maxSimStepTime)
  {
    float stepTime = fminf(deltaT - t, maxSimStepTime);
    Vector2 target = (Vector2){nextX, nextY};
    float speed = Vector2Length(*velocity);
    // draw the target position for debugging
    DrawCubeWires((Vector3){target.x, 0.2f, target.y}, 0.1f, 0.4f, 0.1f, RED);
    Vector2 lookForwardPos = Vector2Add(position, Vector2Scale(*velocity, speed));
    if (Vector2DistanceSqr(target, lookForwardPos) <= pointReachedDistance2)
    {
      // we reached the target position, let's move to the next waypoint
      EnemyGetNextPosition(nextX, nextY, &nextX, &nextY);
      target = (Vector2){nextX, nextY};
      // track how many waypoints we passed
      passedCount++;
    }
    
    // acceleration towards the target
    Vector2 unitDirection = Vector2Normalize(Vector2Subtract(target, lookForwardPos));
    Vector2 acceleration = Vector2Scale(unitDirection, maxAcceleration * stepTime);
    *velocity = Vector2Add(*velocity, acceleration);

    // limit the speed to the maximum speed
    if (speed > maxSpeed)
    {
      *velocity = Vector2Scale(*velocity, maxSpeed / speed);
    }

    // move the enemy
    position = Vector2Add(position, Vector2Scale(*velocity, stepTime));
  }

  if (waypointPassedCount)
  {
    (*waypointPassedCount) = passedCount;
  }

  return position;
}

void EnemyDraw()
{
  for (int i = 0; i < enemyCount; i++)
  {
    Enemy enemy = enemies[i];
    if (enemy.enemyType == ENEMY_TYPE_NONE)
    {
      continue;
    }

    Vector2 position = EnemyGetPosition(&enemy, gameTime.time - enemy.startMovingTime, &enemy.simVelocity, 0);
    
    if (enemy.movePathCount > 0)
    {
      Vector3 p = {enemy.movePath[0].x, 0.2f, enemy.movePath[0].y};
      DrawLine3D(p, (Vector3){position.x, 0.2f, position.y}, GREEN);
    }
    for (int j = 1; j < enemy.movePathCount; j++)
    {
      Vector3 p = {enemy.movePath[j - 1].x, 0.2f, enemy.movePath[j - 1].y};
      Vector3 q = {enemy.movePath[j].x, 0.2f, enemy.movePath[j].y};
      DrawLine3D(p, q, GREEN);
    }

    switch (enemy.enemyType)
    {
    case ENEMY_TYPE_MINION:
      DrawCubeWires((Vector3){position.x, 0.2f, position.y}, 0.4f, 0.4f, 0.4f, GREEN);
      break;
    }
  }
}

void EnemyTriggerExplode(Enemy *enemy, Tower *tower)
{
  // damage the tower
  tower->damage += enemyClassConfigs[enemy->enemyType].explosionDamage;
  // explode the enemy
  if (tower->damage >= TowerGetMaxHealth(tower))
  {
    tower->towerType = TOWER_TYPE_NONE;
  }

  enemy->enemyType = ENEMY_TYPE_NONE;
}

void EnemyUpdate()
{
  const float castleX = 0;
  const float castleY = 0;
  const float maxPathDistance2 = 0.25f * 0.25f;
  
  for (int i = 0; i < enemyCount; i++)
  {
    Enemy *enemy = &enemies[i];
    if (enemy->enemyType == ENEMY_TYPE_NONE)
    {
      continue;
    }

    int waypointPassedCount = 0;
    enemy->simPosition = EnemyGetPosition(enemy, gameTime.time - enemy->startMovingTime, &enemy->simVelocity, &waypointPassedCount);
    enemy->startMovingTime = gameTime.time;
    // track path of unit
    if (enemy->movePathCount == 0 || Vector2DistanceSqr(enemy->simPosition, enemy->movePath[0]) > maxPathDistance2)
    {
      for (int j = ENEMY_MAX_PATH_COUNT - 1; j > 0; j--)
      {
        enemy->movePath[j] = enemy->movePath[j - 1];
      }
      enemy->movePath[0] = enemy->simPosition;
      if (++enemy->movePathCount > ENEMY_MAX_PATH_COUNT)
      {
        enemy->movePathCount = ENEMY_MAX_PATH_COUNT;
      }
    }

    if (waypointPassedCount > 0)
    {
      enemy->currentX = enemy->nextX;
      enemy->currentY = enemy->nextY;
      if (EnemyGetNextPosition(enemy->currentX, enemy->currentY, &enemy->nextX, &enemy->nextY) &&
        Vector2DistanceSqr(enemy->simPosition, (Vector2){castleX, castleY}) <= 0.25f * 0.25f)
      {
        // enemy reached the castle; remove it
        enemy->enemyType = ENEMY_TYPE_NONE;
        continue;
      }
    }
  }

  // handle collisions between enemies
  for (int i = 0; i < enemyCount - 1; i++)
  {
    Enemy *enemyA = &enemies[i];
    if (enemyA->enemyType == ENEMY_TYPE_NONE)
    {
      continue;
    }
    for (int j = i + 1; j < enemyCount; j++)
    {
      Enemy *enemyB = &enemies[j];
      if (enemyB->enemyType == ENEMY_TYPE_NONE)
      {
        continue;
      }
      float distanceSqr = Vector2DistanceSqr(enemyA->simPosition, enemyB->simPosition);
      float radiusA = enemyClassConfigs[enemyA->enemyType].radius;
      float radiusB = enemyClassConfigs[enemyB->enemyType].radius;
      float radiusSum = radiusA + radiusB;
      if (distanceSqr < radiusSum * radiusSum && distanceSqr > 0.001f)
      {
        // collision
        float distance = sqrtf(distanceSqr);
        float overlap = radiusSum - distance;
        // move the enemies apart, but softly; if we have a clog of enemies,
        // moving them perfectly apart can cause them to jitter
        float positionCorrection = overlap / 5.0f;
        Vector2 direction = (Vector2){
            (enemyB->simPosition.x - enemyA->simPosition.x) / distance * positionCorrection,
            (enemyB->simPosition.y - enemyA->simPosition.y) / distance * positionCorrection};
        enemyA->simPosition = Vector2Subtract(enemyA->simPosition, direction);
        enemyB->simPosition = Vector2Add(enemyB->simPosition, direction);
      }
    }
  }

  // handle collisions between enemies and towers
  for (int i = 0; i < enemyCount; i++)
  {
    Enemy *enemy = &enemies[i];
    if (enemy->enemyType == ENEMY_TYPE_NONE)
    {
      continue;
    }
    float enemyRadius = enemyClassConfigs[enemy->enemyType].radius;
    // linear search over towers; could be optimized by using path finding tower map,
    // but for now, we keep it simple
    for (int j = 0; j < towerCount; j++)
    {
      Tower *tower = &towers[j];
      if (tower->towerType == TOWER_TYPE_NONE)
      {
        continue;
      }
      float distanceSqr = Vector2DistanceSqr(enemy->simPosition, (Vector2){tower->x, tower->y});
      float combinedRadius = enemyRadius + 0.708; // sqrt(0.5^2 + 0.5^2), corner-center distance of square with side length 1
      if (distanceSqr > combinedRadius * combinedRadius)
      {
        continue;
      }
      // potential collision; square / circle intersection
      float dx = tower->x - enemy->simPosition.x;
      float dy = tower->y - enemy->simPosition.y;
      float absDx = fabsf(dx);
      float absDy = fabsf(dy);
      if (absDx <= 0.5f && absDx <= absDy) {
        // vertical collision; push the enemy out horizontally
        float overlap = enemyRadius + 0.5f - absDy;
        if (overlap < 0.0f)
        {
          continue;
        }
        float direction = dy > 0.0f ? -1.0f : 1.0f;
        enemy->simPosition.y += direction * overlap;
      }
      else if (absDy <= 0.5f && absDy <= absDx)
      {
        // horizontal collision; push the enemy out vertically
        float overlap = enemyRadius + 0.5f - absDx;
        if (overlap < 0.0f)
        {
          continue;
        }
        float direction = dx > 0.0f ? -1.0f : 1.0f;
        enemy->simPosition.x += direction * overlap;
      }
      else
      {
        // possible collision with a corner
        float cornerDX = dx > 0.0f ? -0.5f : 0.5f;
        float cornerDY = dy > 0.0f ? -0.5f : 0.5f;
        float cornerX = tower->x + cornerDX;
        float cornerY = tower->y + cornerDY;
        float cornerDistanceSqr = Vector2DistanceSqr(enemy->simPosition, (Vector2){cornerX, cornerY});
        if (cornerDistanceSqr > enemyRadius * enemyRadius)
        {
          continue;
        }
        // push the enemy out along the diagonal
        float cornerDistance = sqrtf(cornerDistanceSqr);
        float overlap = enemyRadius - cornerDistance;
        float directionX = cornerDistance > 0.0f ? (cornerX - enemy->simPosition.x) / cornerDistance : -cornerDX;
        float directionY = cornerDistance > 0.0f ? (cornerY - enemy->simPosition.y) / cornerDistance : -cornerDY;
        enemy->simPosition.x -= directionX * overlap;
        enemy->simPosition.y -= directionY * overlap;
      }

      if (enemyClassConfigs[enemy->enemyType].explosionDamage > 0.0f)
      {
        EnemyTriggerExplode(enemy, tower);
      }
    }
  }
}

EnemyId EnemyGetId(Enemy *enemy)
{
  return (EnemyId){enemy - enemies, enemy->generation};
}

Enemy *EnemyTryResolve(EnemyId enemyId)
{
  if (enemyId.index >= ENEMY_MAX_COUNT)
  {
    return 0;
  }
  Enemy *enemy = &enemies[enemyId.index];
  if (enemy->generation != enemyId.generation || enemy->enemyType == ENEMY_TYPE_NONE)
  {
    return 0;
  }
  return enemy;
}

Enemy *EnemyTryAdd(uint8_t enemyType, int16_t currentX, int16_t currentY)
{
  Enemy *spawn = 0;
  for (int i = 0; i < enemyCount; i++)
  {
    Enemy *enemy = &enemies[i];
    if (enemy->enemyType == ENEMY_TYPE_NONE)
    {
      spawn = enemy;
      break;
    }
  }

  if (enemyCount < ENEMY_MAX_COUNT && !spawn)
  {
    spawn = &enemies[enemyCount++];
  }

  if (spawn)
  {
    spawn->currentX = currentX;
    spawn->currentY = currentY;
    spawn->nextX = currentX;
    spawn->nextY = currentY;
    spawn->simPosition = (Vector2){currentX, currentY};
    spawn->simVelocity = (Vector2){0, 0};
    spawn->enemyType = enemyType;
    spawn->startMovingTime = gameTime.time;
    spawn->damage = 0.0f;
    spawn->futureDamage = 0.0f;
    spawn->generation++;
    spawn->movePathCount = 0;
  }

  return spawn;
}

int EnemyAddDamage(Enemy *enemy, float damage)
{
  enemy->damage += damage;
  if (enemy->damage >= EnemyGetMaxHealth(enemy))
  {
    enemy->enemyType = ENEMY_TYPE_NONE;
    return 1;
  }

  return 0;
}

Enemy* EnemyGetClosestToCastle(int16_t towerX, int16_t towerY, float range)
{
  int16_t castleX = 0;
  int16_t castleY = 0;
  Enemy* closest = 0;
  int16_t closestDistance = 0;
  float range2 = range * range;
  for (int i = 0; i < enemyCount; i++)
  {
    Enemy* enemy = &enemies[i];
    if (enemy->enemyType == ENEMY_TYPE_NONE)
    {
      continue;
    }
    float maxHealth = EnemyGetMaxHealth(enemy);
    if (enemy->futureDamage >= maxHealth)
    {
      // ignore enemies that will die soon
      continue;
    }
    int16_t dx = castleX - enemy->currentX;
    int16_t dy = castleY - enemy->currentY;
    int16_t distance = abs(dx) + abs(dy);
    if (!closest || distance < closestDistance)
    {
      float tdx = towerX - enemy->currentX;
      float tdy = towerY - enemy->currentY;
      float tdistance2 = tdx * tdx + tdy * tdy;
      if (tdistance2 <= range2)
      {
        closest = enemy;
        closestDistance = distance;
      }
    }
  }
  return closest;
}

int EnemyCount()
{
  int count = 0;
  for (int i = 0; i < enemyCount; i++)
  {
    if (enemies[i].enemyType != ENEMY_TYPE_NONE)
    {
      count++;
    }
  }
  return count;
}

//# Projectiles
#define PROJECTILE_MAX_COUNT 1200
#define PROJECTILE_TYPE_NONE 0
#define PROJECTILE_TYPE_BULLET 1

typedef struct Projectile
{
  uint8_t projectileType;
  float shootTime;
  float arrivalTime;
  float damage;
  Vector2 position;
  Vector2 target;
  Vector2 directionNormal;
  EnemyId targetEnemy;
} Projectile;

Projectile projectiles[PROJECTILE_MAX_COUNT];
int projectileCount = 0;

void ProjectileInit()
{
  for (int i = 0; i < PROJECTILE_MAX_COUNT; i++)
  {
    projectiles[i] = (Projectile){0};
  }
}

void ProjectileDraw()
{
  for (int i = 0; i < projectileCount; i++)
  {
    Projectile projectile = projectiles[i];
    if (projectile.projectileType == PROJECTILE_TYPE_NONE)
    {
      continue;
    }
    float transition = (gameTime.time - projectile.shootTime) / (projectile.arrivalTime - projectile.shootTime);
    if (transition >= 1.0f)
    {
      continue;
    }
    Vector2 position = Vector2Lerp(projectile.position, projectile.target, transition);
    float x = position.x;
    float y = position.y;
    float dx = projectile.directionNormal.x;
    float dy = projectile.directionNormal.y;
    for (float d = 1.0f; d > 0.0f; d -= 0.25f)
    {
      x -= dx * 0.1f;
      y -= dy * 0.1f;
      float size = 0.1f * d;
      DrawCube((Vector3){x, 0.2f, y}, size, size, size, RED);
    }
  }
}

void ProjectileUpdate()
{
  for (int i = 0; i < projectileCount; i++)
  {
    Projectile *projectile = &projectiles[i];
    if (projectile->projectileType == PROJECTILE_TYPE_NONE)
    {
      continue;
    }
    float transition = (gameTime.time - projectile->shootTime) / (projectile->arrivalTime - projectile->shootTime);
    if (transition >= 1.0f)
    {
      projectile->projectileType = PROJECTILE_TYPE_NONE;
      Enemy *enemy = EnemyTryResolve(projectile->targetEnemy);
      if (enemy)
      {
        EnemyAddDamage(enemy, projectile->damage);
      }
      continue;
    }
  }
}

Projectile *ProjectileTryAdd(uint8_t projectileType, Enemy *enemy, Vector2 position, Vector2 target, float speed, float damage)
{
  for (int i = 0; i < PROJECTILE_MAX_COUNT; i++)
  {
    Projectile *projectile = &projectiles[i];
    if (projectile->projectileType == PROJECTILE_TYPE_NONE)
    {
      projectile->projectileType = projectileType;
      projectile->shootTime = gameTime.time;
      projectile->arrivalTime = gameTime.time + Vector2Distance(position, target) / speed;
      projectile->damage = damage;
      projectile->position = position;
      projectile->target = target;
      projectile->directionNormal = Vector2Normalize(Vector2Subtract(target, position));
      projectile->targetEnemy = EnemyGetId(enemy);
      projectileCount = projectileCount <= i ? i + 1 : projectileCount;
      return projectile;
    }
  }
  return 0;
}

//# Towers

void TowerInit()
{
  for (int i = 0; i < TOWER_MAX_COUNT; i++)
  {
    towers[i] = (Tower){0};
  }
  towerCount = 0;
}

Tower *TowerGetAt(int16_t x, int16_t y)
{
  for (int i = 0; i < towerCount; i++)
  {
    if (towers[i].x == x && towers[i].y == y)
    {
      return &towers[i];
    }
  }
  return 0;
}

Tower *TowerTryAdd(uint8_t towerType, int16_t x, int16_t y)
{
  if (towerCount >= TOWER_MAX_COUNT)
  {
    return 0;
  }

  Tower *tower = TowerGetAt(x, y);
  if (tower)
  {
    return 0;
  }

  tower = &towers[towerCount++];
  tower->x = x;
  tower->y = y;
  tower->towerType = towerType;
  tower->cooldown = 0.0f;
  tower->damage = 0.0f;
  return tower;
}

float TowerGetMaxHealth(Tower *tower)
{
  switch (tower->towerType)
  {
  case TOWER_TYPE_BASE:
    return 10.0f;
  case TOWER_TYPE_GUN:
    return 3.0f;
  case TOWER_TYPE_WALL:
    return 5.0f;
  }
  return 0.0f;
}

void TowerDraw()
{
  for (int i = 0; i < towerCount; i++)
  {
    Tower tower = towers[i];
    DrawCube((Vector3){tower.x, 0.125f, tower.y}, 1.0f, 0.25f, 1.0f, GRAY);
    switch (tower.towerType)
    {
    case TOWER_TYPE_BASE:
      DrawCube((Vector3){tower.x, 0.4f, tower.y}, 0.8f, 0.8f, 0.8f, MAROON);
      break;
    case TOWER_TYPE_GUN:
      DrawCube((Vector3){tower.x, 0.2f, tower.y}, 0.8f, 0.4f, 0.8f, DARKPURPLE);
      break;
    case TOWER_TYPE_WALL:
      DrawCube((Vector3){tower.x, 0.5f, tower.y}, 1.0f, 1.0f, 1.0f, LIGHTGRAY);
      break;
    }
  }
}

void TowerGunUpdate(Tower *tower)
{
  if (tower->cooldown <= 0)
  {
    Enemy *enemy = EnemyGetClosestToCastle(tower->x, tower->y, 3.0f);
    if (enemy)
    {
      tower->cooldown = 0.125f;
      // shoot the enemy; determine future position of the enemy
      float bulletSpeed = 1.0f;
      float bulletDamage = 3.0f;
      Vector2 velocity = enemy->simVelocity;
      Vector2 futurePosition = EnemyGetPosition(enemy, gameTime.time - enemy->startMovingTime, &velocity, 0);
      Vector2 towerPosition = {tower->x, tower->y};
      float eta = Vector2Distance(towerPosition, futurePosition) / bulletSpeed;
      for (int i = 0; i < 8; i++) {
        velocity = enemy->simVelocity;
        futurePosition = EnemyGetPosition(enemy, gameTime.time - enemy->startMovingTime + eta, &velocity, 0);
        float distance = Vector2Distance(towerPosition, futurePosition);
        float eta2 = distance / bulletSpeed;
        if (fabs(eta - eta2) < 0.01f) {
          break;
        }
        eta = (eta2 + eta) * 0.5f;
      }
      ProjectileTryAdd(PROJECTILE_TYPE_BULLET, enemy, towerPosition, futurePosition, 
        bulletSpeed, bulletDamage);
      enemy->futureDamage += bulletDamage;
    }
  }
  else
  {
    tower->cooldown -= gameTime.deltaTime;
  }
}

void TowerUpdate()
{
  for (int i = 0; i < towerCount; i++)
  {
    Tower *tower = &towers[i];
    switch (tower->towerType)
    {
    case TOWER_TYPE_GUN:
      TowerGunUpdate(tower);
      break;
    }
  }
}

//# Game

float nextSpawnTime = 0.0f;

void InitGame()
{
  TowerInit();
  EnemyInit();
  ProjectileInit();
  PathfindingMapInit(20, 20, (Vector3){-10.0f, 0.0f, -10.0f}, 1.0f);

  TowerTryAdd(TOWER_TYPE_BASE, 0, 0);
  TowerTryAdd(TOWER_TYPE_GUN, 2, 0);

  for (int i = -2; i <= 2; i += 1)
  {
    TowerTryAdd(TOWER_TYPE_WALL, i, 2);
    TowerTryAdd(TOWER_TYPE_WALL, i, -2);
    TowerTryAdd(TOWER_TYPE_WALL, -2, i);
  }

  EnemyTryAdd(ENEMY_TYPE_MINION, 5, 4);
}

void GameUpdate()
{
  float dt = GetFrameTime();
  // cap maximum delta time to 0.1 seconds to prevent large time steps
  if (dt > 0.1f) dt = 0.1f;
  gameTime.time += dt;
  gameTime.deltaTime = dt;
  PathFindingMapUpdate();
  EnemyUpdate();
  TowerUpdate();
  ProjectileUpdate();

  // spawn a new enemy every second
  if (gameTime.time >= nextSpawnTime && EnemyCount() < 50)
  {
    nextSpawnTime = gameTime.time + 0.2f;
    // add a new enemy at the boundary of the map
    int randValue = GetRandomValue(-5, 5);
    int randSide = GetRandomValue(0, 3);
    int16_t x = randSide == 0 ? -5 : randSide == 1 ? 5 : randValue;
    int16_t y = randSide == 2 ? -5 : randSide == 3 ? 5 : randValue;
    static int alternation = 0;
    alternation += 1;
    if (alternation % 3 == 0) {
      EnemyTryAdd(ENEMY_TYPE_MINION, 0, -5);
    }
    else if (alternation % 3 == 1)
    {
      EnemyTryAdd(ENEMY_TYPE_MINION, 0, 5);
    }
    EnemyTryAdd(ENEMY_TYPE_MINION, x, y);
  }
}

int main(void)
{
  int screenWidth, screenHeight;
  GetPreferredSize(&screenWidth, &screenHeight);
  InitWindow(screenWidth, screenHeight, "Tower defense");
  SetTargetFPS(30);

  Camera3D camera = {0};
  camera.position = (Vector3){0.0f, 10.0f, -0.5f};
  camera.target = (Vector3){0.0f, 0.0f, -0.5f};
  camera.up = (Vector3){0.0f, 0.0f, -1.0f};
  camera.fovy = 12.0f;
  camera.projection = CAMERA_ORTHOGRAPHIC;

  InitGame();

  while (!WindowShouldClose())
  {
    if (IsPaused()) {
      // canvas is not visible in browser - do nothing
      continue;
    }
    BeginDrawing();
    ClearBackground(DARKBLUE);

    BeginMode3D(camera);
    DrawGrid(10, 1.0f);
    TowerDraw();
    EnemyDraw();
    ProjectileDraw();
    PathFindingMapDraw();
    GameUpdate();
    EndMode3D();

    const char *title = "Tower defense tutorial";
    int titleWidth = MeasureText(title, 20);
    DrawText(title, (GetScreenWidth()  - titleWidth) * 0.5f + 2, 5 + 2, 20, BLACK);
    DrawText(title, (GetScreenWidth()  - titleWidth) * 0.5f, 5, 20, WHITE);
    EndDrawing();
  }

  CloseWindow();

  return 0;
}