Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MJX step gives wrong results or fails to update the state (depending on init conditions) #2361

Open
2 tasks done
alaurenzi opened this issue Jan 15, 2025 · 1 comment
Open
2 tasks done
Assignees
Labels
bug Something isn't working

Comments

@alaurenzi
Copy link

alaurenzi commented Jan 15, 2025

Intro

Hi!

I am a software engineer at Istituto Italiano di Tecnologia (IIT), we use MuJoCo for RL.

My setup

Mujoco version

(mjx1) alaurenzi@alaurenzi-iit-desktop:~/code/mjx_playground/playground/mre_bug_noupdate $ pip show mujoco-mjx
WARNING: Skipping /usr/lib/python3/dist-packages/pytest_repeat.egg-info due to invalid metadata entry 'name'
Name: mujoco-mjx
Version: 3.2.7
Summary: MuJoCo XLA (MJX)
[...]

OS

(mjx1) alaurenzi@alaurenzi-iit-desktop:~/code/mjx_playground/playground/mre_bug_noupdate $ lsb_release -a
No LSB modules are available.
Distributor ID:	Ubuntu
Description:	Ubuntu 24.04.1 LTS
Release:	24.04
Codename:	noble

Kernel / arch

(mjx1) alaurenzi@alaurenzi-iit-desktop:~/code/mjx_playground/playground/mre_bug_noupdate $ uname -a
Linux alaurenzi-iit-desktop 6.8.0-51-generic #52-Ubuntu SMP PREEMPT_DYNAMIC Thu Dec  5 13:09:44 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux

