--- license: apache-2.0 base_model: meta-llama/Llama-3.2-1B-Instruct tags: - neuromem - sda - sparse-distributed-attention - lora - peft --- # NeuroMem SDA LoRA Calibration Checkpoints Trained LoRA adapters + saved SDA address banks for the NeuroMem SDA project. Each subfolder is one calibration run with its own hyperparameters. **Base model**: `meta-llama/Llama-3.2-1B-Instruct` **Eval set**: WikiText-103 test, causal language modeling, `--sda_causal_eval` **Eval pipeline**: `evals/eval_perplexity.py` from `https://github.com/Tejas-JB/NeuroMem` (branch `vihaan/phase1-energy-baseline`) ## Current Leaderboard (in our local env — see Known Limitation below) | Run | M | k | LoRA r | Steps | Train CE | Eval PPL (5-win ctx=512, causal) | |---|---:|---:|---:|---:|---:|---:| | **run_d** (lost, evaluated only) | 16384 | 32 | 64 | 4000 | 4.93 | **414** ← best | | run_f (lost, evaluated only) | 8192 | 16 | 128 | 4000 | 3.65 | 432 (ctx=1024) / 447 (ctx=512) | | run_e (lost, evaluated only) | 8192 | 16 | 64 | 4000 | 5.07 | 523 | | **run_k_M16k_k32_r256_6k** | 16384 | 32 | 256 | 6000 | 1.25 | 1,059 (overfit) | | sweep_rank128_2k | 8192 | 16 | 128 | 2000 | — | _pending Benji eval_ | | sweep_rank64_2k | 8192 | 16 | 64 | 2000 | — | _pending Benji eval_ | | run_001_rank16_2k | 8192 | 16 | 16 | 2000 | 7.94 | 5,828 | In-flight (auto-uploaded as they finish): - `run_j_M32k_r128_8k` — M=32k, rank=128, 8000 steps - `run_l_M32k_r256_8k` — M=32k, rank=256, 8000 steps - `run_p_M16k_k32_r64_8k` — M=16k/k=32, rank=64, 8000 steps (sweet-spot push) - `run_n_M32k_r256_12k` — chained after run_j - `run_o_M16k_k32_r256_12k` — chained after run_l ## Known Limitation (please read before using) The Eval PPL numbers above are from our local environment, where raw uncalibrated SDA produces PPL ~9,781 (under causal eval) instead of the published baseline of PPL ~17.95. We have a documented but unresolved environmental regression in our SDA forward pass — likely a CUDA/cudnn-level fp16 numerical issue that doesn't affect Benji's setup. **The relative improvements are valid** (rank-16 → 64 yields ~10× PPL drop in our env, confirming capacity scaling). **The absolute PPL values are pessimistic.** Re-evaluating these checkpoints in the original SDA paper environment is needed for paper-quality absolute numbers. ## How to use a checkpoint ```python from transformers import AutoModelForCausalLM from peft import PeftModel import torch model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.2-1B-Instruct", dtype=torch.float16, attn_implementation="eager", ) # Patch with SDA — see https://github.com/Tejas-JB/NeuroMem/blob/vihaan/phase1-energy-baseline/src/sda/wrapper.py from src.sda.wrapper import patch_all_layers patch_all_layers(model, M=8192, r=0.3601) # match the checkpoint's M, r # Load LoRA + SDA banks model = PeftModel.from_pretrained(model, "") sda_banks = torch.load("", map_location="cpu", weights_only=True) # (See evals/eval_perplexity.py:268 `load_sda_checkpoint` for the full restore flow) ``` ## Run ```bash python evals/eval_perplexity.py \ --model meta-llama/Llama-3.2-1B-Instruct \ --attention sda --M --k --r \ --patch_layers all --attn_implementation eager \ --sda_checkpoint \ --dataset wikitext-103 --split test \ --batch_size 1 --context_length 2048 \ --output eval_result.json ``` Add `--sda_causal_eval` for spec-compliant causal evaluation (slow but correct).