mjx.rst 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. .. _Mjx:
  2. ================
  3. MuJoCo XLA (MJX)
  4. ================
  5. Starting with version 3.0.0, MuJoCo includes MuJoCo XLA (MJX) under the
  6. `mjx <https://github.com/google-deepmind/mujoco/tree/main/mjx>`__ directory. MJX allows MuJoCo to run on compute
  7. hardware supported by the `XLA <https://www.tensorflow.org/xla>`__ compiler via the
  8. `JAX <https://github.com/google/jax#readme>`__ framework. MJX runs on a
  9. `all platforms supported by JAX <https://jax.readthedocs.io/en/latest/installation.html#supported-platforms>`__: Nvidia
  10. and AMD GPUs, Apple Silicon, and `Google Cloud TPUs <https://cloud.google.com/tpu>`__.
  11. The MJX API is consistent with the main simulation functions in the MuJoCo API, although it is currently missing some
  12. features. While the :ref:`API documentation <Mainsimulation>` is applicable to both libraries, we indicate features
  13. unsupported by MJX in the :ref:`notes <MjxFeatureParity>` below.
  14. MJX is distributed as a separate package called ``mujoco-mjx`` on `PyPI <https://pypi.org/project/mujoco-mjx>`__.
  15. Although it depends on the main ``mujoco`` package for model compilation and visualization, it is a re-implementation of
  16. MuJoCo that uses the same algorithms as the MuJoCo implementation. However, in order to properly leverage JAX, MJX
  17. deliberately diverges from the MuJoCo API in a few places, see below.
  18. MJX is a successor to the `generalized physics pipeline <https://github.com/google/brax/tree/main/brax/generalized>`__
  19. in Google's `Brax <https://github.com/google/brax>`__ physics and reinforcement learning library. MJX was built
  20. by core contributors to both MuJoCo and Brax, who will together continue to support both Brax (for its reinforcement
  21. learning algorithms and included environments) and MJX (for its physics algorithms). A future version of Brax will
  22. depend on the ``mujoco-mjx`` package, and Brax's existing
  23. `generalized pipeline <https://github.com/google/brax/tree/main/brax/generalized>`__ will be deprecated. This change
  24. will be largely transparent to users of Brax.
  25. .. _MjxNotebook:
  26. Tutorial notebook
  27. =================
  28. The following IPython notebook demonstrates the use of MJX along with reinforcement learning to train humanoid and
  29. quadruped robots to locomote: |colab|.
  30. .. |colab| image:: https://colab.research.google.com/assets/colab-badge.svg
  31. :target: https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb
  32. .. _MjxInstallation:
  33. Installation
  34. ============
  35. The recommended way to install this package is via `PyPI <https://pypi.org/project/mujoco-mjx/>`__:
  36. .. code-block:: shell
  37. pip install mujoco-mjx
  38. A copy of the MuJoCo library is provided as part of this package's depdendencies and does **not** need to be downloaded
  39. or installed separately.
  40. .. _MjxUsage:
  41. Basic usage
  42. ===========
  43. Once installed, the package can be imported via ``from mujoco import mjx``. Structs, functions, and enums are available
  44. directly from the top-level ``mjx`` module.
  45. .. _MjxStructs:
  46. Structs
  47. -------
  48. Before running MJX functions on an accelerator device, structs must be copied onto the device via the ``mjx.put_model`` and ``mjx.put_data``
  49. functions. Placing an :ref:`mjModel` on device yields an ``mjx.Model``. Placing an :ref:`mjData` on device yields
  50. an ``mjx.Data``:
  51. .. code-block:: python
  52. model = mujoco.MjModel.from_xml_string("...")
  53. data = mujoco.MjData(model)
  54. mjx_model = mjx.put_model(model)
  55. mjx_data = mjx.put_data(model, data)
  56. These MJX variants mirror their MuJoCo counterparts but have a few key differences:
  57. #. ``mjx.Model`` and ``mjx.Data`` contain JAX arrays that are copied onto device.
  58. #. Some fields are missing from ``mjx.Model`` and ``mjx.Data`` for features that are
  59. :ref:`unsupported <mjxFeatureParity>` in MJX.
  60. #. JAX arrays in ``mjx.Model`` and ``mjx.Data`` support adding batch dimensions. Batch dimensions are a natural way to
  61. express domain randomization (in the case of ``mjx.Model``) or high-throughput simulation for reinforcement learning
  62. (in the case of ``mjx.Data``).
  63. #. Numpy arrays in ``mjx.Model`` and ``mjx.Data`` are structural fields that control the output of JIT compilation.
  64. Modifying these arrays will force JAX to recompile MJX functions. As an example,
  65. ``jnt_limited`` is a numpy array passed by reference from :ref:`mjModel`, which determines if joint limit
  66. constraints should be applied. If ``jnt_limited`` is modified, JAX will
  67. re-compile MJX functions.
  68. On the other hand, ``jnt_range`` is a JAX array that can be modified at runtime, and will only apply to joints with limits
  69. as specified by the ``jnt_limited`` field.
  70. Neither ``mjx.Model`` nor ``mjx.Data`` are meant to be constructed manually. An ``mjx.Data`` may be created by calling
  71. ``mjx.make_data``, which mirrors the :ref:`mj_makeData` function in MuJoCo:
  72. .. code-block:: python
  73. model = mujoco.MjModel.from_xml_string("...")
  74. mjx_model = mjx.put_model(model)
  75. mjx_data = mjx.make_data(model)
  76. Using ``mjx.make_data`` may be preferable when constructing batched ``mjx.Data`` structures inside of a ``vmap``.
  77. .. _MjxFunctions:
  78. Functions
  79. ---------
  80. MuJoCo functions are exposed as MJX functions of the same name, but following
  81. `PEP 8 <https://peps.python.org/pep-0008/>`__-compliant names. Most of the :ref:`main simulation <Mainsimulation>` and
  82. some of the :ref:`sub-components <Subcomponents>` for forward simulation are available from the top-level ``mjx`` module.
  83. MJX functions are not `JIT compiled <https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html>`__ by default -- we
  84. leave it to the user to JIT MJX functions, or JIT their own functions that reference MJX functions. See the
  85. :ref:`minimal example <MjxExample>` below.
  86. .. _MjxEnums:
  87. Enums and constants
  88. -------------------
  89. MJX enums are available as ``mjx.EnumType.ENUM_VALUE``, for example ``mjx.JointType.FREE``. Enums for unsupported MJX
  90. features are omitted from the MJX enum declaration. MJX declares no constants but references MuJoCo constants directly.
  91. .. _MjxExample:
  92. Minimal example
  93. ---------------
  94. .. code-block:: python
  95. # Throw a ball at 100 different velocities.
  96. import jax
  97. import mujoco
  98. from mujoco import mjx
  99. XML=r"""
  100. <mujoco>
  101. <worldbody>
  102. <body>
  103. <freejoint/>
  104. <geom size=".15" mass="1" type="sphere"/>
  105. </body>
  106. </worldbody>
  107. </mujoco>
  108. """
  109. model = mujoco.MjModel.from_xml_string(XML)
  110. mjx_model = mjx.put_model(model)
  111. @jax.vmap
  112. def batched_step(vel):
  113. mjx_data = mjx.make_data(mjx_model)
  114. qvel = mjx_data.qvel.at[0].set(vel)
  115. mjx_data = mjx_data.replace(qvel=qvel)
  116. pos = mjx.step(mjx_model, mjx_data).qpos[0]
  117. return pos
  118. vel = jax.numpy.arange(0.0, 1.0, 0.01)
  119. pos = jax.jit(batched_step)(vel)
  120. print(pos)
  121. .. _MjxFeatureParity:
  122. Feature Parity
  123. ==============
  124. MJX supports most of the main simulation features of MuJoCo, with a few exceptions. MJX will raise an exception if
  125. asked to copy to device an :ref:`mjModel` with field values referencing unsupported features.
  126. The following features are **fully supported** in MJX:
  127. .. list-table::
  128. :width: 90%
  129. :align: left
  130. :widths: 2 5
  131. :header-rows: 1
  132. * - Category
  133. - Feature
  134. * - Dynamics
  135. - :ref:`Forward <mj_forward>`
  136. * - :ref:`Joint <mjtJoint>`
  137. - ``FREE``, ``BALL``, ``SLIDE``, ``HINGE``
  138. * - :ref:`Transmission <mjtTrn>`
  139. - ``JOINT``, ``JOINTINPARENT``, ``SITE``, ``TENDON``
  140. * - :ref:`Actuator Dynamics <mjtDyn>`
  141. - ``NONE``, ``INTEGRATOR``, ``FILTER``, ``FILTEREXACT``
  142. * - :ref:`Actuator Gain <mjtGain>`
  143. - ``FIXED``, ``AFFINE``
  144. * - :ref:`Actuator Bias <mjtBias>`
  145. - ``NONE``, ``AFFINE``
  146. * - :ref:`Tendon Wrapping <mjtWrap>`
  147. - ``JOINT``, ``SITE``, ``PULLEY``
  148. * - :ref:`Geom <mjtGeom>`
  149. - ``PLANE``, ``HFIELD``, ``SPHERE``, ``CAPSULE``, ``BOX``, ``MESH`` are fully implemented. ``ELLIPSOID`` and
  150. ``CYLINDER`` are implemented but only collide with other primitives, note that ``BOX`` is implemented as a mesh.
  151. * - :ref:`Constraint <mjtConstraint>`
  152. - ``EQUALITY``, ``LIMIT_JOINT``, ``CONTACT_FRICTIONLESS``, ``CONTACT_PYRAMIDAL``, ``CONTACT_ELLIPTIC``, ``FRICTION_DOF``, ``FRICTION_TENDON``
  153. * - :ref:`Equality <mjtEq>`
  154. - ``CONNECT``, ``WELD``, ``JOINT``, ``TENDON``
  155. * - :ref:`Integrator <mjtIntegrator>`
  156. - ``EULER``, ``RK4``, ``IMPLICITFAST`` (``IMPLICITFAST`` not supported with :doc:`fluid drag <computation/fluid>`)
  157. * - :ref:`Cone <mjtCone>`
  158. - ``PYRAMIDAL``, ``ELLIPTIC``
  159. * - :ref:`Condim <coContact>`
  160. - 1, 3, 4, 6
  161. * - :ref:`Solver <mjtSolver>`
  162. - ``CG``, ``NEWTON``
  163. * - Fluid Model
  164. - :ref:`flInertia`
  165. * - :ref:`Tendons <tendon>`
  166. - :ref:`Fixed <tendon-fixed>`
  167. * - :ref:`Sensors <mjtSensor>`
  168. - ``MAGNETOMETER``, ``CAMPROJECTION``, ``RANGEFINDER``, ``JOINTPOS``, ``TENDONPOS``, ``ACTUATORPOS``, ``BALLQUAT``,
  169. ``FRAMEPOS``, ``FRAMEXAXIS``, ``FRAMEYAXIS``, ``FRAMEZAXIS``, ``FRAMEQUAT``, ``SUBTREECOM``, ``CLOCK``,
  170. ``VELOCIMETER``, ``GYRO``, ``JOINTVEL``, ``TENDONVEL``, ``ACTUATORVEL``, ``BALLANGVEL``, ``FRAMELINVEL``,
  171. ``FRAMEANGVEL``, ``SUBTREELINVEL``, ``SUBTREEANGMOM``, ``TOUCH``, ``ACCELEROMETER``, ``FORCE``, ``TORQUE``,
  172. ``ACTUATORFRC``, ``JOINTACTFRC``, ``FRAMELINACC``, ``FRAMEANGACC``.
  173. The following features are **in development** and coming soon:
  174. .. list-table::
  175. :width: 90%
  176. :align: left
  177. :widths: 2 5
  178. :header-rows: 1
  179. * - Category
  180. - Feature
  181. * - :ref:`Geom <mjtGeom>`
  182. - ``SDF``. Collisions between (``SPHERE``, ``BOX``, ``MESH``, ``HFIELD``) and ``CYLINDER``. Collisions between
  183. (``BOX``, ``MESH``, ``HFIELD``) and ``ELLIPSOID``.
  184. * - :ref:`Integrator <mjtIntegrator>`
  185. - ``IMPLICIT``
  186. * - Dynamics
  187. - :ref:`Inverse <mj_inverse>`
  188. * - :ref:`Actuator Dynamics <mjtDyn>`
  189. - ``MUSCLE``
  190. * - :ref:`Actuator Gain <mjtGain>`
  191. - ``MUSCLE``
  192. * - :ref:`Actuator Bias <mjtBias>`
  193. - ``MUSCLE``
  194. * - :ref:`Tendon Wrapping <mjtWrap>`
  195. - ``SPHERE``, ``CYLINDER``
  196. * - Fluid Model
  197. - :ref:`flEllipsoid`
  198. * - :ref:`Tendons <tendon>`
  199. - :ref:`Spatial <tendon-spatial>`
  200. * - :ref:`Sensors <mjtSensor>`
  201. - All except ``PLUGIN``, ``USER``
  202. * - Lights
  203. - Positions and directions of lights
  204. The following features are **unsupported**:
  205. .. list-table::
  206. :width: 90%
  207. :align: left
  208. :widths: 2 5
  209. :header-rows: 1
  210. * - Category
  211. - Feature
  212. * - :ref:`margin<body-geom-margin>` and :ref:`gap<body-geom-gap>`
  213. - Unimplemented for collisions with ``Mesh`` :ref:`Geom <mjtGeom>`.
  214. * - :ref:`Transmission <mjtTrn>`
  215. - ``SLIDERCRANK``, ``BODY``
  216. * - :ref:`Actuator Dynamics <mjtDyn>`
  217. - ``USER``
  218. * - :ref:`Actuator Gain <mjtGain>`
  219. - ``USER``
  220. * - :ref:`Actuator Bias <mjtBias>`
  221. - ``USER``
  222. * - :ref:`Solver <mjtSolver>`
  223. - ``PGS``
  224. * - :ref:`Sensors <mjtSensor>`
  225. - ``PLUGIN``, ``USER``
  226. .. _MjxSharpBits:
  227. 🔪 MJX - The Sharp Bits 🔪
  228. ==========================
  229. GPUs and TPUs have unique performance tradeoffs that MJX is subject to. MJX specializes in simulating big batches of
  230. parallel identical physics scenes using algorithms that can be efficiently vectorized on
  231. `SIMD hardware <https://en.wikipedia.org/wiki/Single_instruction,_multiple_data>`__. This specialization is useful
  232. for machine learning workloads such as `reinforcement learning <https://en.wikipedia.org/wiki/Reinforcement_learning>`__
  233. that require massive data throughput.
  234. There are certain workflows that MJX is ill-suited for:
  235. Single scene simulation
  236. Simulating a single scene (1 instance of :ref:`mjData`), MJX can be **10x** slower than MuJoCo, which has been
  237. carefully optimized for CPU. MJX works best when simulating thousands or tens of thousands of scenes in parallel.
  238. Collisions between large meshes
  239. MJX supports collisions between convex mesh geometries. However the convex collision algorithms
  240. in MJX are implemented differently than in MuJoCo. MJX uses a branchless version of the
  241. `Separating Axis Test <https://ubm-twvideo01.s3.amazonaws.com/o1/vault/gdc2013/slides/822403Gregorius_Dirk_TheSeparatingAxisTest.pdf>`__
  242. (SAT) to determine if geometries are colliding with convex meshes, while MuJoCo uses the Minkowski Portal Refinement (MPR)
  243. algorithm as implemented in `libccd <https://github.com/danfis/libccd>`__.
  244. SAT works well for smaller meshes but suffers in both runtime and memory for larger meshes.
  245. For
  246. collisions with convex meshes and primitives, the convex decompositon of the mesh should have
  247. roughly **200 vertices or less** for reasonable performance. For convex-convex collisions,
  248. the convex mesh should have roughly **fewer than 32 vertices**. We recommend using
  249. :ref:`maxhullvert<asset-mesh-maxhullvert>` in the MuJoCo compiler to achieve desired convex mesh properties.
  250. With careful
  251. tuning, MJX can simulate scenes with mesh collisions -- see the MJX
  252. `shadow hand <https://github.com/google-deepmind/mujoco/tree/main/mjx/mujoco/mjx/test_data/shadow_hand>`__
  253. config for an example. Speeding up mesh collision detection is an active area of development for MJX.
  254. Large, complex scenes with many contacts
  255. Accelerators exhibit poor performance for
  256. `branching code <https://aschrein.github.io/jekyll/update/2019/06/13/whatsup-with-my-branches-on-gpu.html#tldr>`__.
  257. Branching is used in broad-phase collision detection, when identifying potential collisions between large numbers of
  258. bodies in a scene. MJX ships with a simple branchless broad-phase algorithm (see performance tuning) but it is not as
  259. powerful as the one in MuJoCo.
  260. To see how this affects simulation, let us consider a physics scene with increasing numbers of humanoid bodies,
  261. varied from 1 to 10. We simulate this scene using CPU MuJoCo on an Apple M3 Max and a 64-core AMD 3995WX and time
  262. it using :ref:`testspeed<saTestspeed>`, using ``2 x numcore`` threads. We time the MJX simulation on an Nvidia
  263. A100 GPU using a batch size of 8192 and an 8-chip
  264. `v5 TPU <https://cloud.google.com/blog/products/compute/announcing-cloud-tpu-v5e-and-a3-gpus-in-ga>`__
  265. machine using a batch size of 16384. Note the vertical scale is logarithmic.
  266. .. figure:: images/mjx/SPS.svg
  267. :width: 95%
  268. :align: center
  269. The values for a single humanoid (leftmost datapoints) for the four timed architectures are **650K**, **1.8M**,
  270. **950K** and **2.7M** steps per second, respectively. Note that as we increase the number of humanoids (which
  271. increases the number of potential contacts in a scene), MJX throughput decreases more rapidly than MuJoCo.
  272. .. _MjxPerformance:
  273. Performance tuning
  274. ==================
  275. For MJX to perform well, some configuration parameters should be adjusted from their default MuJoCo values:
  276. :ref:`option/iterations<option-iterations>` and :ref:`option/ls_iterations<option-ls_iterations>`
  277. The :ref:`iterations<option-iterations>` and :ref:`ls_iterations<option-ls_iterations>` attributes---which control
  278. solver and linesearch iterations, respectively---should be brought down to just low enough that the simulation remains
  279. stable. Accurate solver forces are not so important in reinforcement learning in which domain randomization is often
  280. used to add noise to physics for sim-to-real. The ``NEWTON`` :ref:`Solver <mjtSolver>` delivers excellent convergence
  281. with very few (often just one) solver iterations, and performs well on GPU. ``CG`` is currently a better choice for
  282. TPU.
  283. :ref:`contact/pair<contact-pair>`
  284. Consider explicitly marking geoms for collision detection to reduce the number of contacts that MJX must consider
  285. during each step. Enabling only an explicit list of valid contacts can have a dramatic effect on simulation
  286. performance in MJX. Doing this well often requires an understanding of the task -- for example, the
  287. `OpenAI Gym Humanoid <https://github.com/openai/gym/blob/master/gym/envs/mujoco/humanoid_v4.py>`__ task resets when
  288. the humanoid starts to fall, so full contact with the floor is not needed.
  289. :ref:`maxhullvert<asset-mesh-maxhullvert>`
  290. Set :ref:`maxhullvert<asset-mesh-maxhullvert>` to `64` or less for better convex mesh collision performance.
  291. :ref:`option/flag/eulerdamp<option-flag-eulerdamp>`
  292. Disabling ``eulerdamp`` can help performance and is often not needed for stability. Read the
  293. :ref:`Numerical Integration<geIntegration>` section for details regarding the semantics of this flag.
  294. :ref:`option/jacobian<option-jacobian>`
  295. Explicitly setting "dense" or "sparse" may speed up simulation depending on your device. Modern TPUs have specialized
  296. hardware for rapidly operating over sparse matrices, whereas GPUs tend to be faster with dense matrices as long as
  297. they fit onto the device. As such, the behavior in MJX for the default "auto" setting is sparse if ``nv >= 60`` (60 or
  298. more degrees of freedom), or if MJX detects a TPU as the default backend, otherwise "dense". For TPU, using "sparse"
  299. with the Newton solver can speed up simulation by 2x to 3x. For GPU, choosing "dense" may impart a more modest speedup
  300. of 10% to 20%, as long as the dense matrices can fit on the device.
  301. Broadphase
  302. While MuJoCo handles broadphase culling out of the box, MJX requires additional parameters. For an approximate version of
  303. broadphase, use the experimental custom numeric parameters
  304. ``max_contact_points`` and ``max_geom_pairs``. ``max_contact_points`` caps the number of contact points
  305. sent to the solver for each condim type. ``max_geom_pairs`` caps the total number of geom-pairs sent to
  306. respective collision functions for each geom-type pair. As an example, the
  307. `shadow hand <https://github.com/google-deepmind/mujoco/tree/main/mjx/mujoco/mjx/test_data/shadow_hand>`__
  308. environment makes use of these parameters.
  309. GPU performance
  310. ---------------
  311. The following environment variables should be set:
  312. ``XLA_FLAGS=--xla_gpu_triton_gemm_any=true``
  313. This enables the Triton-based GEMM (matmul) emitter for any GEMM that it supports. This can yield a 30% speedup on
  314. NVIDIA GPUs. If you have multiple GPUs, you may also benefit from enabling flags related to
  315. `communciation between GPUs <https://jax.readthedocs.io/en/latest/gpu_performance_tips.html>`__.