Skip to content

Commit b47a252

Browse files
leon2k2k2kclaude
andcommitted
spec 010: port openai#1695's online Hadamard rotation into train_gpt.py
Implements the port_1695 SpinQuant variant from PR openai#1695 onto the openai#1736 stack. All changes env-var-gated (SPINQUANT_ENABLED=0 default) so spec 008 and spec 009's baseline/internal_only modes are unaffected bit-for-bit. train_gpt.py changes (+247 lines): - import hashlib - Hyperparameters.spinquant_enabled, spinquant_seed - CastedLinear._sq_active class flag (default False) - Utility block: _stable_seed, _hadamard_rotation, install_spinquant_ rotations, _SQ_KEY_TO_TAG, _spinquant_rotate_sd_and_H - 4 forward-path hook sites (2 each in CausalSelfAttention, MLP, _block_with_lora, _parallel_block_with_lora): - pre-QKV: x_qkv = x @ R_attn_in - pre-attn-proj: y @ R_attn_proj_in - pre-fc: x @ R_mlp_in - post-activation pre-proj: hidden @ R_mlp_proj_in - serialize(): call _spinquant_rotate_sd_and_H after Hessian collection and before GPTQ. Rotates weights (W @ R) and Hessians (R.T @ H @ R). - deserialize(): install_spinquant_rotations + set _sq_active=True after loading rotated weights. - MLP.forward: disable fused kernel when SpinQuant active. - LoRA (TTT path) uses unrotated n, base path uses rotated n_qkv. spinquant_hotstart.py changes: - port_1695 mode no longer raises NotImplementedError. Sets h.spinquant_enabled=True and h.spinquant_seed; train_gpt.py's machinery does the rest. Math: orthogonal R means R @ R.T == I, so x_rot @ W_rot = x @ R @ (W @ R).T = x @ R @ R.T @ W.T = x @ W.T. Pre-quant forward is bit-identical to unrotated; GPTQ sees rotated basis where outliers are spread more evenly and quantization error drops. Spec 010 doc updated to reflect the implementation state. Execution runs via SPINQUANT_MODE=port_1695 on spinquant_hotstart.py. Not tested on GPU — flash_attn_3 not available on the dev box. Syntax clean. First pod run will verify end-to-end behavior. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 9b87109 commit b47a252

3 files changed

Lines changed: 321 additions & 46 deletions

File tree

records/track_10min_16mb/2026-04-19_SP8192_CaseOps_GatedAttn_QuantGate_Loop45_PhasedTTT/spinquant_hotstart.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,22 @@ def apply_rotations(model: GPT, h: Hyperparameters, mode: str, base_seed: int) -
216216
)
217217

218218
if mode == MODE_PORT_1695:
219-
raise NotImplementedError(
220-
"SPINQUANT_MODE=port_1695 is deferred until #1695's diff is read "
221-
"and the rotation scheme is ported. Script will be updated in a "
222-
"follow-up spec."
223-
)
219+
# #1695's scheme is implemented inside train_gpt.py itself
220+
# (_spinquant_rotate_sd_and_H called inside serialize(), and
221+
# install_spinquant_rotations called inside deserialize() when
222+
# h.spinquant_enabled is True). The driver's job is just to flip
223+
# the flag; no in-model mutation here. See research/specs/
224+
# 010-port-1695-online-rotation.md.
225+
h.spinquant_enabled = True
226+
h.spinquant_seed = base_seed
227+
manifest["rotations"] = {
228+
"strategy": "online+static via train_gpt.py machinery",
229+
"hyperparameters": {
230+
"spinquant_enabled": True,
231+
"spinquant_seed": base_seed,
232+
},
233+
}
234+
return manifest
224235

225236
raise ValueError(f"Unknown SPINQUANT_MODE={mode!r}; expected one of {_KNOWN_MODES}")
226237

0 commit comments

Comments
 (0)