Output of `pip freeze` ``` absl-py==2.1.0 alabaster==0.7.12 appdirs==1.4.4 argcomplete==3.1.4 asttokens==2.4.1 attrs==23.2.0 Babel==2.10.3 bcc==0.29.1 beautifulsoup4==4.12.3 blinker==1.7.0 bloom==0.12.0 breezy==3.3.5 Brlapi==0.8.5 Brotli==1.1.0 catkin-pkg==1.0.0 catkin-pkg-modules==1.0.0 certifi==2023.11.17 chardet==5.2.0 click==8.1.6 cloud-init==24.4 colcon-argcomplete==0.3.3 colcon-bash==0.5.0 colcon-cd==0.2.1 colcon-cmake==0.2.29 colcon-common-extensions==0.3.0 colcon-core==0.18.4 colcon-defaults==0.2.9 colcon-devtools==0.3.0 colcon-installed-package-information==0.2.1 colcon-library-path==0.2.1 colcon-metadata==0.2.5 colcon-mixin==0.2.3 colcon-notification==0.3.0 colcon-output==0.2.13 colcon-override-check==0.0.1 colcon-package-information==0.4.0 colcon-package-selection==0.2.10 colcon-parallel-executor==0.3.0 colcon-pkg-config==0.1.0 colcon-powershell==0.4.0 colcon-python-setup-py==0.2.9 colcon-recursive-crawl==0.2.3 colcon-ros==0.5.0 colcon-test-result==0.3.8 colcon-zsh==0.5.0 colorama==0.4.6 command-not-found==0.3 configobj==5.0.8 contourpy==1.0.7 coverage==7.4.4 cryptography==41.0.7 cssselect==1.2.0 cupshelpers==1.0 cycler==0.11.0 dbus-python==1.3.2 decorator==5.1.1 defer==1.0.6 Deprecated==1.2.14 distlib==0.3.8 distro==1.9.0 distro-info==1.7+build1 docutils==0.20.1 dulwich==0.21.6 empy==3.3.4 etils==1.11.0 executing==2.0.1 fastbencode==0.2 fastimport==0.9.14 flake8==7.0.0 flake8-blind-except==0.2.1 flake8-builtins==2.1.0 flake8-class-newline==1.6.0 flake8-comprehensions==3.14.0 flake8-deprecated==2.2.1 flake8-docstrings==1.6.0 flake8-import-order==0.18.2 flake8-quotes==3.4.0 fonttools==4.46.0 fs==2.4.16 fsspec==2024.12.0 furo==0.0.0 glfw==2.8.0 gpg==1.18.0 html5lib==1.1 httplib2==0.20.4 idna==3.6 imagesize==1.4.1 importlib-metadata==4.12.0 importlib_resources==6.5.2 iniconfig==1.1.1 ipython==8.20.0 jax==0.4.38 jax-cuda12-pjrt==0.4.38 jax-cuda12-plugin==0.4.38 jaxlib==0.4.38 jedi==0.19.1 Jinja2==3.1.2 jsonpatch==1.32 jsonpointer==2.0 jsonschema==4.10.3 kiwisolver==0.0.0 language-selector==0.1 lark==1.1.9 launchpadlib==1.11.0 lazr.restfulclient==0.14.6 lazr.uri==1.0.6 louis==3.29.0 lxml==5.2.1 lz4==4.0.2+dfsg markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.6.3 matplotlib-inline==0.1.6 mccabe==0.7.0 mdurl==0.1.2 mercurial==6.7.2 merge3==0.0.8 meson==1.3.2 ml_dtypes==0.5.1 more-itertools==10.2.0 mpi4py==3.1.5 mpmath==0.0.0 mujoco==3.2.7 mujoco-mjx==3.2.7 mypy==1.9.0 mypy-extensions==1.0.0 netaddr==0.8.0 netifaces==0.11.0 notify2==0.3 ntplib==0.3.3 numpy==1.26.4 nvidia-cublas-cu12==12.6.4.1 nvidia-cuda-cupti-cu12==12.6.80 nvidia-cuda-nvcc-cu12==12.6.85 nvidia-cuda-runtime-cu12==12.6.77 nvidia-cudnn-cu12==9.6.0.74 nvidia-cufft-cu12==11.3.0.4 nvidia-cusolver-cu12==11.7.1.2 nvidia-cusparse-cu12==12.5.4.2 nvidia-nccl-cu12==2.24.3 nvidia-nvjitlink-cu12==12.6.85 oauthlib==3.2.2 olefile==0.46 opt_einsum==3.4.0 packaging==24.0 parso==0.8.3 patiencediff==0.2.13 pexpect==4.9.0 pillow==10.2.0 pluggy==1.4.0 prompt-toolkit==3.0.43 protobuf==4.21.12 psutil==5.9.8 ptyprocess==0.7.0 pure-eval==0.0.0 PyAudio==0.2.13 pycairo==1.25.1 pycodestyle==2.11.1 pycups==2.0.1 pydocstyle==6.3.0 pydot==1.4.2 pyflakes==3.2.0 PyGithub==2.2.0 Pygments==2.17.2 PyGObject==3.48.2 PyJWT==2.7.0 PyNaCl==1.5.0 PyOpenGL==3.1.7 pyparsing==3.1.1 PyQt5==5.15.10 PyQt5-sip==12.13.0 pyrsistent==0.20.0 pyserial==3.5 pytest==7.4.4 pytest-cov==4.1.0 pytest-mock==3.12.0 pytest-repeat==0.9.3 pytest-rerunfailures==12.0 pytest-runner==2.11.1 pytest-timeout==2.2.0 python-apt==2.7.7+ubuntu3 python-dateutil==2.8.2 python-debian==0.1.49+ubuntu2 pytz==2024.1 pyxdg==0.28 PyYAML==6.0.1 requests==2.31.0 rich==13.7.1 roman==3.3 rosdep==0.25.1 rosdep-modules==0.25.1 rosdistro==1.0.1 rosdistro-modules==1.0.1 rospkg-modules==1.5.1 SciPy==1.11.4 screen-resolution-extra==0.0.0 semver==2.10.2 setuptools==68.1.2 six==1.16.0 snowballstemmer==2.2.0 soupsieve==2.5 Sphinx==7.2.6 sphinx-basic-ng==1.0.0b2 sphinx_inline_tabs==2023.4.21 ssh-import-id==5.11 stack-data==0.6.3 sympy==1.12 systemd-python==235 termcolor==1.1.0 terminator==2.1.3 tqdm==0.0.0 traitlets==5.5.0 transforms3d==0.4.1 trimesh==4.5.3 typeguard==4.1.5 types-aiofiles==23.2 types-aws-xray-sdk==2.12 types-beautifulsoup4==4.12 types-bleach==6.1 types-boltons==23.0 types-boto==2.49 types-braintree==4.24 types-cachetools==5.3 types-caldav==1.3 types-cffi==1.16 types-chevron==0.14 types-click-default-group==1.2 types-click-spinner==0.1 types-colorama==0.4 types-commonmark==0.9 types-console-menu==0.8 types-croniter==2.0 types-dateparser==1.1 types-decorator==5.1 types-Deprecated==1.2 types-dockerfile-parse==2.0 types-docopt==0.6 types-docutils==0.20 types-editdistance==0.6 types-entrypoints==0.4 types-ExifRead==3.0 types-first==2.0 types-flake8-2020==1.8 types-flake8-bugbear==23.9.16 types-flake8-builtins==2.2 types-flake8-docstrings==1.7 types-flake8-plugin-utils==1.3 types-flake8-rst-docstrings==0.3 types-flake8-simplify==0.21 types-flake8-typing-imports==1.15 types-Flask-Cors==4.0 types-Flask-Migrate==4.0 types-Flask-SocketIO==5.3 types-fpdf2==2.7.4 types-gdb==12.1 types-google-cloud-ndb==2.2 types-greenlet==3.0 types-hdbcli==2.18 types-html5lib==1.1 types-httplib2==0.22 types-humanfriendly==10.0 types-ibm-db==3.2 types-influxdb-client==1.38 types-inifile==0.4 types-JACK-Client==0.5 types-jmespath==1.0 types-jsonschema==4.19 types-keyboard==0.13 types-ldap3==2.9 types-libsass==0.22 types-Markdown==3.5 types-mock==5.1 types-mypy-extensions==1.0 types-mysqlclient==2.2 types-netaddr==0.9 types-oauthlib==3.2 types-openpyxl==3.1 types-opentracing==2.4 types-paho-mqtt==1.6 types-paramiko==3.3 types-parsimonious==0.10 types-passlib==1.7 types-passpy==1.0 types-peewee==3.17 types-pep8-naming==0.13 types-pexpect==4.8 types-pika-ts==1.3 types-Pillow==10.1 types-playsound==1.3 types-pluggy==1.2.0 types-polib==1.2 types-portpicker==1.6 types-protobuf==4.24 types-psutil==5.9 types-psycopg2==2.9 types-pyasn1==0.5 types-pyaudio==0.2 types-PyAutoGUI==0.9 types-pycocotools==2.0 types-pycurl==7.45.2 types-pyfarmhash==0.3 types-pyflakes==3.1 types-Pygments==2.16 types-pyinstaller==6.1 types-pyjks==20.0 types-PyMySQL==1.1 types-pynput==1.7 types-pyOpenSSL==23.3 types-pyRFC3339==1.1 types-PyScreeze==0.1.29 types-pyserial==3.5 types-pysftp==0.2 types-pytest-lazy-fixture==0.6 types-python-crontab==3.0 types-python-datemath==1.5 types-python-dateutil==2.8 types-python-gflags==3.1 types-python-jose==3.3 types-python-nmap==0.7 types-python-slugify==8.0 types-python-xlib==0.33 types-pytz==2023.3.post1 types-pywin32==306 types-pyxdg==0.28 types-PyYAML==6.0 types-qrcode==7.4 types-redis==4.6.0 types-regex==2023.10.3 types-requests==2.31 types-requests-oauthlib==1.3 types-retry==0.9 types-s2clientprotocol==5 types-seaborn==0.13 types-Send2Trash==1.8 types-setuptools==68.2 types-simplejson==3.19 types-singledispatch==4.1 types-six==1.16 types-slumber==0.7 types-stdlib-list==0.8 types-stripe==3.5 types-tabulate==0.9 types-tensorflow==2.12 types-toml==0.10 types-toposort==1.10 types-tqdm==4.66 types-translationstring==1.4 types-tree-sitter==0.20.1 types-tree-sitter-languages==1.8 types-ttkthemes==3.2 types-tzlocal==5.1 types-ujson==5.8 types-untangle==1.2 types-usersettings==1.1 types-uWSGI==2.0 types-vobject==0.9 types-waitress==2.1 types-WebOb==1.8 types-whatthepatch==1.0 types-workalendar==17.0 types-WTForms==3.1 types-xmltodict==0.13 types-zstd==1.5 types-zxcvbn==4.4 typing_extensions==4.10.0 tzlocal==5.2 ubuntu-drivers-common==0.0.0 ubuntu-pro-client==8001 ufoLib2==0.16.0 ufw==0.36.2 unattended-upgrades==0.1 unicodedata2==15.1.0 urllib3==2.0.7 vcstool==0.3.0 vcstools==0.1.42 wadllib==1.3.6 wcwidth==0.2.5 webencodings==0.5.1 wheel==0.42.0 wrapt==1.15.0 xdg==5 xkit==0.0.0 zipp==1.0.0 ```

