Sparse.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889
  1. // This code is in the public domain -- Ignacio Castaño <[email protected]>
  2. #include "Sparse.h"
  3. #include "KahanSum.h"
  4. #include "nvcore/Array.inl"
  5. #define USE_KAHAN_SUM 0
  6. using namespace nv;
  7. FullVector::FullVector(uint dim)
  8. {
  9. m_array.resize(dim);
  10. }
  11. FullVector::FullVector(const FullVector & v) : m_array(v.m_array)
  12. {
  13. }
  14. const FullVector & FullVector::operator=(const FullVector & v)
  15. {
  16. nvCheck(dimension() == v.dimension());
  17. m_array = v.m_array;
  18. return *this;
  19. }
  20. void FullVector::fill(float f)
  21. {
  22. const uint dim = dimension();
  23. for (uint i = 0; i < dim; i++)
  24. {
  25. m_array[i] = f;
  26. }
  27. }
  28. void FullVector::operator+= (const FullVector & v)
  29. {
  30. nvDebugCheck(dimension() == v.dimension());
  31. const uint dim = dimension();
  32. for (uint i = 0; i < dim; i++)
  33. {
  34. m_array[i] += v.m_array[i];
  35. }
  36. }
  37. void FullVector::operator-= (const FullVector & v)
  38. {
  39. nvDebugCheck(dimension() == v.dimension());
  40. const uint dim = dimension();
  41. for (uint i = 0; i < dim; i++)
  42. {
  43. m_array[i] -= v.m_array[i];
  44. }
  45. }
  46. void FullVector::operator*= (const FullVector & v)
  47. {
  48. nvDebugCheck(dimension() == v.dimension());
  49. const uint dim = dimension();
  50. for (uint i = 0; i < dim; i++)
  51. {
  52. m_array[i] *= v.m_array[i];
  53. }
  54. }
  55. void FullVector::operator+= (float f)
  56. {
  57. const uint dim = dimension();
  58. for (uint i = 0; i < dim; i++)
  59. {
  60. m_array[i] += f;
  61. }
  62. }
  63. void FullVector::operator-= (float f)
  64. {
  65. const uint dim = dimension();
  66. for (uint i = 0; i < dim; i++)
  67. {
  68. m_array[i] -= f;
  69. }
  70. }
  71. void FullVector::operator*= (float f)
  72. {
  73. const uint dim = dimension();
  74. for (uint i = 0; i < dim; i++)
  75. {
  76. m_array[i] *= f;
  77. }
  78. }
  79. void nv::saxpy(float a, const FullVector & x, FullVector & y)
  80. {
  81. nvDebugCheck(x.dimension() == y.dimension());
  82. const uint dim = x.dimension();
  83. for (uint i = 0; i < dim; i++)
  84. {
  85. y[i] += a * x[i];
  86. }
  87. }
  88. void nv::copy(const FullVector & x, FullVector & y)
  89. {
  90. nvDebugCheck(x.dimension() == y.dimension());
  91. const uint dim = x.dimension();
  92. for (uint i = 0; i < dim; i++)
  93. {
  94. y[i] = x[i];
  95. }
  96. }
  97. void nv::scal(float a, FullVector & x)
  98. {
  99. const uint dim = x.dimension();
  100. for (uint i = 0; i < dim; i++)
  101. {
  102. x[i] *= a;
  103. }
  104. }
  105. float nv::dot(const FullVector & x, const FullVector & y)
  106. {
  107. nvDebugCheck(x.dimension() == y.dimension());
  108. const uint dim = x.dimension();
  109. #if USE_KAHAN_SUM
  110. KahanSum kahan;
  111. for (uint i = 0; i < dim; i++)
  112. {
  113. kahan.add(x[i] * y[i]);
  114. }
  115. return kahan.sum();
  116. #else
  117. float sum = 0;
  118. for (uint i = 0; i < dim; i++)
  119. {
  120. sum += x[i] * y[i];
  121. }
  122. return sum;
  123. #endif
  124. }
  125. FullMatrix::FullMatrix(uint d) : m_width(d), m_height(d)
  126. {
  127. m_array.resize(d*d, 0.0f);
  128. }
  129. FullMatrix::FullMatrix(uint w, uint h) : m_width(w), m_height(h)
  130. {
  131. m_array.resize(w*h, 0.0f);
  132. }
  133. FullMatrix::FullMatrix(const FullMatrix & m) : m_width(m.m_width), m_height(m.m_height)
  134. {
  135. m_array = m.m_array;
  136. }
  137. const FullMatrix & FullMatrix::operator=(const FullMatrix & m)
  138. {
  139. nvCheck(width() == m.width());
  140. nvCheck(height() == m.height());
  141. m_array = m.m_array;
  142. return *this;
  143. }
  144. float FullMatrix::getCoefficient(uint x, uint y) const
  145. {
  146. nvDebugCheck( x < width() );
  147. nvDebugCheck( y < height() );
  148. return m_array[y * width() + x];
  149. }
  150. void FullMatrix::setCoefficient(uint x, uint y, float f)
  151. {
  152. nvDebugCheck( x < width() );
  153. nvDebugCheck( y < height() );
  154. m_array[y * width() + x] = f;
  155. }
  156. void FullMatrix::addCoefficient(uint x, uint y, float f)
  157. {
  158. nvDebugCheck( x < width() );
  159. nvDebugCheck( y < height() );
  160. m_array[y * width() + x] += f;
  161. }
  162. void FullMatrix::mulCoefficient(uint x, uint y, float f)
  163. {
  164. nvDebugCheck( x < width() );
  165. nvDebugCheck( y < height() );
  166. m_array[y * width() + x] *= f;
  167. }
  168. float FullMatrix::dotRow(uint y, const FullVector & v) const
  169. {
  170. nvDebugCheck( v.dimension() == width() );
  171. nvDebugCheck( y < height() );
  172. float sum = 0;
  173. const uint count = v.dimension();
  174. for (uint i = 0; i < count; i++)
  175. {
  176. sum += m_array[y * count + i] * v[i];
  177. }
  178. return sum;
  179. }
  180. void FullMatrix::madRow(uint y, float alpha, FullVector & v) const
  181. {
  182. nvDebugCheck( v.dimension() == width() );
  183. nvDebugCheck( y < height() );
  184. const uint count = v.dimension();
  185. for (uint i = 0; i < count; i++)
  186. {
  187. v[i] += m_array[y * count + i];
  188. }
  189. }
  190. // y = M * x
  191. void nv::mult(const FullMatrix & M, const FullVector & x, FullVector & y)
  192. {
  193. mult(NoTransposed, M, x, y);
  194. }
  195. void nv::mult(Transpose TM, const FullMatrix & M, const FullVector & x, FullVector & y)
  196. {
  197. const uint w = M.width();
  198. const uint h = M.height();
  199. if (TM == Transposed)
  200. {
  201. nvDebugCheck( h == x.dimension() );
  202. nvDebugCheck( w == y.dimension() );
  203. y.fill(0.0f);
  204. for (uint i = 0; i < h; i++)
  205. {
  206. M.madRow(i, x[i], y);
  207. }
  208. }
  209. else
  210. {
  211. nvDebugCheck( w == x.dimension() );
  212. nvDebugCheck( h == y.dimension() );
  213. for (uint i = 0; i < h; i++)
  214. {
  215. y[i] = M.dotRow(i, x);
  216. }
  217. }
  218. }
  219. // y = alpha*A*x + beta*y
  220. void nv::sgemv(float alpha, const FullMatrix & A, const FullVector & x, float beta, FullVector & y)
  221. {
  222. sgemv(alpha, NoTransposed, A, x, beta, y);
  223. }
  224. void nv::sgemv(float alpha, Transpose TA, const FullMatrix & A, const FullVector & x, float beta, FullVector & y)
  225. {
  226. const uint w = A.width();
  227. const uint h = A.height();
  228. if (TA == Transposed)
  229. {
  230. nvDebugCheck( h == x.dimension() );
  231. nvDebugCheck( w == y.dimension() );
  232. for (uint i = 0; i < h; i++)
  233. {
  234. A.madRow(i, alpha * x[i], y);
  235. }
  236. }
  237. else
  238. {
  239. nvDebugCheck( w == x.dimension() );
  240. nvDebugCheck( h == y.dimension() );
  241. for (uint i = 0; i < h; i++)
  242. {
  243. y[i] = alpha * A.dotRow(i, x) + beta * y[i];
  244. }
  245. }
  246. }
  247. // Multiply a row of A by a column of B.
  248. static float dot(uint j, Transpose TA, const FullMatrix & A, uint i, Transpose TB, const FullMatrix & B)
  249. {
  250. const uint w = (TA == NoTransposed) ? A.width() : A.height();
  251. nvDebugCheck(w == ((TB == NoTransposed) ? B.height() : A.width()));
  252. float sum = 0.0f;
  253. for (uint k = 0; k < w; k++)
  254. {
  255. const float a = (TA == NoTransposed) ? A.getCoefficient(k, j) : A.getCoefficient(j, k); // @@ Move branches out of the loop?
  256. const float b = (TB == NoTransposed) ? B.getCoefficient(i, k) : A.getCoefficient(k, i);
  257. sum += a * b;
  258. }
  259. return sum;
  260. }
  261. // C = A * B
  262. void nv::mult(const FullMatrix & A, const FullMatrix & B, FullMatrix & C)
  263. {
  264. mult(NoTransposed, A, NoTransposed, B, C);
  265. }
  266. void nv::mult(Transpose TA, const FullMatrix & A, Transpose TB, const FullMatrix & B, FullMatrix & C)
  267. {
  268. sgemm(1.0f, TA, A, TB, B, 0.0f, C);
  269. }
  270. // C = alpha*A*B + beta*C
  271. void nv::sgemm(float alpha, const FullMatrix & A, const FullMatrix & B, float beta, FullMatrix & C)
  272. {
  273. sgemm(alpha, NoTransposed, A, NoTransposed, B, beta, C);
  274. }
  275. void nv::sgemm(float alpha, Transpose TA, const FullMatrix & A, Transpose TB, const FullMatrix & B, float beta, FullMatrix & C)
  276. {
  277. const uint w = C.width();
  278. const uint h = C.height();
  279. uint aw = (TA == NoTransposed) ? A.width() : A.height();
  280. uint ah = (TA == NoTransposed) ? A.height() : A.width();
  281. uint bw = (TB == NoTransposed) ? B.width() : B.height();
  282. uint bh = (TB == NoTransposed) ? B.height() : B.width();
  283. nvDebugCheck(aw == bh);
  284. nvDebugCheck(bw == ah);
  285. nvDebugCheck(w == bw);
  286. nvDebugCheck(h == ah);
  287. for (uint y = 0; y < h; y++)
  288. {
  289. for (uint x = 0; x < w; x++)
  290. {
  291. float c = alpha * ::dot(x, TA, A, y, TB, B) + beta * C.getCoefficient(x, y);
  292. C.setCoefficient(x, y, c);
  293. }
  294. }
  295. }
  296. /// Ctor. Init the size of the sparse matrix.
  297. SparseMatrix::SparseMatrix(uint d) : m_width(d)
  298. {
  299. m_array.resize(d);
  300. }
  301. /// Ctor. Init the size of the sparse matrix.
  302. SparseMatrix::SparseMatrix(uint w, uint h) : m_width(w)
  303. {
  304. m_array.resize(h);
  305. }
  306. SparseMatrix::SparseMatrix(const SparseMatrix & m) : m_width(m.m_width)
  307. {
  308. m_array = m.m_array;
  309. }
  310. const SparseMatrix & SparseMatrix::operator=(const SparseMatrix & m)
  311. {
  312. nvCheck(width() == m.width());
  313. nvCheck(height() == m.height());
  314. m_array = m.m_array;
  315. return *this;
  316. }
  317. // x is column, y is row
  318. float SparseMatrix::getCoefficient(uint x, uint y) const
  319. {
  320. nvDebugCheck( x < width() );
  321. nvDebugCheck( y < height() );
  322. const uint count = m_array[y].count();
  323. for (uint i = 0; i < count; i++)
  324. {
  325. if (m_array[y][i].x == x) return m_array[y][i].v;
  326. }
  327. return 0.0f;
  328. }
  329. void SparseMatrix::setCoefficient(uint x, uint y, float f)
  330. {
  331. nvDebugCheck( x < width() );
  332. nvDebugCheck( y < height() );
  333. const uint count = m_array[y].count();
  334. for (uint i = 0; i < count; i++)
  335. {
  336. if (m_array[y][i].x == x)
  337. {
  338. m_array[y][i].v = f;
  339. return;
  340. }
  341. }
  342. if (f != 0.0f)
  343. {
  344. Coefficient c = { x, f };
  345. m_array[y].append( c );
  346. }
  347. }
  348. void SparseMatrix::addCoefficient(uint x, uint y, float f)
  349. {
  350. nvDebugCheck( x < width() );
  351. nvDebugCheck( y < height() );
  352. if (f != 0.0f)
  353. {
  354. const uint count = m_array[y].count();
  355. for (uint i = 0; i < count; i++)
  356. {
  357. if (m_array[y][i].x == x)
  358. {
  359. m_array[y][i].v += f;
  360. return;
  361. }
  362. }
  363. Coefficient c = { x, f };
  364. m_array[y].append( c );
  365. }
  366. }
  367. void SparseMatrix::mulCoefficient(uint x, uint y, float f)
  368. {
  369. nvDebugCheck( x < width() );
  370. nvDebugCheck( y < height() );
  371. const uint count = m_array[y].count();
  372. for (uint i = 0; i < count; i++)
  373. {
  374. if (m_array[y][i].x == x)
  375. {
  376. m_array[y][i].v *= f;
  377. return;
  378. }
  379. }
  380. if (f != 0.0f)
  381. {
  382. Coefficient c = { x, f };
  383. m_array[y].append( c );
  384. }
  385. }
  386. float SparseMatrix::sumRow(uint y) const
  387. {
  388. nvDebugCheck( y < height() );
  389. const uint count = m_array[y].count();
  390. #if USE_KAHAN_SUM
  391. KahanSum kahan;
  392. for (uint i = 0; i < count; i++)
  393. {
  394. kahan.add(m_array[y][i].v);
  395. }
  396. return kahan.sum();
  397. #else
  398. float sum = 0;
  399. for (uint i = 0; i < count; i++)
  400. {
  401. sum += m_array[y][i].v;
  402. }
  403. return sum;
  404. #endif
  405. }
  406. float SparseMatrix::dotRow(uint y, const FullVector & v) const
  407. {
  408. nvDebugCheck( y < height() );
  409. const uint count = m_array[y].count();
  410. #if USE_KAHAN_SUM
  411. KahanSum kahan;
  412. for (uint i = 0; i < count; i++)
  413. {
  414. kahan.add(m_array[y][i].v * v[m_array[y][i].x]);
  415. }
  416. return kahan.sum();
  417. #else
  418. float sum = 0;
  419. for (uint i = 0; i < count; i++)
  420. {
  421. sum += m_array[y][i].v * v[m_array[y][i].x];
  422. }
  423. return sum;
  424. #endif
  425. }
  426. void SparseMatrix::madRow(uint y, float alpha, FullVector & v) const
  427. {
  428. nvDebugCheck(y < height());
  429. const uint count = m_array[y].count();
  430. for (uint i = 0; i < count; i++)
  431. {
  432. v[m_array[y][i].x] += alpha * m_array[y][i].v;
  433. }
  434. }
  435. void SparseMatrix::clearRow(uint y)
  436. {
  437. nvDebugCheck( y < height() );
  438. m_array[y].clear();
  439. }
  440. void SparseMatrix::scaleRow(uint y, float f)
  441. {
  442. nvDebugCheck( y < height() );
  443. const uint count = m_array[y].count();
  444. for (uint i = 0; i < count; i++)
  445. {
  446. m_array[y][i].v *= f;
  447. }
  448. }
  449. void SparseMatrix::normalizeRow(uint y)
  450. {
  451. nvDebugCheck( y < height() );
  452. float norm = 0.0f;
  453. const uint count = m_array[y].count();
  454. for (uint i = 0; i < count; i++)
  455. {
  456. float f = m_array[y][i].v;
  457. norm += f * f;
  458. }
  459. scaleRow(y, 1.0f / sqrtf(norm));
  460. }
  461. void SparseMatrix::clearColumn(uint x)
  462. {
  463. nvDebugCheck(x < width());
  464. for (uint y = 0; y < height(); y++)
  465. {
  466. const uint count = m_array[y].count();
  467. for (uint e = 0; e < count; e++)
  468. {
  469. if (m_array[y][e].x == x)
  470. {
  471. m_array[y][e].v = 0.0f;
  472. break;
  473. }
  474. }
  475. }
  476. }
  477. void SparseMatrix::scaleColumn(uint x, float f)
  478. {
  479. nvDebugCheck(x < width());
  480. for (uint y = 0; y < height(); y++)
  481. {
  482. const uint count = m_array[y].count();
  483. for (uint e = 0; e < count; e++)
  484. {
  485. if (m_array[y][e].x == x)
  486. {
  487. m_array[y][e].v *= f;
  488. break;
  489. }
  490. }
  491. }
  492. }
  493. const Array<SparseMatrix::Coefficient> & SparseMatrix::getRow(uint y) const
  494. {
  495. return m_array[y];
  496. }
  497. bool SparseMatrix::isSymmetric() const
  498. {
  499. for (uint y = 0; y < height(); y++)
  500. {
  501. const uint count = m_array[y].count();
  502. for (uint e = 0; e < count; e++)
  503. {
  504. const uint x = m_array[y][e].x;
  505. if (x > y) {
  506. float v = m_array[y][e].v;
  507. if (!equal(getCoefficient(y, x), v)) { // @@ epsilon
  508. return false;
  509. }
  510. }
  511. }
  512. }
  513. return true;
  514. }
  515. // y = M * x
  516. void nv::mult(const SparseMatrix & M, const FullVector & x, FullVector & y)
  517. {
  518. mult(NoTransposed, M, x, y);
  519. }
  520. void nv::mult(Transpose TM, const SparseMatrix & M, const FullVector & x, FullVector & y)
  521. {
  522. const uint w = M.width();
  523. const uint h = M.height();
  524. if (TM == Transposed)
  525. {
  526. nvDebugCheck( h == x.dimension() );
  527. nvDebugCheck( w == y.dimension() );
  528. y.fill(0.0f);
  529. for (uint i = 0; i < h; i++)
  530. {
  531. M.madRow(i, x[i], y);
  532. }
  533. }
  534. else
  535. {
  536. nvDebugCheck( w == x.dimension() );
  537. nvDebugCheck( h == y.dimension() );
  538. for (uint i = 0; i < h; i++)
  539. {
  540. y[i] = M.dotRow(i, x);
  541. }
  542. }
  543. }
  544. // y = alpha*A*x + beta*y
  545. void nv::sgemv(float alpha, const SparseMatrix & A, const FullVector & x, float beta, FullVector & y)
  546. {
  547. sgemv(alpha, NoTransposed, A, x, beta, y);
  548. }
  549. void nv::sgemv(float alpha, Transpose TA, const SparseMatrix & A, const FullVector & x, float beta, FullVector & y)
  550. {
  551. const uint w = A.width();
  552. const uint h = A.height();
  553. if (TA == Transposed)
  554. {
  555. nvDebugCheck( h == x.dimension() );
  556. nvDebugCheck( w == y.dimension() );
  557. for (uint i = 0; i < h; i++)
  558. {
  559. A.madRow(i, alpha * x[i], y);
  560. }
  561. }
  562. else
  563. {
  564. nvDebugCheck( w == x.dimension() );
  565. nvDebugCheck( h == y.dimension() );
  566. for (uint i = 0; i < h; i++)
  567. {
  568. y[i] = alpha * A.dotRow(i, x) + beta * y[i];
  569. }
  570. }
  571. }
  572. // dot y-row of A by x-column of B
  573. static float dotRowColumn(int y, const SparseMatrix & A, int x, const SparseMatrix & B)
  574. {
  575. const Array<SparseMatrix::Coefficient> & row = A.getRow(y);
  576. const uint count = row.count();
  577. #if USE_KAHAN_SUM
  578. KahanSum kahan;
  579. for (uint i = 0; i < count; i++)
  580. {
  581. const SparseMatrix::Coefficient & c = row[i];
  582. kahan.add(c.v * B.getCoefficient(x, c.x));
  583. }
  584. return kahan.sum();
  585. #else
  586. float sum = 0.0f;
  587. for (uint i = 0; i < count; i++)
  588. {
  589. const SparseMatrix::Coefficient & c = row[i];
  590. sum += c.v * B.getCoefficient(x, c.x);
  591. }
  592. return sum;
  593. #endif
  594. }
  595. // dot y-row of A by x-row of B
  596. static float dotRowRow(int y, const SparseMatrix & A, int x, const SparseMatrix & B)
  597. {
  598. const Array<SparseMatrix::Coefficient> & row = A.getRow(y);
  599. const uint count = row.count();
  600. #if USE_KAHAN_SUM
  601. KahanSum kahan;
  602. for (uint i = 0; i < count; i++)
  603. {
  604. const SparseMatrix::Coefficient & c = row[i];
  605. kahan.add(c.v * B.getCoefficient(c.x, x));
  606. }
  607. return kahan.sum();
  608. #else
  609. float sum = 0.0f;
  610. for (uint i = 0; i < count; i++)
  611. {
  612. const SparseMatrix::Coefficient & c = row[i];
  613. sum += c.v * B.getCoefficient(c.x, x);
  614. }
  615. return sum;
  616. #endif
  617. }
  618. // dot y-column of A by x-column of B
  619. static float dotColumnColumn(int y, const SparseMatrix & A, int x, const SparseMatrix & B)
  620. {
  621. nvDebugCheck(A.height() == B.height());
  622. const uint h = A.height();
  623. #if USE_KAHAN_SUM
  624. KahanSum kahan;
  625. for (uint i = 0; i < h; i++)
  626. {
  627. kahan.add(A.getCoefficient(y, i) * B.getCoefficient(x, i));
  628. }
  629. return kahan.sum();
  630. #else
  631. float sum = 0.0f;
  632. for (uint i = 0; i < h; i++)
  633. {
  634. sum += A.getCoefficient(y, i) * B.getCoefficient(x, i);
  635. }
  636. return sum;
  637. #endif
  638. }
  639. void nv::transpose(const SparseMatrix & A, SparseMatrix & B)
  640. {
  641. nvDebugCheck(A.width() == B.height());
  642. nvDebugCheck(B.width() == A.height());
  643. const uint w = A.width();
  644. for (uint x = 0; x < w; x++)
  645. {
  646. B.clearRow(x);
  647. }
  648. const uint h = A.height();
  649. for (uint y = 0; y < h; y++)
  650. {
  651. const Array<SparseMatrix::Coefficient> & row = A.getRow(y);
  652. const uint count = row.count();
  653. for (uint i = 0; i < count; i++)
  654. {
  655. const SparseMatrix::Coefficient & c = row[i];
  656. nvDebugCheck(c.x < w);
  657. B.setCoefficient(y, c.x, c.v);
  658. }
  659. }
  660. }
  661. // C = A * B
  662. void nv::mult(const SparseMatrix & A, const SparseMatrix & B, SparseMatrix & C)
  663. {
  664. mult(NoTransposed, A, NoTransposed, B, C);
  665. }
  666. void nv::mult(Transpose TA, const SparseMatrix & A, Transpose TB, const SparseMatrix & B, SparseMatrix & C)
  667. {
  668. sgemm(1.0f, TA, A, TB, B, 0.0f, C);
  669. }
  670. // C = alpha*A*B + beta*C
  671. void nv::sgemm(float alpha, const SparseMatrix & A, const SparseMatrix & B, float beta, SparseMatrix & C)
  672. {
  673. sgemm(alpha, NoTransposed, A, NoTransposed, B, beta, C);
  674. }
  675. void nv::sgemm(float alpha, Transpose TA, const SparseMatrix & A, Transpose TB, const SparseMatrix & B, float beta, SparseMatrix & C)
  676. {
  677. const uint w = C.width();
  678. const uint h = C.height();
  679. uint aw = (TA == NoTransposed) ? A.width() : A.height();
  680. uint ah = (TA == NoTransposed) ? A.height() : A.width();
  681. uint bw = (TB == NoTransposed) ? B.width() : B.height();
  682. uint bh = (TB == NoTransposed) ? B.height() : B.width();
  683. nvDebugCheck(aw == bh);
  684. nvDebugCheck(bw == ah);
  685. nvDebugCheck(w == bw);
  686. nvDebugCheck(h == ah);
  687. for (uint y = 0; y < h; y++)
  688. {
  689. for (uint x = 0; x < w; x++)
  690. {
  691. float c = beta * C.getCoefficient(x, y);
  692. if (TA == NoTransposed && TB == NoTransposed)
  693. {
  694. // dot y-row of A by x-column of B.
  695. c += alpha * dotRowColumn(y, A, x, B);
  696. }
  697. else if (TA == Transposed && TB == Transposed)
  698. {
  699. // dot y-column of A by x-row of B.
  700. c += alpha * dotRowColumn(x, B, y, A);
  701. }
  702. else if (TA == Transposed && TB == NoTransposed)
  703. {
  704. // dot y-column of A by x-column of B.
  705. c += alpha * dotColumnColumn(y, A, x, B);
  706. }
  707. else if (TA == NoTransposed && TB == Transposed)
  708. {
  709. // dot y-row of A by x-row of B.
  710. c += alpha * dotRowRow(y, A, x, B);
  711. }
  712. C.setCoefficient(x, y, c);
  713. }
  714. }
  715. }
  716. // C = At * A
  717. void nv::sqm(const SparseMatrix & A, SparseMatrix & C)
  718. {
  719. // This is quite expensive...
  720. mult(Transposed, A, NoTransposed, A, C);
  721. }