Skip to content

Commit 17c8521

Browse files
leon2k2k2kclaude
andcommitted
swap spec numbering: 010 = port_1695 online rotation, 011 = tapered WD
User flagged that port_1695 should be the next spec (higher-impact, natural follow-up to 009) rather than tapered WD. Reshuffled: - 010-port-1695-online-rotation.md (NEW) — port openai#1695's online Hadamard rotation with rotated-basis GPTQ. Hotstart off spec 008 pre_gptq.pt. Expected Delta -0.003 to -0.005 bpb vs spec 009 baseline. ~\$10, 8xH100. - 011-tapered-wd.md (renumbered from 010) — Muon WD taper from openai#1729. Full retrain, ~\$20. Independent of specs 009/010, can run in parallel. Spec 010 inherits the design analysis from research/ideas/ spinquant-integration-notes.md (addendum section). Depends on spec 009 baseline measurement for apples-to-apples Delta. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 611811d commit 17c8521

2 files changed

Lines changed: 187 additions & 13 deletions

File tree

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Spec 010 — Port #1695's online Hadamard rotation scheme
2+
3+
**Slug:** `port-1695-online-rotation`
4+
**Created:** 2026-04-20
5+
**Links to idea:** `research/ideas/1736-improvement.md` and `research/ideas/spinquant-integration-notes.md` (Addendum section: "how #1695 actually does it").
6+
**Depends on:** spec 008 complete (`pre_gptq.pt` available). Should run **after** spec 009 — we want spec 009's `internal_only` number first, to judge whether this more invasive port delivers additional signal.
7+
8+
## Hypothesis
9+
10+
PR #1695's scheme — **online activation rotation with rotated-basis GPTQ** — delivers ~−0.005 bpb on the #1529 base. Porting it to #1736's stack should yield similar or better gain because #1736's stack is strictly richer (CaseOps + gates + phased TTT) without conflicts with the rotation design.
11+
12+
## Baseline
13+
14+
Spec 009's `baseline` mode (our reproduced #1736 seed-42 number, measured end-to-end by spec 009).
15+
16+
## Expected Δ
17+
18+
−0.003 to −0.005 bpb vs baseline. Stronger than spec 009's `internal_only` mode (~−0.002) because it rotates in four positions instead of one, and handles the MLP via the post-nonlinearity hook.
19+
20+
If `internal_only` already delivered ≥ −0.003 in spec 009, this lever's incremental gain on top may be smaller (~−0.001 to −0.002) — in that case the combined delta against spec 009 baseline could be ~−0.004 total.
21+
22+
## Approach overview (see integration notes addendum for full design)
23+
24+
#1695 uses **four Hadamard rotations applied online in the forward pass** — not baked into weights, not folded through nonlinearities.
25+
26+
| Rotation | Dim | Site |
27+
|---|---|---|
28+
| `R_attn_in` | d_model (512) | `x_qkv = x @ R_attn_in` before Q/K/V linear |
29+
| `R_attn_proj_in` | d_model (512) | `y = y @ R_attn_proj_in` before attn output proj |
30+
| `R_mlp_in` | d_model (512) | `x = x @ R_mlp_in` before fc |
31+
| `R_mlp_proj_in` | d_ff (2048) | `hidden = hidden @ R_mlp_proj_in` before proj (applied AFTER `LeakyReLU.square`) |
32+
33+
Rotations are `register_buffer`s (non-persistent, regenerated deterministically from `SPINQUANT_SEED`). Gated by `CastedLinear._sq_active` class flag — OFF during training (Dynamo constant-folds branch away), ON after `deserialize()` for quantized eval + TTT.
34+
35+
GPTQ Hessian must be rotated to match: `H_new = R.T @ H @ R` for each linear whose input is rotated.
36+
37+
**Why it works where static rotation doesn't:**
38+
39+
- `R_mlp_proj_in` applies after LeakyReLU² → no non-linearity to commute through.
40+
- Rotations operate on per-linear-input, never the residual stream → per-channel multipliers (`attn_scale`, `mlp_scale`, `resid_mix`, `skip_weights`) stay in trained basis, untouched.
41+
- Float pass is different from unrotated trained model — **no invariance check**. The bet: rotated-basis GPTQ error is lower, perturbation ≪ savings.
42+
43+
## Accept criteria
44+
45+
### Preflight
46+
- CPU-side sanity test: rotate a tiny model, verify GPTQ calibration runs without numerical blow-up on the rotated Hessian (no NaN, no inf in rotated H's eigenvalues). Optional — this is less critical than spec 009's invariance test because we're not claiming float invariance.
47+
48+
### On-pod
49+
- Script loads `pre_gptq.pt`, installs rotation buffers via `install_spinquant_rotations(...)`, sets `CastedLinear._sq_active = True`.
50+
- GPTQ runs (rotated Hessian path) without error.
51+
- Artifact < 16 MB.
52+
- Phased TTT completes within 600 s.
53+
- `final.json` with pre-quant, quantized, and post-TTT bpb.
54+
55+
### Primary success
56+
- **val_bpb < spec 009 baseline by ≥ 0.002** → SpinQuant online rotation lands on #1736, matches #1695's witnessed gain.
57+
- Ideally beats spec 009's `internal_only` by ≥ 0.001 → confirms the 4-rotation approach is worth the invasiveness.
58+
59+
## Config diff
60+
61+
```
62+
SPINQUANT_ENABLED=1
63+
SPINQUANT_SEED=42
64+
HOTSTART_FP_CKPT=/workspace/runs/008-1736-reproduction/seed_42/pre_gptq.pt
65+
```
66+
67+
Plus `ARTIFACT_DIR=/workspace/runs/010-port-1695/`.
68+
69+
## Code changes
70+
71+
- **Branch:** `research`.
72+
- **Patch target:** `records/track_10min_16mb/2026-04-19_SP8192_CaseOps_GatedAttn_QuantGate_Loop45_PhasedTTT/train_gpt.py`.
73+
- **Additions (~150 LOC, porting directly from #1695's diff):**
74+
1. `_hadamard_rotation(n, seed, tag)` utility — Sylvester-Hadamard × random-sign diag, QR re-orthogonalization. Uses `_SPINQUANT_CACHE` keyed by `(seed, tag, n)`.
75+
2. `install_spinquant_rotations(model, h, seed, log_fn)` — registers buffers `_sq_R_attn_in`, `_sq_R_attn_proj_in` on every `CausalSelfAttention` module; `_sq_R_mlp_in`, `_sq_R_mlp_proj_in` on every `MLP`.
76+
3. `CastedLinear._sq_active` class-level bool flag, default `False`.
77+
4. Forward-pass hooks in:
78+
- `CausalSelfAttention.forward` — lines ~765 and ~808 (pre-QKV and pre-out_proj).
79+
- `MLP.forward` — lines ~818 (pre-fc) and ~822 (pre-proj, AFTER LeakyReLU square). Also disable fused kernel when `_sq_active`.
80+
- `forward_ttt` (both parallel and sequential variants) — matching hooks.
81+
5. Rotation of GPTQ collected Hessian in `serialize()` path — a `_spinquant_rotate_sd_and_H` function that applies `H_new = R.T @ H @ R` for each matrix whose forward input is rotated.
82+
- **New file (optional):** `spinquant_online_hotstart.py` — standalone driver that:
83+
1. Loads FP state_dict from `HOTSTART_FP_CKPT`.
84+
2. Calls `install_spinquant_rotations(...)`.
85+
3. Sets `CastedLinear._sq_active = True`.
86+
4. Calls `serialize(h, base_model, code)` — GPTQ runs in rotated forward.
87+
5. Calls `deserialize(h, device)`.
88+
6. Runs quantized eval + phased TTT.
89+
7. Writes `final.json`.
90+
91+
Very similar structure to `spinquant_hotstart.py` from spec 009, just with `install_spinquant_rotations` replacing the R_a rotation.
92+
93+
- **Reference:** `gh pr diff 1695` — copy their rotation primitives and install function directly. Their forward-pass hook pattern works for both training-path and TTT-path forwards.
94+
95+
## Hardware ladder
96+
97+
8×H100, single seed (42). Same pod shape as spec 009. ~10 min compute + eval + TTT.
98+
99+
## Seed plan
100+
101+
Single seed 42. If it wins clearly (>−0.002 over baseline and > spec 009 internal_only), 3-seed confirmation becomes the next spec.
102+
103+
## Inputs
104+
105+
- **FP checkpoint:** `runs/008-1736-reproduction/seed_42/pre_gptq.pt` (spec 008 output).
106+
- **Data:** same CaseOps dataset as spec 008.
107+
- **Tokenizer:** bundled.
108+
- **Prior result needed first:** spec 009's `baseline` mode (gives us a measured spec-008-equivalent post-TTT number to compare against).
109+
110+
## Execution protocol
111+
112+
```bash
113+
cd /workspace/parameter-golf/records/track_10min_16mb/2026-04-19_SP8192_CaseOps_GatedAttn_QuantGate_Loop45_PhasedTTT
114+
115+
mkdir -p /workspace/runs/010-port-1695
116+
117+
NCCL_NET=Socket DATA_DIR=./data \
118+
ARTIFACT_DIR=/workspace/runs/010-port-1695 \
119+
CASEOPS_ENABLED=1 \
120+
PHASED_TTT_ENABLED=1 PHASED_TTT_PREFIX_DOCS=2000 PHASED_TTT_NUM_PHASES=3 \
121+
MLP_CLIP_SIGMAS=12.0 ATTN_CLIP_SIGMAS=13.0 \
122+
EMBED_BITS=7 EMBED_CLIP_SIGMAS=15.0 \
123+
GPTQ_RESERVE_SECONDS=4 GPTQ_CALIBRATION_BATCHES=16 \
124+
GATED_ATTN_ENABLED=1 GATED_ATTN_INIT_STD=0.005 GATED_ATTN_QUANT_GATE=1 \
125+
SPINQUANT_ENABLED=1 SPINQUANT_SEED=42 \
126+
HOTSTART_FP_CKPT=/workspace/runs/008-1736-reproduction/seed_42/pre_gptq.pt \
127+
SEED=42 \
128+
torchrun --standalone --nproc_per_node=8 spinquant_online_hotstart.py \
129+
> /workspace/runs/010-port-1695/run.log 2>&1
130+
```
131+
132+
## Stop-early criteria
133+
134+
- GPTQ Hessian rotation produces non-finite values → halt, debug Hessian math.
135+
- Artifact > 16 MB → halt.
136+
- val_bpb > spec 009 baseline + 0.003 → likely a forward-pass hook bug, halt.
137+
138+
## Checkpoints to emit
139+
140+
None. Reuses spec 008's `pre_gptq.pt` as sole input. Output is the rotated-and-quantized `.ptz` artifact.
141+
142+
## Cost estimate
143+
144+
| Item | Cost |
145+
|---|---|
146+
| Pod spin-up + compile warm-up | $2 |
147+
| Port setup (Hessian rotation debug if needed) | $3 |
148+
| Single run (8×H100, ~10 min GPU) | $5 |
149+
| **Total** | **~$10** |
150+
151+
Cheaper than spec 008 because no training.
152+
153+
## Extra artifacts
154+
155+
- `runs/010-port-1695/run.log`
156+
- `runs/010-port-1695/final_model.int6.ptz`
157+
- `runs/010-port-1695/rotation_manifest.json`
158+
- `runs/010-port-1695/final.json`
159+
160+
## Open questions for interview
161+
162+
1. **Hessian-rotation math:** does `H_new = R.T @ H @ R` correctly capture the relationship for all four rotation sites? `R_mlp_proj_in` acts on the post-nonlinearity hidden, so its corresponding Hessian is collected from `hidden.detach()` at line ~822. Double-check the collected-tensor identity before rotating.
163+
2. **GPTQ clip-sigma behavior:** `MLP_CLIP_SIGMAS=12.0`, `ATTN_CLIP_SIGMAS=13.0` were tuned for #1736's unrotated distributions. After rotation, weight/activation variance may shift. Initial run with original sigmas — if calibration fails or clip triggers excessively, sweep `*_CLIP_SIGMAS` wider.
164+
3. **Training-time flag:** `CastedLinear._sq_active` must be `False` during any TTT training step (so LoRA trains on unrotated forward consistently). The spec-009 TTT code path would be affected too if we ever composed the two. For spec 010 alone this is fine — we never retrain.
165+
4. **`_spinquant_rotate_sd_and_H` exact contents:** read the function in `#1695`'s diff and port it verbatim; their implementation handles the state_dict side too (are any weights rotated statically in addition to activations? check during porting).
166+
5. **Seed convention:** `#1695` uses `SPINQUANT_SEED=20260416` (their date). Spec 010 uses 42 to match our seed convention. If sensitivity to rotation-seed is detectable, a sweep can come later.
167+
168+
## What this spec does NOT do
169+
170+
- Does not touch any non-quant lever.
171+
- Does not retrain. Hotstart only.
172+
- Does not sweep rotation seed, clip sigmas, or schedule — single config port.
173+
- Does not attempt a hybrid with spec 009's `internal_only` R_a rotation. If both land positive, a follow-up spec can try the combination.
174+
- Does not modify #1736's training loop — rotation is only active post-deserialize for eval.
Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
# Spec 010 — Tapered weight decay (training lever, port from #1729)
1+
# Spec 011 — Tapered weight decay (training lever, port from #1729)
22

33
**Slug:** `tapered-wd`
44
**Created:** 2026-04-20
5-
**Links to idea:** `research/ideas/1736-improvement.md` (spec-010 section).
6-
**Can run in parallel with:** spec 009 (separate pod, independent work).
5+
**Links to idea:** `research/ideas/1736-improvement.md`.
6+
**Can run in parallel with:** specs 009 and 010 (separate pod, independent work).
77

88
## Hypothesis
99

@@ -93,10 +93,10 @@ Single pod, single run:
9393
```bash
9494
cd /workspace/parameter-golf/records/track_10min_16mb/2026-04-19_SP8192_CaseOps_GatedAttn_QuantGate_Loop45_PhasedTTT
9595

96-
mkdir -p /workspace/runs/010-tapered-wd/seed_42
96+
mkdir -p /workspace/runs/011-tapered-wd/seed_42
9797

9898
NCCL_NET=Socket DATA_DIR=./data \
99-
ARTIFACT_DIR=/workspace/runs/010-tapered-wd/seed_42 \
99+
ARTIFACT_DIR=/workspace/runs/011-tapered-wd/seed_42 \
100100
CASEOPS_ENABLED=1 \
101101
PHASED_TTT_ENABLED=1 PHASED_TTT_PREFIX_DOCS=2000 PHASED_TTT_NUM_PHASES=3 \
102102
MLP_CLIP_SIGMAS=12.0 ATTN_CLIP_SIGMAS=13.0 \
@@ -108,7 +108,7 @@ WD_TAPER_START_FRAC=0.70 \
108108
WD_TAPER_FINAL_MULT=0.50 \
109109
SEED=42 \
110110
torchrun --standalone --nproc_per_node=8 train_gpt.py \
111-
> /workspace/runs/010-tapered-wd/seed_42/train.log 2>&1
111+
> /workspace/runs/011-tapered-wd/seed_42/train.log 2>&1
112112
```
113113

114114
Verify in log: `muon_wd` value at step >= 0.7×total_steps should show the ramp. Add a one-time log line at the start of the taper zone:
@@ -120,7 +120,7 @@ log(f"WD_TAPER: start_step={start_step} total_steps={total_steps} "
120120

121121
## Checkpoints to emit
122122

123-
**Exactly one:** `runs/010-tapered-wd/seed_42/final_model.pt` — auto-saved by `serialize()` before GPTQ. Same convention as spec 008. Reusable for future quant-family experiments (SpinQuant, per-group bit, AR-selfgen) on top of tapered-WD weights if this lever lands.
123+
**Exactly one:** `runs/011-tapered-wd/seed_42/final_model.pt` — auto-saved by `serialize()` before GPTQ. Same convention as spec 008. Reusable for future quant-family experiments (SpinQuant, per-group bit, AR-selfgen) on top of tapered-WD weights if this lever lands.
124124

125125
Plus the submission `.ptz` artifact and `final.json` as usual.
126126

@@ -145,16 +145,16 @@ Same rough cost as spec 008, since it's a full retrain with a tiny config change
145145

146146
## Extra artifacts
147147

148-
- `runs/010-tapered-wd/seed_42/train.log` — full training log
149-
- `runs/010-tapered-wd/seed_42/final_model.pt` — pre-GPTQ FP checkpoint
150-
- `runs/010-tapered-wd/seed_42/final_model.int6.ptz` — quantized submission artifact
151-
- `runs/010-tapered-wd/seed_42/final.json` — post-TTT val_bpb, Δ vs spec 008, wall times
152-
- `runs/010-tapered-wd/seed_42/notes.md` — execution narrative
148+
- `runs/011-tapered-wd/seed_42/train.log` — full training log
149+
- `runs/011-tapered-wd/seed_42/final_model.pt` — pre-GPTQ FP checkpoint
150+
- `runs/011-tapered-wd/seed_42/final_model.int6.ptz` — quantized submission artifact
151+
- `runs/011-tapered-wd/seed_42/final.json` — post-TTT val_bpb, Δ vs spec 008, wall times
152+
- `runs/011-tapered-wd/seed_42/notes.md` — execution narrative
153153

154154
## Open questions for interview
155155

156156
1. **Which optimizer(s) get the taper?** PR #1729's body suggests their taper applied to *Muon WD only*. Our implementation should probably follow that — the lever as they measured it is Muon-specific. Adam WD can be left at 0.02 throughout. Confirm at interview; if unclear, run Muon-only for the first pass.
157-
2. **Parallel to spec 009?** Yes — spec 009 hotstarts off spec 008's `pre_gptq.pt` on one pod; spec 010 retrains on a separate pod. Independent. Total combined cost ~$35 if run simultaneously, vs ~$35 sequentially anyway — simultaneity just parallelizes wall time.
157+
2. **Parallel to spec 009?** Yes — spec 009 hotstarts off spec 008's `pre_gptq.pt` on one pod; spec 011 retrains on a separate pod. Independent. Total combined cost ~$35 if run simultaneously, vs ~$35 sequentially anyway — simultaneity just parallelizes wall time.
158158
3. **Is the taper linear or cosine?** PR #1729's README implies linear from start_frac to end. If cosine decay is preferred, we can change to `mult = h.wd_taper_final_mult + 0.5 * (1 - h.wd_taper_final_mult) * (1 + cos(pi * progress))`. For the first pass, linear is simpler and cheaper to reason about.
159159
4. **Does WD taper interact with MATRIX_LR decay?** #1736 already has a cosine LR schedule during warmdown. Tapering WD on top is an additional schedule — need to verify no weird interaction (e.g., LR near-zero + reduced WD = almost no parameter movement, which shouldn't matter but worth glancing at training log).
160160

0 commit comments

Comments
 (0)