Forráskód Böngészése

Reduce memory usage for edges in A* and add tests

Shiqing 6 éve
szülő
commit
c2b824687d
3 módosított fájl, 203 hozzáadás és 37 törlés
  1. 47 30
      core/math/a_star.cpp
  2. 22 7
      core/math/a_star.h
  3. 134 0
      main/tests/test_astar.cpp

+ 47 - 30
core/math/a_star.cpp

@@ -164,22 +164,24 @@ void AStar::connect_points(int p_id, int p_with_id, bool bidirectional) {
 	}
 
 	Segment s(p_id, p_with_id);
-	s.from_point = a;
-	s.to_point = b;
-	segments.insert(s);
-
-	if (bidirectional) {
-		SWAP(s.from, s.to);
-		SWAP(s.from_point, s.to_point);
-		segments.insert(s);
+	if (bidirectional) s.direction = Segment::BIDIRECTIONAL;
+
+	Set<Segment>::Element *element = segments.find(s);
+	if (element != NULL) {
+		s.direction |= element->get().direction;
+		if (s.direction == Segment::BIDIRECTIONAL) {
+			// Both are neighbours of each other now
+			a->unlinked_neighbours.remove(b->id);
+			b->unlinked_neighbours.remove(a->id);
+		}
+		segments.erase(element);
 	}
+
+	segments.insert(s);
 }
 
 void AStar::disconnect_points(int p_id, int p_with_id, bool bidirectional) {
 
-	Segment s(p_id, p_with_id);
-	Segment t(p_with_id, p_id);
-
 	Point *a;
 	bool a_exists = points.lookup(p_id, a);
 	CRASH_COND(!a_exists);
@@ -188,23 +190,32 @@ void AStar::disconnect_points(int p_id, int p_with_id, bool bidirectional) {
 	bool b_exists = points.lookup(p_with_id, b);
 	CRASH_COND(!b_exists);
 
-	bool warned = false;
+	Segment s(p_id, p_with_id);
+	int remove_direction = bidirectional ? (int)Segment::BIDIRECTIONAL : s.direction;
+
+	Set<Segment>::Element *element = segments.find(s);
+	if (element != NULL) {
+		// s is the new segment
+		// Erase the directions to be removed
+		s.direction = (element->get().direction & ~remove_direction);
 
-	if (segments.has(s)) {
-		segments.erase(s);
 		a->neighbours.remove(b->id);
-		b->unlinked_neighbours.remove(a->id);
-	} else {
-		warned = true;
-		WARN_PRINT("The edge to be removed does not exist.");
-	}
+		if (bidirectional) {
+			b->neighbours.remove(a->id);
+			if (element->get().direction != Segment::BIDIRECTIONAL) {
+				a->unlinked_neighbours.remove(b->id);
+				b->unlinked_neighbours.remove(a->id);
+			}
+		} else {
+			if (s.direction == Segment::NONE)
+				b->unlinked_neighbours.remove(a->id);
+			else
+				a->unlinked_neighbours.set(b->id, b);
+		}
 
-	if (bidirectional && segments.has(t)) {
-		segments.erase(t);
-		b->neighbours.remove(a->id);
-		a->unlinked_neighbours.remove(b->id);
-	} else if (bidirectional && !warned) {
-		WARN_PRINT("The reverse edge to be removed does not exist.");
+		segments.erase(element);
+		if (s.direction != Segment::NONE)
+			segments.insert(s);
 	}
 }
 
@@ -242,8 +253,10 @@ PoolVector<int> AStar::get_point_connections(int p_id) {
 bool AStar::are_points_connected(int p_id, int p_with_id, bool bidirectional) const {
 
 	Segment s(p_id, p_with_id);
-	Segment t(p_with_id, p_id);
-	return segments.has(s) || (bidirectional && segments.has(t));
+	const Set<Segment>::Element *element = segments.find(s);
+
+	return element != NULL &&
+		   (bidirectional || (element->get().direction & s.direction) == s.direction);
 }
 
 void AStar::clear() {
@@ -297,13 +310,17 @@ Vector3 AStar::get_closest_position_in_segment(const Vector3 &p_point) const {
 
 	for (const Set<Segment>::Element *E = segments.front(); E; E = E->next()) {
 
-		if (!(E->get().from_point->enabled && E->get().to_point->enabled)) {
+		Point *from_point = nullptr, *to_point = nullptr;
+		points.lookup(E->get().u, from_point);
+		points.lookup(E->get().v, to_point);
+
+		if (!(from_point->enabled && to_point->enabled)) {
 			continue;
 		}
 
 		Vector3 segment[2] = {
-			E->get().from_point->pos,
-			E->get().to_point->pos,
+			from_point->pos,
+			to_point->pos,
 		};
 
 		Vector3 p = Geometry::get_closest_point_to_segment(p_point, segment);

+ 22 - 7
core/math/a_star.h

@@ -81,20 +81,35 @@ class AStar : public Reference {
 	struct Segment {
 		union {
 			struct {
-				int32_t from;
-				int32_t to;
+				int32_t u;
+				int32_t v;
 			};
 			uint64_t key;
 		};
 
-		Point *from_point;
-		Point *to_point;
+		enum {
+			NONE = 0,
+			FORWARD = 1,
+			BACKWARD = 2,
+			BIDIRECTIONAL = FORWARD | BACKWARD
+		};
+		unsigned char direction;
 
 		bool operator<(const Segment &p_s) const { return key < p_s.key; }
-		Segment() { key = 0; }
+		Segment() {
+			key = 0;
+			direction = NONE;
+		}
 		Segment(int p_from, int p_to) {
-			from = p_from;
-			to = p_to;
+			if (p_from < p_to) {
+				u = p_from;
+				v = p_to;
+				direction = FORWARD;
+			} else {
+				u = p_to;
+				v = p_from;
+				direction = BACKWARD;
+			}
 		}
 	};
 

+ 134 - 0
main/tests/test_astar.cpp

@@ -87,11 +87,145 @@ bool test_abcx() {
 	return ok;
 }
 
+bool test_add_remove() {
+	AStar a;
+	bool ok = true;
+
+	// Manual tests
+	a.add_point(1, Vector3(0, 0, 0));
+	a.add_point(2, Vector3(0, 1, 0));
+	a.add_point(3, Vector3(1, 1, 0));
+	a.add_point(4, Vector3(2, 0, 0));
+	a.connect_points(1, 2, true);
+	a.connect_points(1, 3, true);
+	a.connect_points(1, 4, false);
+
+	ok = ok && (a.are_points_connected(2, 1) == true);
+	ok = ok && (a.are_points_connected(4, 1) == true);
+	ok = ok && (a.are_points_connected(2, 1, false) == true);
+	ok = ok && (a.are_points_connected(4, 1, false) == false);
+
+	a.disconnect_points(1, 2, true);
+	ok = ok && (a.get_point_connections(1).size() == 2); // 3, 4
+	ok = ok && (a.get_point_connections(2).size() == 0);
+
+	a.disconnect_points(4, 1, false);
+	ok = ok && (a.get_point_connections(1).size() == 2); // 3, 4
+	ok = ok && (a.get_point_connections(4).size() == 0);
+
+	a.disconnect_points(4, 1, true);
+	ok = ok && (a.get_point_connections(1).size() == 1); // 3
+	ok = ok && (a.get_point_connections(4).size() == 0);
+
+	a.connect_points(2, 3, false);
+	ok = ok && (a.get_point_connections(2).size() == 1); // 3
+	ok = ok && (a.get_point_connections(3).size() == 1); // 1
+
+	a.connect_points(2, 3, true);
+	ok = ok && (a.get_point_connections(2).size() == 1); // 3
+	ok = ok && (a.get_point_connections(3).size() == 2); // 1, 2
+
+	a.disconnect_points(2, 3, false);
+	ok = ok && (a.get_point_connections(2).size() == 0);
+	ok = ok && (a.get_point_connections(3).size() == 2); // 1, 2
+
+	a.connect_points(4, 3, true);
+	ok = ok && (a.get_point_connections(3).size() == 3); // 1, 2, 4
+	ok = ok && (a.get_point_connections(4).size() == 1); // 3
+
+	a.disconnect_points(3, 4, false);
+	ok = ok && (a.get_point_connections(3).size() == 2); // 1, 2
+	ok = ok && (a.get_point_connections(4).size() == 1); // 3
+
+	a.remove_point(3);
+	ok = ok && (a.get_point_connections(1).size() == 0);
+	ok = ok && (a.get_point_connections(2).size() == 0);
+	ok = ok && (a.get_point_connections(4).size() == 0);
+
+	a.add_point(0, Vector3(0, -1, 0));
+	a.add_point(3, Vector3(2, 1, 0));
+	// 0: (0, -1)
+	// 1: (0, 0)
+	// 2: (0, 1)
+	// 3: (2, 1)
+	// 4: (2, 0)
+
+	// Tests for get_closest_position_in_segment
+	a.connect_points(2, 3);
+	ok = ok && (a.get_closest_position_in_segment(Vector3(0.5, 0.5, 0)) == Vector3(0.5, 1, 0));
+
+	a.connect_points(3, 4);
+	a.connect_points(0, 3);
+	a.connect_points(1, 4);
+	a.disconnect_points(1, 4, false);
+	a.disconnect_points(4, 3, false);
+	a.disconnect_points(3, 4, false);
+	// Remaining edges: <2, 3>, <0, 3>, <1, 4> (directed)
+	ok = ok && (a.get_closest_position_in_segment(Vector3(2, 0.5, 0)) == Vector3(1.75, 0.75, 0));
+	ok = ok && (a.get_closest_position_in_segment(Vector3(-1, 0.2, 0)) == Vector3(0, 0, 0));
+	ok = ok && (a.get_closest_position_in_segment(Vector3(3, 2, 0)) == Vector3(2, 1, 0));
+
+	int seed = 0;
+
+	// Random tests for connectivity checks
+	for (int i = 0; i < 20000; i++) {
+		seed = (seed * 1103515245 + 12345) & 0x7fffffff;
+		int u = (seed / 5) % 5;
+		int v = seed % 5;
+		if (u == v) {
+			i--;
+			continue;
+		}
+		if (seed % 2 == 1) {
+			// Add a (possibly existing) directed edge and confirm connectivity
+			a.connect_points(u, v, false);
+			ok = ok && (a.are_points_connected(u, v, false) == true);
+		} else {
+			// Remove a (possibly nonexistent) directed edge and confirm disconnectivity
+			a.disconnect_points(u, v, false);
+			ok = ok && (a.are_points_connected(u, v, false) == false);
+		}
+	}
+
+	// Random tests for point removal
+	for (int i = 0; i < 20000; i++) {
+		a.clear();
+		for (int j = 0; j < 5; j++)
+			a.add_point(j, Vector3(0, 0, 0));
+
+		// Add or remove random edges
+		for (int j = 0; j < 10; j++) {
+			seed = (seed * 1103515245 + 12345) & 0x7fffffff;
+			int u = (seed / 5) % 5;
+			int v = seed % 5;
+			if (u == v) {
+				j--;
+				continue;
+			}
+			if (seed % 2 == 1)
+				a.connect_points(u, v, false);
+			else
+				a.disconnect_points(u, v, false);
+		}
+
+		// Remove point 0
+		a.remove_point(0);
+		// White box: this will check all edges remaining in the segments set
+		for (int j = 1; j < 5; j++) {
+			ok = ok && (a.are_points_connected(0, j, true) == false);
+		}
+	}
+
+	// It's been great work, cheers \(^ ^)/
+	return ok;
+}
+
 typedef bool (*TestFunc)(void);
 
 TestFunc test_funcs[] = {
 	test_abc,
 	test_abcx,
+	test_add_remove,
 	NULL
 };