nearest_neighbor_graph_ann.cpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. /*************************************************************************
  2. * Copyright (c) 2011 AT&T Intellectual Property
  3. * All rights reserved. This program and the accompanying materials
  4. * are made available under the terms of the Eclipse Public License v1.0
  5. * which accompanies this distribution, and is available at
  6. * https://www.eclipse.org/legal/epl-v10.html
  7. *
  8. * Contributors: Details at https://graphviz.org
  9. *************************************************************************/
  10. #include <ANN/ANN.h> // ANN declarations
  11. #include <mingle/nearest_neighbor_graph_ann.h>
  12. #include <utility>
  13. #include <vector>
  14. static const int dim = 4; // dimension
  15. static void sortPtsX(int n, ANNpointArray pts){
  16. /* sort so that edges always go from left to right in x-doordinate */
  17. for (int i = 0; i < n; i++){
  18. ANNpoint p = pts[i];
  19. if (p[0] < p[2] || (p[0] == p[2] && p[1] < p[3])) continue;
  20. std::swap(p[0], p[2]);
  21. std::swap(p[1], p[3]);
  22. }
  23. }
  24. static void sortPtsY(int n, ANNpointArray pts){
  25. /* sort so that edges always go from left to right in x-doordinate */
  26. for (int i = 0; i < n; i++){
  27. ANNpoint p = pts[i];
  28. if (p[1] < p[3] || (p[1] == p[3] && p[0] < p[2])) continue;
  29. std::swap(p[0], p[2]);
  30. std::swap(p[1], p[3]);
  31. }
  32. }
  33. void nearest_neighbor_graph_ann(int nPts, int k, const std::vector<double> &x,
  34. int &nz0, std::vector<int> &irn,
  35. std::vector<int> &jcn,
  36. std::vector<double> &val) {
  37. /* Gives a nearest neighbor graph is a list of dim-dimendional points. The connectivity is in irn/jcn, and the distance in val.
  38. nPts: number of points
  39. dim: dimension
  40. k: number of neighbors needed
  41. x: nPts*dim vector. The i-th point is x[i*dim : i*dim + dim - 1]
  42. nz: number of entries in the connectivity matrix irn/jcn/val
  43. irn, jcn: the connectivity
  44. val: the distance
  45. note that there could be repeates
  46. */
  47. // error tolerance
  48. const double eps = 0;
  49. ANNpointArray dataPts = annAllocPts(nPts, dim); // allocate data points
  50. std::vector<ANNidx> nnIdx(k); // allocate near neighbor indices
  51. std::vector<ANNdist> dists(k); // allocate near neighbor dists
  52. for (int i = 0; i < nPts; i++){
  53. double *xx = dataPts[i];
  54. for (int j = 0; j < dim; j++) xx[j] = x[i*dim + j];
  55. }
  56. //========= graph when sort based on x ========
  57. int nz = 0;
  58. sortPtsX(nPts, dataPts);
  59. ANNkd_tree kdTree( // build search structure
  60. dataPts, // the data points
  61. nPts, // number of points
  62. dim); // dimension of space
  63. for (int ip = 0; ip < nPts; ip++){
  64. kdTree.annkSearch( // search
  65. dataPts[ip], // query point
  66. k, // number of near neighbors
  67. nnIdx.data(), // nearest neighbors (returned)
  68. dists.data(), // distance (returned)
  69. eps); // error bound
  70. for (int i = 0; i < k; i++) { // print summary
  71. if (nnIdx[i] == ip) continue;
  72. val[nz] = dists[i];
  73. irn[nz] = ip;
  74. jcn[nz++] = nnIdx[i];
  75. }
  76. }
  77. //========= graph when sort based on y ========
  78. sortPtsY(nPts, dataPts);
  79. kdTree = ANNkd_tree( // build search structure
  80. dataPts, // the data points
  81. nPts, // number of points
  82. dim); // dimension of space
  83. for (int ip = 0; ip < nPts; ip++){
  84. kdTree.annkSearch( // search
  85. dataPts[ip], // query point
  86. k, // number of near neighbors
  87. nnIdx.data(), // nearest neighbors (returned)
  88. dists.data(), // distance (returned)
  89. eps); // error bound
  90. for (int i = 0; i < k; i++) { // print summary
  91. if (nnIdx[i] == ip) continue;
  92. val[nz] = dists[i];
  93. irn[nz] = ip;
  94. jcn[nz++] = nnIdx[i];
  95. }
  96. }
  97. nz0 = nz;
  98. annDeallocPts(dataPts);
  99. annClose(); // done with ANN
  100. }