What's happening? What did you expect?

Depending on initial conditions, MJX's mj_step either gives wrong results (compared to plain Mujoco) or does not update the state (qpos, qvel) at all. The model is a plain cartpole, which I am attaching with the rest of the code / requirements.txt.

The MRE is very simple as expected:

import jax
from jax import numpy as jp

import mujoco
from mujoco import mjx
from mujoco import viewer as mj_viewer

# print device
print(jax.devices())

# open mjcf (just the standard cartpole model)
xml = open('cartpole.xml', 'r').read()

# make model
mj_model = mujoco.MjModel.from_xml_string(xml)

# make model, data, and viewer
mj_data = mujoco.MjData(mj_model)

# NOTE!! uncommenting this changes the output O.O
# mj_viewer_handle = mj_viewer.launch_passive(mj_model, mj_data)  uncommenting this 

# to gpu
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

# vectorize the mjxdata over some envs, set initial condition
num_envs = 1
qpos = jp.array([0.20, -0.25])

# only used for random initial conditions, not used for this test
# as we fix initial conditions
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, num_envs)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=qpos))(rng)
mj_data.qpos = qpos

# jit mjx step
jit_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))

for i in range(20):
    mujoco.mj_step(mj_model, mj_data)
    batch = jit_step(mjx_model, batch)
    print(i, ' mjx', batch.qpos)
    print(i, ' mj', mj_data.qpos)
    print(i, ' err', mj_data.qpos - batch.qpos[0])
    print('---')

