invdyn_bullet_comparison.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. #include "invdyn_bullet_comparison.hpp"
  2. #include <cmath>
  3. #include "BulletInverseDynamics/IDConfig.hpp"
  4. #include "BulletInverseDynamics/MultiBodyTree.hpp"
  5. #include "btBulletDynamicsCommon.h"
  6. #include "BulletDynamics/Featherstone/btMultiBodyConstraintSolver.h"
  7. #include "BulletDynamics/Featherstone/btMultiBodyDynamicsWorld.h"
  8. #include "BulletDynamics/Featherstone/btMultiBodyLinkCollider.h"
  9. #include "BulletDynamics/Featherstone/btMultiBodyPoint2Point.h"
  10. namespace btInverseDynamics {
  11. int compareInverseAndForwardDynamics(vecx &q, vecx &u, vecx &dot_u, btVector3 &gravity, bool verbose,
  12. btMultiBody *btmb, MultiBodyTree *id_tree, double *pos_error,
  13. double *acc_error) {
  14. // call function and return -1 if it does, printing an error_message
  15. #define RETURN_ON_FAILURE(x) \
  16. do { \
  17. if (-1 == x) { \
  18. error_message("calling " #x "\n"); \
  19. return -1; \
  20. } \
  21. } while (0)
  22. if (verbose) {
  23. printf("\n ===================================== \n");
  24. }
  25. vecx joint_forces(q.size());
  26. // set positions and velocities for btMultiBody
  27. // base link
  28. mat33 world_T_base;
  29. vec3 world_pos_base;
  30. btTransform base_transform;
  31. vec3 base_velocity;
  32. vec3 base_angular_velocity;
  33. RETURN_ON_FAILURE(id_tree->setGravityInWorldFrame(gravity));
  34. RETURN_ON_FAILURE(id_tree->getBodyOrigin(0, &world_pos_base));
  35. RETURN_ON_FAILURE(id_tree->getBodyTransform(0, &world_T_base));
  36. RETURN_ON_FAILURE(id_tree->getBodyAngularVelocity(0, &base_angular_velocity));
  37. RETURN_ON_FAILURE(id_tree->getBodyLinearVelocityCoM(0, &base_velocity));
  38. base_transform.setBasis(world_T_base);
  39. base_transform.setOrigin(world_pos_base);
  40. btmb->setBaseWorldTransform(base_transform);
  41. btmb->setBaseOmega(base_angular_velocity);
  42. btmb->setBaseVel(base_velocity);
  43. btmb->setLinearDamping(0);
  44. btmb->setAngularDamping(0);
  45. // remaining links
  46. int q_index;
  47. if (btmb->hasFixedBase()) {
  48. q_index = 0;
  49. } else {
  50. q_index = 6;
  51. }
  52. if (verbose) {
  53. printf("bt:num_links= %d, num_dofs= %d\n", btmb->getNumLinks(), btmb->getNumDofs());
  54. }
  55. for (int l = 0; l < btmb->getNumLinks(); l++) {
  56. const btMultibodyLink &link = btmb->getLink(l);
  57. if (verbose) {
  58. printf("link %d, pos_var_count= %d, dof_count= %d\n", l, link.m_posVarCount,
  59. link.m_dofCount);
  60. }
  61. if (link.m_posVarCount == 1) {
  62. btmb->setJointPosMultiDof(l, &q(q_index));
  63. btmb->setJointVelMultiDof(l, &u(q_index));
  64. if (verbose) {
  65. printf("set q[%d]= %f, u[%d]= %f\n", q_index, q(q_index), q_index, u(q_index));
  66. }
  67. q_index++;
  68. }
  69. }
  70. // sanity check
  71. if (q_index != q.size()) {
  72. error_message("error in number of dofs for btMultibody and MultiBodyTree\n");
  73. return -1;
  74. }
  75. // run inverse dynamics to determine joint_forces for given q, u, dot_u
  76. if (-1 == id_tree->calculateInverseDynamics(q, u, dot_u, &joint_forces)) {
  77. error_message("calculating inverse dynamics\n");
  78. return -1;
  79. }
  80. // set up bullet forward dynamics model
  81. btScalar dt = 0;
  82. btAlignedObjectArray<btScalar> scratch_r;
  83. btAlignedObjectArray<btVector3> scratch_v;
  84. btAlignedObjectArray<btMatrix3x3> scratch_m;
  85. // this triggers switch between using either appliedConstraintForce or appliedForce
  86. bool isConstraintPass = false;
  87. // apply gravity forces for btMultiBody model. Must be done manually.
  88. btmb->addBaseForce(btmb->getBaseMass() * gravity);
  89. for (int link = 0; link < btmb->getNumLinks(); link++) {
  90. btmb->addLinkForce(link, gravity * btmb->getLinkMass(link));
  91. if (verbose) {
  92. printf("link %d, applying gravity %f %f %f\n", link,
  93. gravity[0] * btmb->getLinkMass(link), gravity[1] * btmb->getLinkMass(link),
  94. gravity[2] * btmb->getLinkMass(link));
  95. }
  96. }
  97. // apply generalized forces
  98. if (btmb->hasFixedBase()) {
  99. q_index = 0;
  100. } else {
  101. vec3 base_force;
  102. base_force(0) = joint_forces(3);
  103. base_force(1) = joint_forces(4);
  104. base_force(2) = joint_forces(5);
  105. vec3 base_moment;
  106. base_moment(0) = joint_forces(0);
  107. base_moment(1) = joint_forces(1);
  108. base_moment(2) = joint_forces(2);
  109. btmb->addBaseForce(world_T_base * base_force);
  110. btmb->addBaseTorque(world_T_base * base_moment);
  111. if (verbose) {
  112. printf("base force from id: %f %f %f\n", joint_forces(3), joint_forces(4),
  113. joint_forces(5));
  114. printf("base moment from id: %f %f %f\n", joint_forces(0), joint_forces(1),
  115. joint_forces(2));
  116. }
  117. q_index = 6;
  118. }
  119. for (int l = 0; l < btmb->getNumLinks(); l++) {
  120. const btMultibodyLink &link = btmb->getLink(l);
  121. if (link.m_posVarCount == 1) {
  122. if (verbose) {
  123. printf("id:joint_force[%d]= %f, applied to link %d\n", q_index,
  124. joint_forces(q_index), l);
  125. }
  126. btmb->addJointTorque(l, joint_forces(q_index));
  127. q_index++;
  128. }
  129. }
  130. // sanity check
  131. if (q_index != q.size()) {
  132. error_message("error in number of dofs for btMultibody and MultiBodyTree\n");
  133. return -1;
  134. }
  135. // run forward kinematics & forward dynamics
  136. btAlignedObjectArray<btQuaternion> world_to_local;
  137. btAlignedObjectArray<btVector3> local_origin;
  138. btmb->forwardKinematics(world_to_local, local_origin);
  139. btmb->computeAccelerationsArticulatedBodyAlgorithmMultiDof(dt, scratch_r, scratch_v, scratch_m, isConstraintPass);
  140. // read generalized accelerations back from btMultiBody
  141. // the mapping from scratch variables to accelerations is taken from the implementation
  142. // of stepVelocitiesMultiDof
  143. btScalar *base_accel = &scratch_r[btmb->getNumDofs()];
  144. btScalar *joint_accel = base_accel + 6;
  145. *acc_error = 0;
  146. int dot_u_offset = 0;
  147. if (btmb->hasFixedBase()) {
  148. dot_u_offset = 0;
  149. } else {
  150. dot_u_offset = 6;
  151. }
  152. if (true == btmb->hasFixedBase()) {
  153. for (int i = 0; i < btmb->getNumDofs(); i++) {
  154. if (verbose) {
  155. printf("bt:ddot_q[%d]= %f, id:ddot_q= %e, diff= %e\n", i, joint_accel[i],
  156. dot_u(i + dot_u_offset), joint_accel[i] - dot_u(i));
  157. }
  158. *acc_error += std::pow(joint_accel[i] - dot_u(i + dot_u_offset), 2);
  159. }
  160. } else {
  161. vec3 base_dot_omega;
  162. vec3 world_dot_omega;
  163. world_dot_omega(0) = base_accel[0];
  164. world_dot_omega(1) = base_accel[1];
  165. world_dot_omega(2) = base_accel[2];
  166. base_dot_omega = world_T_base.transpose() * world_dot_omega;
  167. // com happens to coincide with link origin here. If that changes, we need to calculate
  168. // ddot_com
  169. vec3 base_ddot_com;
  170. vec3 world_ddot_com;
  171. world_ddot_com(0) = base_accel[3];
  172. world_ddot_com(1) = base_accel[4];
  173. world_ddot_com(2) = base_accel[5];
  174. base_ddot_com = world_T_base.transpose()*world_ddot_com;
  175. for (int i = 0; i < 3; i++) {
  176. if (verbose) {
  177. printf("bt::base_dot_omega(%d)= %e dot_u[%d]= %e, diff= %e\n", i, base_dot_omega(i),
  178. i, dot_u[i], base_dot_omega(i) - dot_u[i]);
  179. }
  180. *acc_error += std::pow(base_dot_omega(i) - dot_u(i), 2);
  181. }
  182. for (int i = 0; i < 3; i++) {
  183. if (verbose) {
  184. printf("bt::base_ddot_com(%d)= %e dot_u[%d]= %e, diff= %e\n", i, base_ddot_com(i),
  185. i, dot_u[i + 3], base_ddot_com(i) - dot_u[i + 3]);
  186. }
  187. *acc_error += std::pow(base_ddot_com(i) - dot_u(i + 3), 2);
  188. }
  189. for (int i = 0; i < btmb->getNumDofs(); i++) {
  190. if (verbose) {
  191. printf("bt:ddot_q[%d]= %f, id:ddot_q= %e, diff= %e\n", i, joint_accel[i],
  192. dot_u(i + 6), joint_accel[i] - dot_u(i + 6));
  193. }
  194. *acc_error += std::pow(joint_accel[i] - dot_u(i + 6), 2);
  195. }
  196. }
  197. *acc_error = std::sqrt(*acc_error);
  198. if (verbose) {
  199. printf("======dynamics-err: %e\n", *acc_error);
  200. }
  201. *pos_error = 0.0;
  202. {
  203. mat33 world_T_body;
  204. if (-1 == id_tree->getBodyTransform(0, &world_T_body)) {
  205. error_message("getting transform for body %d\n", 0);
  206. return -1;
  207. }
  208. vec3 world_com;
  209. if (-1 == id_tree->getBodyCoM(0, &world_com)) {
  210. error_message("getting com for body %d\n", 0);
  211. return -1;
  212. }
  213. if (verbose) {
  214. printf("id:com: %f %f %f\n", world_com(0), world_com(1), world_com(2));
  215. printf("id:transform: %f %f %f\n"
  216. " %f %f %f\n"
  217. " %f %f %f\n",
  218. world_T_body(0, 0), world_T_body(0, 1), world_T_body(0, 2), world_T_body(1, 0),
  219. world_T_body(1, 1), world_T_body(1, 2), world_T_body(2, 0), world_T_body(2, 1),
  220. world_T_body(2, 2));
  221. }
  222. }
  223. for (int l = 0; l < btmb->getNumLinks(); l++) {
  224. const btMultibodyLink &bt_link = btmb->getLink(l);
  225. vec3 bt_origin = bt_link.m_cachedWorldTransform.getOrigin();
  226. mat33 bt_basis = bt_link.m_cachedWorldTransform.getBasis();
  227. if (verbose) {
  228. printf("------------- link %d\n", l + 1);
  229. printf("bt:com: %f %f %f\n", bt_origin(0), bt_origin(1), bt_origin(2));
  230. printf("bt:transform: %f %f %f\n"
  231. " %f %f %f\n"
  232. " %f %f %f\n",
  233. bt_basis(0, 0), bt_basis(0, 1), bt_basis(0, 2), bt_basis(1, 0), bt_basis(1, 1),
  234. bt_basis(1, 2), bt_basis(2, 0), bt_basis(2, 1), bt_basis(2, 2));
  235. }
  236. mat33 id_world_T_body;
  237. vec3 id_world_com;
  238. if (-1 == id_tree->getBodyTransform(l + 1, &id_world_T_body)) {
  239. error_message("getting transform for body %d\n", l);
  240. return -1;
  241. }
  242. if (-1 == id_tree->getBodyCoM(l + 1, &id_world_com)) {
  243. error_message("getting com for body %d\n", l);
  244. return -1;
  245. }
  246. if (verbose) {
  247. printf("id:com: %f %f %f\n", id_world_com(0), id_world_com(1), id_world_com(2));
  248. printf("id:transform: %f %f %f\n"
  249. " %f %f %f\n"
  250. " %f %f %f\n",
  251. id_world_T_body(0, 0), id_world_T_body(0, 1), id_world_T_body(0, 2),
  252. id_world_T_body(1, 0), id_world_T_body(1, 1), id_world_T_body(1, 2),
  253. id_world_T_body(2, 0), id_world_T_body(2, 1), id_world_T_body(2, 2));
  254. }
  255. vec3 diff_com = bt_origin - id_world_com;
  256. mat33 diff_basis = bt_basis - id_world_T_body;
  257. if (verbose) {
  258. printf("diff-com: %e %e %e\n", diff_com(0), diff_com(1), diff_com(2));
  259. printf("diff-transform: %e %e %e %e %e %e %e %e %e\n", diff_basis(0, 0),
  260. diff_basis(0, 1), diff_basis(0, 2), diff_basis(1, 0), diff_basis(1, 1),
  261. diff_basis(1, 2), diff_basis(2, 0), diff_basis(2, 1), diff_basis(2, 2));
  262. }
  263. double total_pos_err =
  264. std::sqrt(std::pow(diff_com(0), 2) + std::pow(diff_com(1), 2) +
  265. std::pow(diff_com(2), 2) + std::pow(diff_basis(0, 0), 2) +
  266. std::pow(diff_basis(0, 1), 2) + std::pow(diff_basis(0, 2), 2) +
  267. std::pow(diff_basis(1, 0), 2) + std::pow(diff_basis(1, 1), 2) +
  268. std::pow(diff_basis(1, 2), 2) + std::pow(diff_basis(2, 0), 2) +
  269. std::pow(diff_basis(2, 1), 2) + std::pow(diff_basis(2, 2), 2));
  270. if (verbose) {
  271. printf("======kin-pos-err: %e\n", total_pos_err);
  272. }
  273. if (total_pos_err > *pos_error) {
  274. *pos_error = total_pos_err;
  275. }
  276. }
  277. return 0;
  278. }
  279. }