Example of rolling out Python-based policies and MJX-based environments with passive viewer #2324
-
IntroHello, I'm a MuJoCo user working on manipulation. My setup
My questionI'm looking for an example of how to sync mjpython passive viewer's state with external data. Specifically, I have an mjx-based environment and I want to visualize a rollout of trained jax-based policy while manually moving/perturbing the bodies from the passive viewer. Say I have the following (somewhat incomplete and naive) passive viewer code: policy = jax_based_policy()
environment = mjx_based_environment()
states = [environment_reset(reset_rng)]
with mujoco.viewer.launch_passive(model, data) as viewer:
while viewer.is_running():
action_rng, rng = jax.random.split(rng)
action = policy_fn(action_rng, states[-1].obs)
states.append(environment_step(states[-1], action))
mjx.get_data_into(
data,
model,
extract_first_state(states[-1].pipeline_state),
)
viewer.sync()
time_until_next_step = model.opt.timestep - (time.time() - step_start)
if time_until_next_step > 0:
time.sleep(time_until_next_step) This will obviously not work because none of the perturbations in data are synced into pipeline states. What should I sync from Minimal model and/or code that explain my questionNo response Confirmations
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
I think @btaba or @kevinzakka should answer this, but a question to you: wouldn't it be easier to run the MJX-trained policy on C MuJoCo? Did you try that and get strange results? |
Beta Was this translation helpful? Give feedback.
-
I managed to figure out a solution that satisfies my needs. I'll copy-paste a sketch of what my script looks like below in case someone else is looking for similar solution. Before that, though, to answer Yuval's question:
It would, if I wanted to rollout just the MuJoCo model. However, my task logic is written with MJX and in order to debug/visualize the task itself, I need to step with the environment. I could duplicate the task logic in a non-MJX environment, but that seems a bit clumsy and error-prone -- after all, the whole point of doing this is to enable debugging the original task. Maybe there exists a better and more flexible way of handling the environment/task logic. This is somewhat related to my earlier question in #2125; happy to hear more if anyone has thoughts on these environment/task things 🙂 But here's a sketch of the script I ended up using: `mjx_passive_viewer.py`import collections
import dataclasses
import time
from absl import app
from absl import logging
from etils import eapp
from etils import epath
import jax
import jax.numpy as jnp
import mujoco
from mujoco import mjx
import mujoco.viewer
import numpy as np
Path = epath.Path
@dataclasses.dataclass
class Args:
policy_export_dir_path: Path
seed: int = 0
def main(args: Args) -> None:
logging.info(f"args: {args}")
policy_export_dir_path = args.policy_export_dir_path
seed = args.seed
def policy_fn(rng: jax.Array, observation: jax.Array) -> jax.Array:
...
environment: brax.envs.PipelineEnv = ...
environment_step = jax.jit(environment.step)
environment_reset = jax.jit(environment.reset)
rng = jax.random.PRNGKey(seed)
reset_rng, rng = jax.random.split(rng, 2)
model = environment.sys.mj_model
data = mujoco.MjData(model)
states = collections.deque([environment_reset(reset_rng)], maxlen=10)
assert np.all(np.isfinite(states[-1].obs))
# Call policy function with reset state to make sure it's compiled.
action_rng, rng = jax.random.split(rng)
policy_fn(action_rng, states[-1].obs)
paused = False
simulation_speeds = [0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1.0]
simulation_speed_index = simulation_speeds.index(1.0)
def key_callback(keycode: int) -> None:
nonlocal paused, simulation_speed_index
if chr(keycode) == " ":
# TODO(hartikainen): Render "paused" label into viewer.
paused = not paused
elif chr(keycode) == "-":
# TODO(hartikainen): Render simulation speed label into viewer.
simulation_speed_index = max(0, simulation_speed_index - 1)
elif chr(keycode) == "+":
# TODO(hartikainen): Render simulation speed label into viewer.
simulation_speed_index = min(
len(simulation_speeds) - 1, simulation_speed_index + 1
)
else:
print(f"{keycode=}")
mjx.get_data_into(data, model, states[-1].pipeline_state)
with mujoco.viewer.launch_passive(model, data, key_callback=key_callback) as viewer:
viewer.sync()
while viewer.is_running():
simulation_speed = simulation_speeds[simulation_speed_index]
step_start = time.time()
if paused:
mujoco.mjv_applyPerturbPose(model, data, viewer.perturb, 1)
mujoco.mj_forward(model, data)
else:
mujoco.mjv_applyPerturbPose(model, data, viewer.perturb, 0)
mujoco.mjv_applyPerturbForce(model, data, viewer.perturb)
pipeline_state = states[-1].pipeline_state.replace(
qpos=jnp.array(data.qpos),
qvel=jnp.array(data.qvel),
mocap_pos=jnp.array(data.mocap_pos),
mocap_quat=jnp.array(data.mocap_quat),
xfrc_applied=jnp.array(data.xfrc_applied),
)
states[-1] = states[-1].replace(pipeline_state=pipeline_state)
action_rng, rng = jax.random.split(rng)
action = policy_fn(action_rng, states[-1].obs)
states.append(environment_step(states[-1], action))
assert np.all(np.isfinite(states[-1].obs))
mjx.get_data_into(data, model, states[-1].pipeline_state)
# Can modify the viewer with `viewer.lock`.
with viewer.lock():
pass
viewer.sync()
# Rudimentary time keeping, will drift relative to wall clock.
time_until_next_step = model.opt.timestep / simulation_speed - (
time.time() - step_start
)
if time_until_next_step > 0:
time.sleep(time_until_next_step)
if __name__ == "__main__":
eapp.better_logging()
app.run(main, flags_parser=eapp.make_flags_parser(Args)) There's nothing too surprising in it. The |
Beta Was this translation helpful? Give feedback.
I managed to figure out a solution that satisfies my needs. I'll copy-paste a sketch of what my script looks like below in case someone else is looking for similar solution. Before that, though, to answer Yuval's question:
It would, if I wanted to rollout just the MuJoCo model. However, my task logic is written with MJX and in order to debug/visualize the task itself, I need to step with the environment. I could duplicate the task logic in a non-MJX environment, but that seems a bit clumsy and error-prone -- after all, the whole point of doing this is to enable debugging the original task. Maybe there exists a better and m…