Output on my system:

[CudaDevice(id=0)]
0  mjx [[ 0.2  -0.25]]
0  mj [ 0.20064749 -0.25108757]
0  err [ 0.00064749 -0.00108758]
---
1  mjx [[ 0.2  -0.25]]
1  mj [ 0.20194727 -0.25344699]
1  err [ 0.00194727 -0.003447  ]
---
2  mjx [[ 0.2  -0.25]]
2  mj [ 0.2034819  -0.25649398]
2  err [ 0.00348189 -0.00649399]
---
3  mjx [[ 0.2  -0.25]]
3  mj [ 0.20510108 -0.26002131]
3  err [ 0.00510108 -0.0100213 ]
---
4  mjx [[ 0.2  -0.25]]
4  mj [ 0.20675072 -0.26395774]
4  err [ 0.00675072 -0.01395774]
---
5  mjx [[ 0.2  -0.25]]
5  mj [ 0.20841131 -0.26828134]
5  err [ 0.0084113  -0.01828134]
---
6  mjx [[ 0.2  -0.25]]
6  mj [ 0.21007585 -0.2729882 ]
6  err [ 0.01007585 -0.0229882 ]
---
7  mjx [[ 0.2  -0.25]]
7  mj [ 0.21174182 -0.27808125]
7  err [ 0.01174182 -0.02808124]
---
8  mjx [[ 0.2  -0.25]]
8  mj [ 0.2134083 -0.2835662]
8  err [ 0.01340829 -0.03356621]
---
9  mjx [[ 0.2  -0.25]]
9  mj [ 0.21507496 -0.28945009]
9  err [ 0.01507495 -0.03945008]
---
10  mjx [[ 0.2  -0.25]]
10  mj [ 0.21674169 -0.2957408 ]
10  err [ 0.01674169 -0.04574081]
---
11  mjx [[ 0.2  -0.25]]
11  mj [ 0.21840845 -0.30244686]
11  err [ 0.01840845 -0.05244687]
---
12  mjx [[ 0.2  -0.25]]
12  mj [ 0.22007522 -0.30957739]
12  err [ 0.02007522 -0.05957738]
---
13  mjx [[ 0.2  -0.25]]
13  mj [ 0.22174199 -0.31714205]
13  err [ 0.02174199 -0.06714204]
---
14  mjx [[ 0.2  -0.25]]
14  mj [ 0.22340876 -0.3251511 ]
14  err [ 0.02340876 -0.07515112]
---
15  mjx [[ 0.2  -0.25]]
15  mj [ 0.22507554 -0.33361538]
15  err [ 0.02507554 -0.08361536]
---
16  mjx [[ 0.2  -0.25]]
16  mj [ 0.22674232 -0.34254628]
16  err [ 0.02674231 -0.09254628]
---
17  mjx [[ 0.2  -0.25]]
17  mj [ 0.2284091  -0.35195582]
17  err [ 0.02840909 -0.10195583]
---
18  mjx [[ 0.2  -0.25]]
18  mj [ 0.23007588 -0.3618566 ]
18  err [ 0.03007588 -0.11185661]
---
19  mjx [[ 0.2  -0.25]]
19  mj [ 0.23174266 -0.37226185]
19  err [ 0.03174266 -0.12226185]
---

