testspeed.cc 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. // Copyright 2021 DeepMind Technologies Limited
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <chrono>
  15. #include <cstdio>
  16. #include <cstring>
  17. #include <ratio>
  18. #include <string>
  19. #include <thread>
  20. #include <vector>
  21. #include <mujoco/mujoco.h>
  22. // maximum number of threads
  23. const int maxthread = 512;
  24. // model and per-thread data
  25. mjModel* m = NULL;
  26. mjData* d[maxthread];
  27. // per-thread statistics
  28. int contacts[maxthread];
  29. int constraints[maxthread];
  30. mjtNum simtime[maxthread];
  31. // timer
  32. mjtNum gettm(void) {
  33. using std::chrono::steady_clock;
  34. using Microseconds = std::chrono::duration<double, std::micro>;
  35. static steady_clock::time_point tm_start = steady_clock::now();
  36. auto elapsed = Microseconds(steady_clock::now() - tm_start);
  37. return elapsed.count();
  38. }
  39. // deallocate and print message
  40. int finish(const char* msg = NULL, mjModel* m = NULL) {
  41. // deallocate model
  42. if (m) {
  43. mj_deleteModel(m);
  44. }
  45. // print message
  46. if (msg) {
  47. std::printf("%s\n", msg);
  48. }
  49. return 0;
  50. }
  51. std::vector<mjtNum> CtrlNoise(const mjModel* m, int nsteps, mjtNum ctrlnoise) {
  52. std::vector<mjtNum> ctrl;
  53. for (int step=0; step < nsteps; step++) {
  54. for (int i = 0; i < m->nu; i++) {
  55. mjtNum center = 0.0;
  56. mjtNum radius = 1.0;
  57. mjtNum* range = m->actuator_ctrlrange + 2 * i;
  58. if (m->actuator_ctrllimited[i]) {
  59. center = (range[1] + range[0]) / 2;
  60. radius = (range[1] - range[0]) / 2;
  61. }
  62. radius *= ctrlnoise;
  63. ctrl.push_back(center + radius * (2 * mju_Halton(step, i+2) - 1));
  64. }
  65. }
  66. return ctrl;
  67. }
  68. // thread function
  69. void simulate(int id, int nstep, mjtNum* ctrl) {
  70. // clear statistics
  71. contacts[id] = 0;
  72. constraints[id] = 0;
  73. // run and time
  74. mjtNum start = gettm();
  75. for (int i=0; i < nstep; i++) {
  76. // inject pseudo-random control noise
  77. mju_copy(d[id]->ctrl, ctrl + i*m->nu, m->nu);
  78. // advance simulation
  79. mj_step(m, d[id]);
  80. // accumulate statistics
  81. contacts[id] += d[id]->ncon;
  82. constraints[id] += d[id]->nefc;
  83. }
  84. simtime[id] = 1e-6 * (gettm() - start);
  85. }
  86. // main function
  87. int main(int argc, char** argv) {
  88. // print help if arguments are missing
  89. if (argc < 2 || argc > 6) {
  90. return finish(
  91. "\n"
  92. "Usage: testspeed modelfile [nstep nthread ctrlnoise npoolthread]\n"
  93. "\n"
  94. " argument default semantic\n"
  95. " -------- ------- --------\n"
  96. " modelfile path to model (required)\n"
  97. " nstep 10000 number of steps per rollout\n"
  98. " nthread 1 number of threads running parallel rollouts\n"
  99. " ctrlnoise 0.01 scale of pseudo-random noise injected into actuators\n"
  100. " npoolthread 0 number of threads in engine-internal threadpool\n"
  101. "\n"
  102. "Note: If the model has a keyframe named \"test\", it will be loaded prior to simulation\n");
  103. }
  104. // read arguments
  105. int nstep = 10000, nthread = 0, npoolthread = 0;
  106. // inject small noise by default, to avoid fixed contact state
  107. double ctrlnoise = 0.01;
  108. if (argc > 2 && (std::sscanf(argv[2], "%d", &nstep) != 1 || nstep <= 0)) {
  109. return finish("Invalid nstep argument");
  110. }
  111. if (argc > 3 && std::sscanf(argv[3], "%d", &nthread) != 1) {
  112. return finish("Invalid nthread argument");
  113. }
  114. if (argc > 4 && std::sscanf(argv[4], "%lf", &ctrlnoise) != 1) {
  115. return finish("Invalid ctrlnoise argument");
  116. }
  117. if (argc > 5 && std::sscanf(argv[5], "%d", &npoolthread) != 1) {
  118. return finish("Invalid npoolthread argument");
  119. }
  120. // clamp ctrlnoise to [0.0, 1.0]
  121. ctrlnoise = mjMAX(0.0, mjMIN(ctrlnoise, 1.0));
  122. // clamp nthread to [1, maxthread]
  123. nthread = mjMAX(1, mjMIN(maxthread, nthread));
  124. npoolthread = mjMAX(1, mjMIN(maxthread, npoolthread));
  125. // get filename, determine file type
  126. std::string filename(argv[1]);
  127. bool binary = (filename.find(".mjb") != std::string::npos); // NOLINT
  128. // load model
  129. char error[1000] = "Could not load binary model";
  130. if (binary) {
  131. m = mj_loadModel(argv[1], 0);
  132. } else {
  133. m = mj_loadXML(argv[1], 0, error, 1000);
  134. }
  135. if (!m) {
  136. return finish(error);
  137. }
  138. // make per-thread data
  139. int testkey = mj_name2id(m, mjOBJ_KEY, "test");
  140. for (int id=0; id < nthread; id++) {
  141. // make mjData(s)
  142. d[id] = mj_makeData(m);
  143. if (!d[id]) {
  144. return finish("Could not allocate mjData", m);
  145. }
  146. // reset to keyframe
  147. if (testkey >= 0) {
  148. mj_resetDataKeyframe(m, d[id], testkey);
  149. }
  150. // make and bind threadpool
  151. if (npoolthread > 1) {
  152. mjThreadPool* threadpool = mju_threadPoolCreate(npoolthread);
  153. mju_bindThreadPool(d[id], threadpool);
  154. }
  155. }
  156. // install timer callback for profiling
  157. mjcb_time = gettm;
  158. // print start
  159. std::printf("\nRolling out %d steps%s, at dt = %g",
  160. nstep,
  161. nthread > 1 ? " per thread" : "",
  162. m->opt.timestep);
  163. if (sizeof(mjtNum) == 4) {
  164. std::printf(", using single-precision");
  165. }
  166. if (npoolthread > 1) {
  167. std::printf(", using %d threads", npoolthread);
  168. }
  169. std::printf("...\n\n");
  170. // create pseudo-random control sequence
  171. std::vector<mjtNum> ctrl = CtrlNoise(m, nstep, ctrlnoise);
  172. // run simulation, record total time
  173. std::thread th[maxthread];
  174. double starttime = gettm();
  175. for (int id=0; id < nthread; id++) {
  176. th[id] = std::thread(simulate, id, nstep, ctrl.data());
  177. }
  178. for (int id=0; id < nthread; id++) {
  179. th[id].join();
  180. }
  181. double tottime = 1e-6 * (gettm() - starttime); // total time, in seconds
  182. // all-thread summary
  183. constexpr char mu_str[3] = "\u00B5"; // unicode mu character
  184. if (nthread > 1) {
  185. std::printf("Summary for all %d threads\n\n", nthread);
  186. std::printf(" Total simulation time : %.2f s\n", tottime);
  187. std::printf(" Total steps per second : %.0f\n", nthread*nstep/tottime);
  188. std::printf(" Total realtime factor : %.2f x\n", nthread*nstep*m->opt.timestep/tottime);
  189. std::printf(" Total time per step : %.1f %ss\n\n", 1e6*tottime/(nthread*nstep), mu_str);
  190. std::printf("Details for thread 0\n\n");
  191. }
  192. // details for thread 0
  193. std::printf(" Simulation time : %.2f s\n", simtime[0]);
  194. std::printf(" Steps per second : %.0f\n", nstep/simtime[0]);
  195. std::printf(" Realtime factor : %.2f x\n", nstep*m->opt.timestep/simtime[0]);
  196. std::printf(" Time per step : %.1f %ss\n\n", 1e6*simtime[0]/nstep, mu_str);
  197. std::printf(" Contacts per step : %.2f\n", static_cast<float>(contacts[0])/nstep);
  198. std::printf(" Constraints per step : %.2f\n", static_cast<float>(constraints[0])/nstep);
  199. std::printf(" Degrees of freedom : %d\n\n", m->nv);
  200. // profiler, top-level
  201. printf(" Internal profiler%s, %ss per step\n", nthread > 1 ? " for thread 0" : "", mu_str);
  202. int number = d[0]->timer[mjTIMER_STEP].number;
  203. mjtNum tstep = number ? d[0]->timer[mjTIMER_STEP].duration/number : 0.0;
  204. mjtNum components = 0, total = 0;
  205. for (int i=0; i <= mjTIMER_ADVANCE; i++) {
  206. if (d[0]->timer[i].number > 0) {
  207. int number = d[0]->timer[i].number;
  208. mjtNum istep = number ? d[0]->timer[i].duration/number : 0.0;
  209. mjtNum percent = number ? 100*istep/tstep : 0.0;
  210. std::printf(" %17s : %6.1f (%6.2f %%)\n", mjTIMERSTRING[i], istep, percent);
  211. // save step time, add up timing of components
  212. if (i == 0) total = istep;
  213. if (i >= mjTIMER_POSITION) {
  214. components += istep;
  215. }
  216. }
  217. }
  218. // "other" (computation not covered by timers)
  219. if (tstep > 0) {
  220. mjtNum other = total - components;
  221. std::printf(" %17s : %6.1f (%6.2f %%)\n", "other", other, 100*other/tstep);
  222. }
  223. std::printf("\n");
  224. // mjTIMER_POSITION and its components
  225. for (int i : {mjTIMER_POSITION,
  226. mjTIMER_POS_KINEMATICS,
  227. mjTIMER_POS_INERTIA,
  228. mjTIMER_POS_COLLISION,
  229. mjTIMER_POS_MAKE,
  230. mjTIMER_POS_PROJECT}) {
  231. if (d[0]->timer[i].number > 0) {
  232. mjtNum istep = d[0]->timer[i].duration/d[0]->timer[i].number;
  233. if (i == mjTIMER_POSITION) {
  234. std::printf(" position total : %6.1f (%6.2f %%)\n", istep, 100*istep/tstep);
  235. } else {
  236. std::printf(" %-10s : %6.1f (%6.2f %%)\n",
  237. mjTIMERSTRING[i]+4, istep, 100*istep/tstep);
  238. }
  239. }
  240. // components of mjTIMER_POS_COLLISION
  241. if (i == mjTIMER_POS_COLLISION) {
  242. for (int j : {mjTIMER_COL_BROAD, mjTIMER_COL_NARROW}) {
  243. int number = d[0]->timer[j].number;
  244. mjtNum jstep = number ? d[0]->timer[j].duration/number : 0.0;
  245. mjtNum percent = number ? 100*jstep/tstep : 0.0;
  246. std::printf(" %-11s : %6.1f (%6.2f %%)\n", mjTIMERSTRING[j]+4, jstep, percent);
  247. }
  248. }
  249. }
  250. // free per-thread data
  251. for (int id=0; id < nthread; id++) {
  252. mjThreadPool* threadpool = (mjThreadPool*) d[id]->threadpool;
  253. mj_deleteData(d[id]);
  254. if (threadpool) {
  255. mju_threadPoolDestroy(threadpool);
  256. }
  257. }
  258. // finalize
  259. return finish();
  260. }