Search.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. #include <vector>
  2. #include <queue>
  3. #include <iostream>
  4. #include <unordered_map>
  5. struct GraphNode
  6. {
  7. // Adjacency list
  8. std::vector<GraphNode*> mAdjacent;
  9. };
  10. struct Graph
  11. {
  12. // A graph contains nodes
  13. std::vector<GraphNode*> mNodes;
  14. };
  15. struct WeightedEdge
  16. {
  17. // Which nodes are connected by this edge?
  18. struct WeightedGraphNode* mFrom;
  19. struct WeightedGraphNode* mTo;
  20. // Weight of this edge
  21. float mWeight;
  22. };
  23. struct WeightedGraphNode
  24. {
  25. std::vector<WeightedEdge*> mEdges;
  26. };
  27. struct WeightedGraph
  28. {
  29. std::vector<WeightedGraphNode*> mNodes;
  30. };
  31. struct GBFSScratch
  32. {
  33. const WeightedEdge* mParentEdge = nullptr;
  34. float mHeuristic = 0.0f;
  35. bool mInOpenSet = false;
  36. bool mInClosedSet = false;
  37. };
  38. using GBFSMap =
  39. std::unordered_map<const WeightedGraphNode*, GBFSScratch>;
  40. struct AStarScratch
  41. {
  42. const WeightedEdge* mParentEdge = nullptr;
  43. float mHeuristic = 0.0f;
  44. float mActualFromStart = 0.0f;
  45. bool mInOpenSet = false;
  46. bool mInClosedSet = false;
  47. };
  48. using AStarMap =
  49. std::unordered_map<const WeightedGraphNode*, AStarScratch>;
  50. float ComputeHeuristic(const WeightedGraphNode* a, const WeightedGraphNode* b)
  51. {
  52. return 0.0f;
  53. }
  54. bool AStar(const WeightedGraph& g, const WeightedGraphNode* start,
  55. const WeightedGraphNode* goal, AStarMap& outMap)
  56. {
  57. std::vector<const WeightedGraphNode*> openSet;
  58. // Set current node to start, and mark in closed set
  59. const WeightedGraphNode* current = start;
  60. outMap[current].mInClosedSet = true;
  61. do
  62. {
  63. // Add adjacent nodes to open set
  64. for (const WeightedEdge* edge : current->mEdges)
  65. {
  66. const WeightedGraphNode* neighbor = edge->mTo;
  67. // Get scratch data for this node
  68. AStarScratch& data = outMap[neighbor];
  69. // Only check nodes that aren't in the closed set
  70. if (!data.mInClosedSet)
  71. {
  72. if (!data.mInOpenSet)
  73. {
  74. // Not in the open set, so parent must be current
  75. data.mParentEdge = edge;
  76. data.mHeuristic = ComputeHeuristic(neighbor, goal);
  77. // Actual cost is the parent's plus cost of traversing edge
  78. data.mActualFromStart = outMap[current].mActualFromStart +
  79. edge->mWeight;
  80. data.mInOpenSet = true;
  81. openSet.emplace_back(neighbor);
  82. }
  83. else
  84. {
  85. // Compute what new actual cost is if current becomes parent
  86. float newG = outMap[current].mActualFromStart + edge->mWeight;
  87. if (newG < data.mActualFromStart)
  88. {
  89. // Current should adopt this node
  90. data.mParentEdge = edge;
  91. data.mActualFromStart = newG;
  92. }
  93. }
  94. }
  95. }
  96. // If open set is empty, all possible paths are exhausted
  97. if (openSet.empty())
  98. {
  99. break;
  100. }
  101. // Find lowest cost node in open set
  102. auto iter = std::min_element(openSet.begin(), openSet.end(),
  103. [&outMap](const WeightedGraphNode* a, const WeightedGraphNode* b) {
  104. // Calculate f(x) for nodes a/b
  105. float fOfA = outMap[a].mHeuristic + outMap[a].mActualFromStart;
  106. float fOfB = outMap[b].mHeuristic + outMap[b].mActualFromStart;
  107. return fOfA < fOfB;
  108. });
  109. // Set to current and move from open to closed
  110. current = *iter;
  111. openSet.erase(iter);
  112. outMap[current].mInOpenSet = true;
  113. outMap[current].mInClosedSet = true;
  114. } while (current != goal);
  115. // Did we find a path?
  116. return (current == goal) ? true : false;
  117. }
  118. bool GBFS(const WeightedGraph& g, const WeightedGraphNode* start,
  119. const WeightedGraphNode* goal, GBFSMap& outMap)
  120. {
  121. std::vector<const WeightedGraphNode*> openSet;
  122. // Set current node to start, and mark in closed set
  123. const WeightedGraphNode* current = start;
  124. outMap[current].mInClosedSet = true;
  125. do
  126. {
  127. // Add adjacent nodes to open set
  128. for (const WeightedEdge* edge : current->mEdges)
  129. {
  130. // Get scratch data for this node
  131. GBFSScratch& data = outMap[edge->mTo];
  132. // Add it only if it's not in the closed set
  133. if (!data.mInClosedSet)
  134. {
  135. // Set the adjacent node's parent edge
  136. data.mParentEdge = edge;
  137. if (!data.mInOpenSet)
  138. {
  139. // Compute the heuristic for this node, and add to open set
  140. data.mHeuristic = ComputeHeuristic(edge->mTo, goal);
  141. data.mInOpenSet = true;
  142. openSet.emplace_back(edge->mTo);
  143. }
  144. }
  145. }
  146. // If open set is empty, all possible paths are exhausted
  147. if (openSet.empty())
  148. {
  149. break;
  150. }
  151. // Find lowest cost node in open set
  152. auto iter = std::min_element(openSet.begin(), openSet.end(),
  153. [&outMap](const WeightedGraphNode* a, const WeightedGraphNode* b) {
  154. return outMap[a].mHeuristic < outMap[b].mHeuristic;
  155. });
  156. // Set to current and move from open to closed
  157. current = *iter;
  158. openSet.erase(iter);
  159. outMap[current].mInOpenSet = false;
  160. outMap[current].mInClosedSet = true;
  161. } while (current != goal);
  162. // Did we find a path?
  163. return (current == goal) ? true : false;
  164. }
  165. using NodeToParentMap =
  166. std::unordered_map<const GraphNode*, const GraphNode*>;
  167. bool BFS(const Graph& graph, const GraphNode* start, const GraphNode* goal, NodeToParentMap& outMap)
  168. {
  169. // Whether we found a path
  170. bool pathFound = false;
  171. // Nodes to consider
  172. std::queue<const GraphNode*> q;
  173. // Enqueue the first node
  174. q.emplace(start);
  175. while (!q.empty())
  176. {
  177. // Dequeue a node
  178. const GraphNode* current = q.front();
  179. q.pop();
  180. if (current == goal)
  181. {
  182. pathFound = true;
  183. break;
  184. }
  185. // Enqueue adjacent nodes that aren't already in the queue
  186. for (const GraphNode* node : current->mAdjacent)
  187. {
  188. // If the parent is null, it hasn't been enqueued
  189. // (except for the start node)
  190. const GraphNode* parent = outMap[node];
  191. if (parent == nullptr && node != start)
  192. {
  193. // Enqueue this node, setting its parent
  194. outMap[node] = current;
  195. q.emplace(node);
  196. }
  197. }
  198. }
  199. return pathFound;
  200. }
  201. void testBFS()
  202. {
  203. Graph g;
  204. for (int i = 0; i < 5; i++)
  205. {
  206. for (int j = 0; j < 5; j++)
  207. {
  208. GraphNode* node = new GraphNode;
  209. g.mNodes.emplace_back(node);
  210. }
  211. }
  212. for (int i = 0; i < 5; i++)
  213. {
  214. for (int j = 0; j < 5; j++)
  215. {
  216. GraphNode* node = g.mNodes[i * 5 + j];
  217. if (i > 0)
  218. {
  219. node->mAdjacent.emplace_back(g.mNodes[(i - 1) * 5 + j]);
  220. }
  221. if (i < 4)
  222. {
  223. node->mAdjacent.emplace_back(g.mNodes[(i + 1) * 5 + j]);
  224. }
  225. if (j > 0)
  226. {
  227. node->mAdjacent.emplace_back(g.mNodes[i * 5 + j - 1]);
  228. }
  229. if (j < 4)
  230. {
  231. node->mAdjacent.emplace_back(g.mNodes[i * 5 + j + 1]);
  232. }
  233. }
  234. }
  235. NodeToParentMap map;
  236. bool found = BFS(g, g.mNodes[0], g.mNodes[9], map);
  237. std::cout << found << '\n';
  238. }
  239. void testHeuristic(bool useAStar)
  240. {
  241. WeightedGraph g;
  242. for (int i = 0; i < 5; i++)
  243. {
  244. for (int j = 0; j < 5; j++)
  245. {
  246. WeightedGraphNode* node = new WeightedGraphNode;
  247. g.mNodes.emplace_back(node);
  248. }
  249. }
  250. for (int i = 0; i < 5; i++)
  251. {
  252. for (int j = 0; j < 5; j++)
  253. {
  254. WeightedGraphNode* node = g.mNodes[i * 5 + j];
  255. if (i > 0)
  256. {
  257. WeightedEdge* e = new WeightedEdge;
  258. e->mFrom = node;
  259. e->mTo = g.mNodes[(i - 1) * 5 + j];
  260. e->mWeight = 1.0f;
  261. node->mEdges.emplace_back(e);
  262. }
  263. if (i < 4)
  264. {
  265. WeightedEdge* e = new WeightedEdge;
  266. e->mFrom = node;
  267. e->mTo = g.mNodes[(i + 1) * 5 + j];
  268. e->mWeight = 1.0f;
  269. node->mEdges.emplace_back(e);
  270. }
  271. if (j > 0)
  272. {
  273. WeightedEdge* e = new WeightedEdge;
  274. e->mFrom = node;
  275. e->mTo = g.mNodes[i * 5 + j - 1];
  276. e->mWeight = 1.0f;
  277. node->mEdges.emplace_back(e);
  278. }
  279. if (j < 4)
  280. {
  281. WeightedEdge* e = new WeightedEdge;
  282. e->mFrom = node;
  283. e->mTo = g.mNodes[i * 5 + j + 1];
  284. e->mWeight = 1.0f;
  285. node->mEdges.emplace_back(e);
  286. }
  287. }
  288. }
  289. bool found = false;
  290. if (useAStar)
  291. {
  292. AStarMap map;
  293. found = AStar(g, g.mNodes[0], g.mNodes[9], map);
  294. }
  295. else
  296. {
  297. GBFSMap map;
  298. found = GBFS(g, g.mNodes[0], g.mNodes[9], map);
  299. }
  300. std::cout << found << '\n';
  301. }
  302. struct GameState
  303. {
  304. // (For tic-tac-toe, array of board)
  305. enum SquareState { Empty, X, O };
  306. SquareState mBoard[3][3];
  307. };
  308. struct GTNode
  309. {
  310. // Children nodes
  311. std::vector<GTNode*> mChildren;
  312. // State of game
  313. GameState mState;
  314. };
  315. void GenStates(GTNode* root, bool xPlayer)
  316. {
  317. for (int i = 0; i < 3; i++)
  318. {
  319. for (int j = 0; j < 3; j++)
  320. {
  321. if (root->mState.mBoard[i][j] == GameState::Empty)
  322. {
  323. GTNode* node = new GTNode;
  324. root->mChildren.emplace_back(node);
  325. node->mState = root->mState;
  326. node->mState.mBoard[i][j] = xPlayer ? GameState::X : GameState::O;
  327. GenStates(node, !xPlayer);
  328. }
  329. }
  330. }
  331. }
  332. float GetScore(const GameState& state)
  333. {
  334. // Are any of the rows the same?
  335. for (int i = 0; i < 3; i++)
  336. {
  337. bool same = true;
  338. GameState::SquareState v = state.mBoard[i][0];
  339. for (int j = 1; j < 3; j++)
  340. {
  341. if (state.mBoard[i][j] != v)
  342. {
  343. same = false;
  344. }
  345. }
  346. if (same)
  347. {
  348. if (v == GameState::X)
  349. {
  350. return 1.0f;
  351. }
  352. else
  353. {
  354. return -1.0f;
  355. }
  356. }
  357. }
  358. // Are any of the columns the same?
  359. for (int j = 0; j < 3; j++)
  360. {
  361. bool same = true;
  362. GameState::SquareState v = state.mBoard[0][j];
  363. for (int i = 1; i < 3; i++)
  364. {
  365. if (state.mBoard[i][j] != v)
  366. {
  367. same = false;
  368. }
  369. }
  370. if (same)
  371. {
  372. if (v == GameState::X)
  373. {
  374. return 1.0f;
  375. }
  376. else
  377. {
  378. return -1.0f;
  379. }
  380. }
  381. }
  382. // What about diagonals?
  383. if (((state.mBoard[0][0] == state.mBoard[1][1]) &&
  384. (state.mBoard[1][1] == state.mBoard[2][2])) ||
  385. ((state.mBoard[2][0] == state.mBoard[1][1]) &&
  386. (state.mBoard[1][1] == state.mBoard[0][2])))
  387. {
  388. if (state.mBoard[1][1] == GameState::X)
  389. {
  390. return 1.0f;
  391. }
  392. else
  393. {
  394. return -1.0f;
  395. }
  396. }
  397. // We tied
  398. return 0.0f;
  399. }
  400. float MinPlayer(const GTNode* node);
  401. float MaxPlayer(const GTNode* node)
  402. {
  403. // If this is a leaf, return score
  404. if (node->mChildren.empty())
  405. {
  406. return GetScore(node->mState);
  407. }
  408. float maxValue = -std::numeric_limits<float>::infinity();
  409. // Find the subtree with the maximum value
  410. for (const GTNode* child : node->mChildren)
  411. {
  412. maxValue = std::max(maxValue, MinPlayer(child));
  413. }
  414. return maxValue;
  415. }
  416. float MinPlayer(const GTNode* node)
  417. {
  418. // If this is a leaf, return score
  419. if (node->mChildren.empty())
  420. {
  421. return GetScore(node->mState);
  422. }
  423. float minValue = std::numeric_limits<float>::infinity();
  424. // Find the subtree with the minimum value
  425. for (const GTNode* child : node->mChildren)
  426. {
  427. minValue = std::min(minValue, MaxPlayer(child));
  428. }
  429. return minValue;
  430. }
  431. const GTNode* MinimaxDecide(const GTNode* root)
  432. {
  433. // Find the subtree with the maximum value, and save the choice
  434. const GTNode* choice = nullptr;
  435. float maxValue = -std::numeric_limits<float>::infinity();
  436. for (const GTNode* child : root->mChildren)
  437. {
  438. float v = MinPlayer(child);
  439. if (v > maxValue)
  440. {
  441. maxValue = v;
  442. choice = child;
  443. }
  444. }
  445. return choice;
  446. }
  447. float AlphaBetaMin(const GTNode* node, float alpha, float beta);
  448. float AlphaBetaMax(const GTNode* node, float alpha, float beta)
  449. {
  450. // If this is a leaf, return score
  451. if (node->mChildren.empty())
  452. {
  453. return GetScore(node->mState);
  454. }
  455. float maxValue = -std::numeric_limits<float>::infinity();
  456. // Find the subtree with the maximum value
  457. for (const GTNode* child : node->mChildren)
  458. {
  459. maxValue = std::max(maxValue, AlphaBetaMin(child, alpha, beta));
  460. if (maxValue >= beta)
  461. {
  462. return maxValue; // Beta prune
  463. }
  464. alpha = std::max(maxValue, alpha);
  465. }
  466. return maxValue;
  467. }
  468. float AlphaBetaMin(const GTNode* node, float alpha, float beta)
  469. {
  470. // If this is a leaf, return score
  471. if (node->mChildren.empty())
  472. {
  473. return GetScore(node->mState);
  474. }
  475. float minValue = std::numeric_limits<float>::infinity();
  476. // Find the subtree with the minimum value
  477. for (const GTNode* child : node->mChildren)
  478. {
  479. minValue = std::min(minValue, AlphaBetaMax(child, alpha, beta));
  480. if (minValue <= alpha)
  481. {
  482. return minValue; // Alpha prune
  483. }
  484. beta = std::min(minValue, beta);
  485. }
  486. return minValue;
  487. }
  488. const GTNode* AlphaBetaDecide(const GTNode* root)
  489. {
  490. // Find the subtree with the maximum value, and save the choice
  491. const GTNode* choice = nullptr;
  492. float maxValue = -std::numeric_limits<float>::infinity();
  493. float beta = std::numeric_limits<float>::infinity();
  494. for (const GTNode* child : root->mChildren)
  495. {
  496. float v = AlphaBetaMin(child, maxValue, beta);
  497. if (v > maxValue)
  498. {
  499. maxValue = v;
  500. choice = child;
  501. }
  502. }
  503. return choice;
  504. }
  505. void testTicTac()
  506. {
  507. GTNode* root = new GTNode;
  508. root->mState.mBoard[0][0] = GameState::O;
  509. root->mState.mBoard[0][1] = GameState::Empty;
  510. root->mState.mBoard[0][2] = GameState::X;
  511. root->mState.mBoard[1][0] = GameState::X;
  512. root->mState.mBoard[1][1] = GameState::O;
  513. root->mState.mBoard[1][2] = GameState::O;
  514. root->mState.mBoard[2][0] = GameState::X;
  515. root->mState.mBoard[2][1] = GameState::Empty;
  516. root->mState.mBoard[2][2] = GameState::Empty;
  517. GenStates(root, true);
  518. const GTNode* choice = AlphaBetaDecide(root);
  519. std::cout << choice->mChildren.size();
  520. }