As you can see, the MJX qpos does not change at all. But it gets more weird: if I uncomment the line that just launches a passive viewer (doing nothing in the background), then the result changes to the following:

[CudaDevice(id=0)]
0  mjx [[ 0.20000008 -0.25017637]]
0  mj [ 0.20064749 -0.25108757]
0  err [ 0.00064741 -0.00091121]
---
1  mjx [[ 0.2000003  -0.25070545]]
1  mj [ 0.20194727 -0.25344699]
1  err [ 0.00194697 -0.00274155]
---
2  mjx [[ 0.20000066 -0.25158724]]
2  mj [ 0.2034819  -0.25649398]
2  err [ 0.00348124 -0.00490674]
---
3  mjx [[ 0.20000117 -0.25282174]]
3  mj [ 0.20510108 -0.26002131]
3  err [ 0.00509992 -0.00719956]
---
4  mjx [[ 0.20000182 -0.254409  ]]
4  mj [ 0.20675072 -0.26395774]
4  err [ 0.0067489  -0.00954875]
---
5  mjx [[ 0.20000263 -0.25634894]]
5  mj [ 0.20841131 -0.26828134]
5  err [ 0.00840868 -0.0119324 ]
---
6  mjx [[ 0.20000356 -0.2586416 ]]
6  mj [ 0.21007585 -0.2729882 ]
6  err [ 0.01007229 -0.0143466 ]
---
7  mjx [[ 0.20000465 -0.261287  ]]
7  mj [ 0.21174182 -0.27808125]
7  err [ 0.01173717 -0.01679423]
---
8  mjx [[ 0.20000589 -0.26428512]]
8  mj [ 0.2134083 -0.2835662]
8  err [ 0.0134024  -0.01928109]
---
9  mjx [[ 0.20000727 -0.26763594]]
9  mj [ 0.21507496 -0.28945009]
9  err [ 0.01506768 -0.02181414]
---
10  mjx [[ 0.2000088  -0.27133948]]
10  mj [ 0.21674169 -0.2957408 ]
10  err [ 0.0167329  -0.02440134]
---
11  mjx [[ 0.20001046 -0.27539575]]
11  mj [ 0.21840845 -0.30244686]
11  err [ 0.01839799 -0.02705112]
---
12  mjx [[ 0.20001228 -0.27980474]]
12  mj [ 0.22007522 -0.30957739]
12  err [ 0.02006294 -0.02977264]
---
13  mjx [[ 0.20001423 -0.28456643]]
13  mj [ 0.22174199 -0.31714205]
13  err [ 0.02172776 -0.03257561]
---
14  mjx [[ 0.20001633 -0.28968084]]
14  mj [ 0.22340876 -0.3251511 ]
14  err [ 0.02339242 -0.03547028]
---
15  mjx [[ 0.20001858 -0.295148  ]]
15  mj [ 0.22507554 -0.33361538]
15  err [ 0.02505696 -0.03846738]
---
16  mjx [[ 0.20002098 -0.30096784]]
16  mj [ 0.22674232 -0.34254628]
16  err [ 0.02672133 -0.04157844]
---
17  mjx [[ 0.20002352 -0.3071404 ]]
17  mj [ 0.2284091  -0.35195582]
17  err [ 0.02838558 -0.04481542]
---
18  mjx [[ 0.2000262  -0.31366572]]
18  mj [ 0.23007588 -0.3618566 ]
18  err [ 0.03004968 -0.04819089]
---
19  mjx [[ 0.20002903 -0.32054374]]
19  mj [ 0.23174266 -0.37226185]
19  err [ 0.03171363 -0.05171812]
---

It is probably something specific to my system, yet I am failing to understand what it could be.

Steps for reproduction

Run the run.py script in the ZIP folder below. It also contains a requirements.txt file for reproducing my pip env.
Then, uncomment line 21 and run again.

mre_bug_noupdate.zip

Minimal model for reproduction

No response

Code required for reproduction

No response

Confirmations

@alaurenzi alaurenzi added the bug Something isn't working label Jan 15, 2025
@alaurenzi alaurenzi changed the title MJX step gives f MJX step gives wrong results or fails to update the state (depending on init conditions) Jan 15, 2025
@alaurenzi
Copy link
Author

From further investigation, it could be that the problem arises from some internal collisions that are not appropriately filtered
This may explain the different behavior between MJ and MJX which IIRC use different collision check pipelines (?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants