Skip to content

Instantly share code, notes, and snippets.

@adammyhre
Created February 9, 2025 12:23
Show Gist options
  • Save adammyhre/6317b014d4050514771b6020bf79ecc0 to your computer and use it in GitHub Desktop.
Save adammyhre/6317b014d4050514771b6020bf79ecc0 to your computer and use it in GitHub Desktop.
Tactical Pathfinding for A*
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.AI;
using TMPro;
public struct Node {
public Vector3 position;
public bool isWalkable;
public Node(Vector3 position, bool isWalkable = false) {
this.position = position;
this.isWalkable = isWalkable;
}
}
public class GridManager : MonoBehaviour {
public int gridSizeX = 10;
public int gridSizeZ = 10;
public float cellSize = 1.0f;
public TextMeshPro textPrefab;
public Material dangerZoneMaterial;
Node[,] grid;
// A sector represents an 8-way directional segment (45-degree slices)
// around a grid node, used to track enemy influence and danger levels in different directions.
readonly Dictionary<Vector3, int> nodeSectorData = new (); // Each position will map to an 8-sector bitmask
public Node GetNodeFromWorldPosition(Vector3 worldPosition) {
int x = Mathf.RoundToInt(worldPosition.x / cellSize);
int z = Mathf.RoundToInt(worldPosition.z / cellSize);
x = Mathf.Clamp(x, 0, gridSizeX - 1); // Clamp to grid bounds.
z = Mathf.Clamp(z, 0, gridSizeZ - 1); // Clamp to grid bounds.
return grid[x, z];
}
void Awake() {
InitializeGrid();
PrecomputeSectors();
DrawGrid();
}
void DrawGrid() {
for (int x = 0; x < gridSizeX; x++)
for (int z = 0; z < gridSizeZ; z++) {
Node node = grid[x, z];
int dangerScore = CalculateNodeDanger(GetSectors(node.position));
if (dangerScore == 0) continue;
GameObject cellVisual = GameObject.CreatePrimitive(PrimitiveType.Quad);
Destroy(cellVisual.GetComponent<Collider>());
cellVisual.transform.SetParent(transform);
cellVisual.transform.SetPositionAndRotation(node.position + new Vector3(cellSize / 2, 0.01f, cellSize / 2), Quaternion.Euler(90, 0, 0));
cellVisual.transform.localScale = new Vector3(cellSize, cellSize, 1);
cellVisual.GetComponent<Renderer>().sharedMaterial = dangerZoneMaterial;
TextMeshPro text3D = Instantiate(textPrefab, node.position + new Vector3(cellSize / 2, 0.1f, cellSize / 2), Quaternion.Euler(90, 0, 0));
text3D.text = dangerScore.ToString();
text3D.transform.SetParent(transform);
}
}
public int CalculateNodeDanger(int sectorMask) {
int danger = 0;
for (int sector = 0; sector < 8; sector++) {
danger += (sectorMask >> (sector * 4)) & 0b1111;
}
return danger;
}
public int GetSectors(Vector3 nodePosition) {
return nodeSectorData.GetValueOrDefault(nodePosition);
}
void PrecomputeSectors() {
var enemies = GameObject.FindGameObjectsWithTag("Enemy");
foreach (var node in grid) {
if (!node.isWalkable) continue;
int sectorMask = 0;
foreach (var enemy in enemies) {
if (enemy.GetComponent<Enemy>() is not { } enemyComponent) continue;
float detectionRadius = enemyComponent.detectionRadius;
Vector3 direction = (node.position - enemy.transform.position).normalized;
float distance = Vector3.Distance(node.position, enemy.transform.position);
if (distance > detectionRadius) continue;
if (IsObstructed(enemy.transform.position, node.position + new Vector3(0.5f, 0, 0.5f))) continue;
int sector = GetSectorForDirection(direction);
int rangeValue = GetRangeValue(distance, detectionRadius);
int sectorShift = sector * 4;
int currentSectorValue = (sectorMask >> sectorShift) & 0b1111;
int newSectorValue = Mathf.Min(15, currentSectorValue + rangeValue);
sectorMask &= ~(0b1111 << sectorShift);
sectorMask |= newSectorValue << sectorShift;
}
nodeSectorData[node.position] = sectorMask;
}
}
int GetRangeValue(float distance, float detectionRadius) {
if (distance < detectionRadius * 0.5f) return 3;
if (distance < detectionRadius * 0.75f) return 2;
if (distance <= detectionRadius) return 1;
return 0;
}
int GetSectorForDirection(Vector3 direction) {
return Mathf.FloorToInt((Mathf.Atan2(direction.z, direction.x) * Mathf.Rad2Deg + 360) % 360 / 45f);
}
bool IsObstructed(Vector3 from, Vector3 to) {
return Physics.Raycast(from, (to - from).normalized, Vector3.Distance(from, to));
}
void InitializeGrid() {
grid = new Node[gridSizeX, gridSizeZ];
for (int x = 0; x < gridSizeX; x++)
for (int z = 0; z < gridSizeZ; z++) {
var pos = new Vector3(x, 0, z) * cellSize;
grid[x, z] = new Node(pos, IsPositionOnNavMesh(pos));
}
}
bool IsPositionOnNavMesh(Vector3 position) => NavMesh.SamplePosition(position, out _, cellSize / 2, NavMesh.AllAreas);
}
using System.Collections.Generic;
using System.Linq;
public class PriorityQueue<TElement, TPriority> {
readonly SortedList<TPriority, Queue<TElement>> priorities = new ();
public int Count { get; set; }
public void Enqueue(TElement element, TPriority priority) {
if (!priorities.ContainsKey(priority)) {
priorities[priority] = new Queue<TElement>();
}
priorities[priority].Enqueue(element);
Count++;
}
public TElement Dequeue() {
if (priorities.Count == 0) throw new System.InvalidOperationException("The queue is empty");
var firstPair = priorities.First();
var element = firstPair.Value.Dequeue();
if (firstPair.Value.Count == 0) priorities.Remove(firstPair.Key);
Count--;
return element;
}
}
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.AI;
public class TacticalAgent : MonoBehaviour {
public GridManager gridManager;
public TacticalPathfinder pathfinder;
public Transform destination;
NavMeshAgent navMeshAgent;
List<Vector3> waypoints;
int currentWaypointIndex;
public Vector3 GetMovementVelocity() => navMeshAgent.velocity;
void Start() {
navMeshAgent = GetComponentInChildren<NavMeshAgent>();
SetCustomPath();
}
void Update() {
FollowCustomPath();
}
void SetCustomPath() {
Node startNode = gridManager.GetNodeFromWorldPosition(transform.position);
Node endNode = gridManager.GetNodeFromWorldPosition(destination.position);
List<Node> path = pathfinder.FindTacticalPath(startNode, endNode);
if (path?.Count > 0) {
waypoints = path.ConvertAll(node => node.position + new Vector3(0.5f, 0, 0.5f));
currentWaypointIndex = 0;
}
}
void FollowCustomPath() {
if (waypoints == null || currentWaypointIndex >= waypoints.Count) {
navMeshAgent.isStopped = true;
navMeshAgent.ResetPath();
return;
}
navMeshAgent.SetDestination(waypoints[currentWaypointIndex]);
if (!navMeshAgent.pathPending && navMeshAgent.remainingDistance < 1f) {
currentWaypointIndex++;
}
}
}
using System.Collections.Generic;
using System.Linq;
using TMPro;
using UnityEngine;
public class TacticalPathfinder : MonoBehaviour {
public GridManager gridManager;
public float dangerPenalty = 20.0f;
public List<Node> FindTacticalPath(Node startNode, Node endNode) {
var openList = new PriorityQueue<Node, float>();
var closedSet = new HashSet<Vector3>();
var cameFrom = new Dictionary<Vector3, Vector3>();
var costSoFar = new Dictionary<Vector3, float> { [startNode.position] = 0 };
var exposureSoFar = new Dictionary<Vector3, float> { [startNode.position] = 0 };
openList.Enqueue(startNode, 0);
while (openList.Count > 0) {
Node currentNode = openList.Dequeue();
if (currentNode.position == endNode.position)
return ReconstructPath(cameFrom, startNode, endNode);
closedSet.Add(currentNode.position);
foreach (Node neighbor in GetNeighbors(currentNode)) {
if (closedSet.Contains(neighbor.position)) continue;
float exposure = exposureSoFar[currentNode.position];
int dangerScore = gridManager.CalculateNodeDanger(gridManager.GetSectors(neighbor.position));
float tempExposure = exposure;
float newCost = costSoFar[currentNode.position] + TacticalCostNodeToNode(currentNode, neighbor, ref tempExposure, dangerScore);
if (costSoFar.TryGetValue(neighbor.position, out float existingCost) && newCost >= existingCost) continue;
costSoFar[neighbor.position] = newCost;
exposureSoFar[neighbor.position] = tempExposure;
cameFrom[neighbor.position] = currentNode.position;
float priority = newCost + Heuristic(neighbor, endNode);
openList.Enqueue(neighbor, priority);
// DisplayTacticalCost(neighbor.position, newCost);
}
}
Debug.LogWarning("No path found!");
return null;
}
float Heuristic(Node fromNode, Node toNode) {
float baseHeuristic = fromNode.position.SquaredDistance(toNode.position);
int dangerScore = gridManager.CalculateNodeDanger(gridManager.GetSectors(toNode.position));
float heuristicPenalty = dangerScore * dangerPenalty;
return baseHeuristic + heuristicPenalty;
}
float TacticalCostNodeToNode(Node fromNode, Node toNode, ref float cumulativeExposure, int dangerScore) {
float travelTime = Vector3.Distance(fromNode.position, toNode.position);
if (dangerScore > 0) {
cumulativeExposure += dangerScore * dangerPenalty;
}
float exposureRefund = 0;
if (dangerScore == 0) {
exposureRefund = -cumulativeExposure;
cumulativeExposure = 0;
}
float penalty = Mathf.Pow(3, dangerScore);
return travelTime + penalty * penalty + exposureRefund;
}
IEnumerable<Node> GetNeighbors(Node node) {
Vector3[] directions = { Vector3.forward, Vector3.back, Vector3.left, Vector3.right };
return directions
.Select(direction => gridManager.GetNodeFromWorldPosition(node.position + direction * gridManager.cellSize))
.Where(neighbor => neighbor.isWalkable);
}
List<Node> ReconstructPath(Dictionary<Vector3, Vector3> cameFrom, Node startNode, Node endNode) {
var path = new List<Node>();
Vector3 current = endNode.position;
while (current != startNode.position) {
path.Add(new Node(current));
current = cameFrom[current];
}
path.Add(startNode);
path.Reverse();
return SimplifyPath(path);
}
List<Node> SimplifyPath(List<Node> fullPath) {
if (fullPath == null || fullPath.Count < 2) return fullPath;
var simplifiedPath = new List<Node> { fullPath[0] };
for (int i = 1; i < fullPath.Count - 1; i++) {
if (gridManager.CalculateNodeDanger(gridManager.GetSectors(fullPath[i].position)) == 0) {
simplifiedPath.Add(fullPath[i]);
}
}
simplifiedPath.Add(fullPath[^1]);
return simplifiedPath;
}
void DisplayTacticalCost(Vector3 position, float cost) {
GameObject costTextObj = new GameObject("CostText");
costTextObj.transform.parent = gridManager.transform;
costTextObj.transform.position = position + new Vector3(0, 1.5f, 0); // Raise text above the ground
TextMeshPro textMesh = costTextObj.AddComponent<TextMeshPro>();
textMesh.text = cost.ToString("F1"); // Display cost with one decimal
textMesh.fontSize = 2;
textMesh.alignment = TextAlignmentOptions.Center;
if (cost < 10) textMesh.color = Color.green;
else if (cost < 20) textMesh.color = Color.yellow;
else if (cost < 40) textMesh.color = Color.red;
else textMesh.color = new Color(0.5f, 0, 0); // Dark red for extreme danger
Destroy(costTextObj, 15.0f);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment