- LeWM + PRISM-MPPI for Franka PushT (v3 β 411-ep dataset, 200-ep target-clean prior)
LeWM + PRISM-MPPI for Franka PushT (v3 β 411-ep dataset, 200-ep target-clean prior)
An action-conditioned latent world model (LeWM) paired with a PRISM action prior, both trained on real Franka FR3 PushT teleoperation data. Supports three planning modes: PRISM-MPPI (PoG fusion, default), warm-start, and vanilla MPPI β selectable at deploy time for A/B comparison.
What's new vs v2
| Component | v2 | v3 |
|---|---|---|
| Training dataset | 304 eps / 72 k frames | 411 eps / 94 k frames |
| WM val_pred_loss | 0.0045 | 0.0046 (similar) |
| CV @ H=5 | 0.180 | 0.195 |
| pred/id @ H=1 | 0.465 | 0.402 (better short-h fidelity) |
| PRISM action prior | β not included | β
prior_head_pusht_fr3_v3.pt (H=3, val MSE 0.348) |
| PRISM-MPPI mode | β vanilla only | β 3 modes: pog / warm_start / none |
Default H (plan-steps) |
3 | 3 (matches both the prior and the WM's deploy sweet spot) |
Model summary
| World Model | PRISM Prior Head | |
|---|---|---|
| Architecture | ViT-tiny encoder + 6-layer AR-Transformer predictor + Embedder action encoder | 3-layer MLP, hidden=512, Ξ²-NLL loss |
| Parameters | 18.03 M | 0.51 M |
| Input | (224, 224, 3) RGB obs + goal | (z_t, z_g) β β^(2 Γ 192) |
| Output | next-latent prediction | (ΞΌ, Ο) over action sequence (H=3, A_block=5, action_dim=2) β 15 ticks = 1.5 s |
| Training data | All 411 eps (mixed: 200 target + 211 random terminations) | eps 0-199 only (target-completion subset) |
| Goal supervision | (n/a β predictor) | HER hindsight, episode last frame (sim convention, Andrychowicz et al. 2017) |
| Final val_pred_loss / val MSE | 0.0046 | 0.348 (Ο β 0.53 β βMSE, well-calibrated) |
Why H=3 (not the sim convention H=5)?
The v3 WM's per-step rollout fidelity (pred/id) is best at short horizons
(0.40 @ H=1, 0.186 @ H=5, 0.331 @ H=25 β see docs/34). H=3 keeps both the WM
rollout and the prior's action sequence within the high-fidelity envelope.
Empirically, training the prior at H=3 (vs sim's H=5) gives β8.7 % val MSE
(0.348 vs 0.381) and a tighter Ο (0.53 vs 0.58) β full ablation in
docs/35 Β§11.
Why train the prior on only the first 200 episodes?
Eps 200-410 were collected with no fixed-target pushing β the operator stopped
T at arbitrary positions, making the episode-last-frame z_g a noisy
supervision signal. Training on the full 411 eps with HER endframe yields a
broken prior (val MSE 1.63, fails HARD GATE). Restricting to the target-clean
subset 0-199 recovers a useful prior (val MSE 0.38, well-calibrated Ο). See
docs/35 Β§9
for the full ablation.
Plan-worthiness diagnostics (WM only β measured on the train distribution)
| Metric @ H=5 | Value | Interpretation |
|---|---|---|
| CV | 0.195 Β± 0.010 | Borderline below 0.30 plan-worthy threshold; PRISM prior helps |
| GT_rank | 36.6 % Β± 2.1 | Direction correct ("weak-align" tier) |
| pred/id @ H=1 | 0.402 Β± 0.016 | Good single-step action conditioning |
| pred/id @ H=5 | 0.186 Β± 0.006 | Reasonable 2.5 s rollout fidelity |
| pred/id @ H=25 | 0.331 Β± 0.006 | Long-horizon drift β use H β€ 5 |
See docs/34 for the full data-scaling analysis (v1 β v2 β v3 monotone CV climb).
Three planning modes
| Mode | Description | When to use |
|---|---|---|
pog (default) |
PRISM-MPPI: Product-of-Gaussians fusion of prior (ΞΌ, Ο) into MPPI init | Main deploy mode |
warm_start |
Use prior mean only, keep planner default Ο | A/B test (isolates Ο-fusion's contribution) |
none |
Vanilla LeWM-MPPI (no prior) | Paper-grade A/B baseline, or no-prior fallback |
All three modes share the same LeWM, MPPI loop, K=300, n_iters=30, and action scaler β apples-to-apples comparison.
Quick start
pip install torch torchvision numpy einops transformers huggingface_hub
from huggingface_hub import snapshot_download
import numpy as np
# Download the bundle (WM + prior + inference code)
local = snapshot_download("YuhaiW/lewm-pusht-fr3-v3")
import sys; sys.path.insert(0, local)
from pusht_lewm_inference import PushtLewmInference
# ββ PRISM-MPPI deploy (recommended default) βββββββββββββββββββββββββββββ
planner = PushtLewmInference(
lewm_ckpt = f"{local}/lewm_pusht_fr3_v3.ckpt",
prior_ckpt = f"{local}/prior_head_pusht_fr3_v3.pt",
injection_mode = "pog", # "pog" | "warm_start" | "none"
device = "cuda",
)
# In the robot control loop (10 Hz):
while not done:
obs_uint8 = camera_rgb() # (224, 224, 3) uint8
goal_uint8 = goal_rgb() # (224, 224, 3) uint8
actions = planner.plan(obs_uint8, goal_uint8)
# (5, 2) float32 β Ξxy meters for next 0.5 s
for a in actions:
robot.send_delta_target(a)
time.sleep(0.1) # 10 Hz tick
To A/B test against vanilla MPPI on the same WM:
vanilla = PushtLewmInference(
lewm_ckpt = f"{local}/lewm_pusht_fr3_v3.ckpt",
prior_ckpt = f"{local}/prior_head_pusht_fr3_v3.pt", # loaded for scaler
injection_mode = "none", # disable PoG fusion
)
Robot expectations
| Robot | Franka FR3 (or compatible) with Cartesian impedance control |
| Action interpretation | Ξ-target XY in meters (per tick) |
| Control frequency | 10 Hz |
| Camera | Top-down RGB at 224 Γ 224 |
| Goal image | Single RGB showing the desired final scene |
| Z, rotation, gripper | NOT controlled (XY-only by design; lock in your controller) |
| Teleop style assumed | "Decisive" pushes β operator commits and pushes in one smooth motion |
What's in the bundle
lewm_pusht_fr3_v3.ckpt # 72 MB β world model (pickled JEPA object)
prior_head_pusht_fr3_v3.pt # 2 MB β PRISM prior head + StandardScaler + meta
action_scaler.json # 0.5 KB β fallback scaler when no prior_ckpt
pusht_lewm_inference.py # standalone PRISM-MPPI planner (3 modes)
jepa.py, module.py # required for LeWM ckpt deserialization
prior_head.py # required for prior ckpt deserialization
requirements.txt # minimal deps
README.md # this file
Architecture overview
obs (224Β²) goal (224Β²)
β β
βΌ βΌ
ViT-tiny β ViT-tiny β β shared weights
β β
z_t β β^192 z_g β β^192
βββ PRISM Prior Head (optional, "pog" / "warm_start" modes) ββ
β β
βΌ βΌ
concat(z_t, z_g) β MLP β (ΞΌ_p, Ο_p) β β^(H Γ A_block Γ A_raw)
β β
βΌ βΌ
βββ PoG fusion with MPPI init ΞΌ=0, Ο=var_scale ββ
β
βΌ
N(ΞΌ_fused, Ο_fused) sampled K=300 times
β
βββ MPPI iterations (LeWM AR rollout cost vs z_g) ββ
β
βΌ
optimized action sequence (H Γ A_block, A_raw)
β
βΌ
first A_block actions β robot
Caveats and limitations
- PRISM prior trained on 200 eps (target subset). The full 411-ep dataset
is heterogeneous (eps 200-410 have random T-final-positions). Training the
prior on the clean subset gives a usable signal (val MSE 0.38) but still
~3Γ worse than sim/red-cube counterparts. The next-best improvement would be
collecting future datasets with an explicit
goal_pixelsfield (one printed target image per session). - WM cost surface is borderline (CV @ H=5 = 0.195 < 0.30). The PRISM prior is expected to help bridge the gap; PRISM-MPPI's cost-rescoring step tolerates the borderline cost surface in a way that vanilla MPPI cannot.
- Trained on top-down RGB only. Other camera angles are OOD.
- 2-D XY action space. Z, rotation, gripper are not controlled.
- 10 Hz tick. Faster/slower control loops mismatch the action scaler.
Provenance
- WM trained: 2026-06-03 (RTX 5090, ~4 h 53 min on all 411 eps)
- Prior trained: 2026-06-03 (RTX 5090, ~30 s on first 200 eps + sim-aligned HER)
- Dataset snapshot:
Rongxuan-Zhou/pusht_lewm_fr3sha1b5dd5db801ef405b43d51dd5b9a3210d8d79ce6 - Project: PRISM-JEPA
- Companions:
YuhaiW/lewm-pusht-fr3-v2β previous release (304 eps, no prior, vanilla MPPI only)YuhaiW/prism-jepa-red-cube-arxβ same architecture on ARX cube task
Citation
@misc{prism-jepa-pusht-fr3-v3,
title = {LeWM + PRISM-MPPI for Franka PushT (v3 β 411-ep)},
author = {Wang, Yuhai and Zhou, Rongxuan and collaborators},
year = {2026},
url = {https://huggingface.co/YuhaiW/lewm-pusht-fr3-v3}
}
If you cite the PRISM action prior mechanism, also cite Andrychowicz et al. (2017) for hindsight experience replay, on which our prior training is based.
License
Apache 2.0.