Skip to content

Commit c714a4d

Browse files
authored
Merge pull request openai#1412 from Robby955/submission/parallel-residuals-hessian-sdclip
Record: SP8192 + Parallel Residuals + Hessian-Aware SDClip — val_bpb 1.08354 (3-seed mean)
2 parents bac888c + 4b57791 commit c714a4d

8 files changed

Lines changed: 2669 additions & 0 deletions

File tree

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Non-record: Parallel Residuals + Hessian-Aware SDClip (3-seed mean 1.08354 BPB)
2+
3+
**val bpb: 1.08354** (3-seed mean, std=0.00050)
4+
5+
Not a record. This is a small 3-seed experiment over [PR #1394](https://github.com/openai/parameter-golf/pull/1394) on my runs, but not enough evidence for a statistical claim — the seed count is too small for confidence. Posting because the changes are zero-cost, reproducible, and may be useful to others trying out different techniques.
6+
7+
| Seed | Steps | Pre-quant BPB | Post-quant BPB | **Sliding BPB** | Artifact |
8+
|-|-|-|-|-|-|
9+
| 1337 | 5178 | 1.08765 | 1.09959 | **1.08301** | 15,976,275 |
10+
| 42 | 5180 | 1.08816 | 1.10013 | **1.08363** | 15,978,439 |
11+
| 3141 | 5182 | 1.08872 | 1.10044 | **1.08399** | 15,979,649 |
12+
| **Mean** | | 1.08818 | 1.10005 | **1.08354** | 15,978,121 |
13+
14+
## Changes
15+
16+
Three zero-cost modifications on top of [PR #1394](https://github.com/openai/parameter-golf/pull/1394), adding zero extra parameters or bytes:
17+
18+
### 1. Parallel Residuals (Layers 7+)
19+
20+
GPT-J style parallel attention+MLP ([Wang & Komatsuzaki, 2021](https://github.com/kingoflolz/mesh-transformer-jax)) for the last 4 layers. Both attention and MLP read from the same input and their outputs are added in parallel:
21+
22+
```
23+
# Parallel (layers 7-10):
24+
x_out = x + attn_scale * Attn(norm(x)) + mlp_scale * MLP(norm(x))
25+
26+
# Sequential (layers 0-6, unchanged):
27+
h = x + attn_scale * Attn(norm(x))
28+
x_out = h + mlp_scale * MLP(norm(h))
29+
```
30+
31+
I expected parallel residuals to reduce interference between attention and MLP during GPTQ calibration. Pre-quant BPB barely moved, but the quantization gap tightened across all 3 seeds, which made this the most useful change in practice.
32+
33+
### 2. Hessian-Aware SDClip
34+
35+
I used GPTQ's existing Hessian diagonal as a cheap importance signal to slightly modulate SDClip thresholds by row:
36+
37+
$$c_i = k \cdot \sigma_i \cdot [1 + \lambda(r_i - 1)], \quad \lambda = 0.175$$
38+
39+
where $\sigma_i$ is the standard deviation of row $i$ and $r_i$ is the row importance derived from Hessian-weighted magnitude. The effect is small but directionally useful at $\lambda = 0.175$; higher $\lambda$ hurt compression. I initially used $\lambda = 0.30$ but found $\lambda = 0.175$ is consistently better across seeds — both lower BPB and smaller artifact. Higher $\lambda$ reduces rounding error but increases entropy, which makes Brotli compression less effective.
40+
41+
### 3. Progressive Recurrence
42+
43+
Depth recurrence split into two phases: first loop enabled at 50% of training, second at 65%. The split points were not optimized — 50% matches the original and 65% was a single manual choice. Enabling both loops at once causes a sharper loss spike; splitting gives the model time to adapt to each additional pass before adding the next.
44+
45+
## Hessian Analysis (Cross-Seed)
46+
47+
Hessian diagnostics from 3 seeds, 67 matrices each:
48+
49+
- **Group-level traces** (early/loop/mid/late blocks): $r=0.997$ across seeds
50+
- **Per-matrix traces**: $r=0.994$
51+
- **Per-row importance**: $r=0.12$ (noise)
52+
53+
Importance hierarchy: early blocks (30x trace of late blocks) >> loop >> mid >> late. Per-row importance is too noisy to be a reliable signal, but group-level traces are very stable across seeds. This suggests per-group clip allocation could be a useful direction.
54+
55+
## Future Directions
56+
57+
Several ideas I'd like to explore with more compute time:
58+
59+
- **Per-group clip allocation**: Non-uniform $k$ across layer groups, using the stable group-level trace hierarchy as a guide.
60+
- **Output-Hessian weighting**: Using backward-pass gradients for output-side row importance rather than input-side alone.
61+
- **More seeds**: 3 seeds is not enough for strong statistical claims. I'd want 5+ to be confident about the gap vs PR #1394.
62+
- **YAQA**: I like the idea of the paper ([arXiv:2505.22988](https://arxiv.org/abs/2505.22988)), but I couldn't get a working backward pass for it. I think maybe it could be adapted for the parameter golf problem in an interesting way. I also like the math in Mousse ([arXiv:2603.09697](https://arxiv.org/abs/2603.09697)), but exploiting curvature in small LMs seems tough.
63+
64+
## Run Command
65+
66+
```bash
67+
HESSIAN_CLIP_LAMBDA=0.175 LOOP_PHASE2_AT=0.65 PARALLEL_RESIDUAL_START=7 SEED=1337 \
68+
torchrun --standalone --nproc_per_node=8 train_gpt_sweep.py
69+
```
70+
71+
## Requirements
72+
73+
Flash Attention 3 (Hopper) required. SP8192 BPE tokenizer trained on FineWeb 10B (sentencepiece BPE, 8192 vocab).
74+
75+
```bash
76+
pip install torch --index-url https://download.pytorch.org/whl/cu130
77+
pip install --no-cache-dir \
78+
"https://download.pytorch.org/whl/cu130/flash_attn_3-3.0.0-cp39-abi3-manylinux_2_28_x86_64.whl"
79+
pip install -r requirements.txt
80+
```
81+
82+
## Compliance (Track A — Fixed Predictor)
83+
84+
- No TTT, SLOT, n-gram cache, or eval-time adaptation
85+
- GPTQ calibration within training budget
86+
- Standard autoregressive sliding-window eval (stride=64)
87+
88+
89+
## Credits
90+
91+
Learned from and inspired by [PR #1394](https://github.com/openai/parameter-golf/pull/1394) (@clarkkev) — SDClip, depth recurrence, and GPTQ embedding quantization ideas. Parallel residuals from GPT-J ([Wang & Komatsuzaki, 2021](https://github.com/kingoflolz/mesh-transformer-jax)). Additional credits: PR #1204 (@msisovic, depth recurrence), PR #1217 (@bigbag, MuonEq-R), PR #1019 (@abaybektursun, previous SOTA).
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"author": "Robby Sneiderman",
3+
"github_id": "Robby955",
4+
"name": "Non-record: Parallel Residuals + Hessian-Aware SDClip",
5+
"blurb": "Three zero-cost modifications on PR #1394: GPT-J parallel residuals (layers 7+), Hessian-diagonal SDClip modulation (lambda=0.175), two-phase progressive recurrence. 3-seed mean 1.08354 BPB.",
6+
"date": "2026-04-06T00:00:00Z",
7+
"val_loss": 2.7995,
8+
"val_bpb": 1.08354,
9+
"val_bpb_std": 0.00050,
10+
"seeds": [1337, 42, 3141],
11+
"seed_results": {
12+
"1337": {
13+
"val_loss": 2.79749,
14+
"val_bpb": 1.08301,
15+
"artifact_bytes": 15976275,
16+
"steps": 5178
17+
},
18+
"42": {
19+
"val_loss": 2.79910,
20+
"val_bpb": 1.08363,
21+
"artifact_bytes": 15978439,
22+
"steps": 5180
23+
},
24+
"3141": {
25+
"val_loss": 2.80002,
26+
"val_bpb": 1.08399,
27+
"artifact_bytes": 15979649,
28+
"steps": 5182
29+
}
30+
},
31+
"hardware": "8xH100 80GB SXM",
32+
"bytes_total": 15978121,
33+
"based_on": "PR #1394 (@clarkkev)"
34+
}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
W0406 05:48:52.138000 4182076 torch/distributed/run.py:803]
2+
W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] *****************************************
3+
W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
4+
W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] *****************************************
5+
Hyperparameters:
6+
adam_eps: 1e-08
7+
adam_wd: 0.02
8+
beta1: 0.9
9+
beta2: 0.95
10+
compressor: brotli
11+
data_dir: ./data/
12+
datasets_dir: ./data/datasets/fineweb10B_sp8192
13+
distributed: True
14+
ema_decay: 0.997
15+
embed_bits: 8
16+
embed_clip_sigmas: 20.0
17+
embed_lr: 0.6
18+
embed_wd: 0.085
19+
embedding_dim: 512
20+
enable_looping_at: 0.5
21+
eval_seq_len: 2048
22+
eval_stride: 64
23+
gptq_calibration_batches: 64
24+
gptq_reserve_seconds: 12.0
25+
grad_accum_steps: 1
26+
grad_clip_norm: 0.3
27+
head_lr: 0.008
28+
hessian_clip_lambda: 0.175
29+
is_main_process: True
30+
iterations: 20000
31+
ln_scale: True
32+
local_rank: 0
33+
logfile: logs/f3971278-d577-499b-8fde-755434809ba9.txt
34+
logit_softcap: 30.0
35+
loop_end: 5
36+
loop_layer_bits: 0
37+
loop_layer_clip_sigmas: 0.0
38+
loop_phase2_at: 0.65
39+
loop_start: 4
40+
matrix_bits: 6
41+
matrix_clip_sigmas: 12.85
42+
matrix_lr: 0.02
43+
max_wallclock_seconds: 600.0
44+
min_lr: 0.0
45+
mlp_mult: 4.0
46+
model_dim: 512
47+
model_path: final_model.pt
48+
muon_backend_steps: 5
49+
muon_beta2: 0.95
50+
muon_momentum: 0.99
51+
muon_momentum_warmup_start: 0.92
52+
muon_momentum_warmup_steps: 1500
53+
muon_row_normalize: True
54+
muon_wd: 0.085
55+
num_heads: 8
56+
num_kv_heads: 4
57+
num_layers: 11
58+
num_loops: 2
59+
parallel_residual_start: 7
60+
qk_gain_init: 4.0
61+
quantized_model_path: final_model.int6.ptz
62+
rank: 0
63+
rope_base: 10000.0
64+
rope_dims: 16
65+
rope_train_seq_len: 2048
66+
run_id: f3971278-d577-499b-8fde-755434809ba9
67+
scalar_lr: 0.02
68+
seed: 1337
69+
skip_gates_enabled: True
70+
sliding_window_enabled: True
71+
tie_embeddings: True
72+
tied_embed_init_std: 0.005
73+
tied_embed_lr: 0.03
74+
tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model
75+
train_batch_tokens: 786432
76+
train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin
77+
train_log_every: 500
78+
train_seq_len: 2048
79+
ttt_chunk_tokens: 32768
80+
ttt_enabled: False
81+
ttt_entropy_high: 2.1
82+
ttt_entropy_low: 1.75
83+
ttt_epochs: 4
84+
ttt_freeze_blocks: 2
85+
ttt_lr: 0.0005
86+
ttt_ns_steps: 3
87+
untie_loop_mlps: False
88+
val_batch_tokens: 524288
89+
val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin
90+
val_loss_every: 4000
91+
vocab_size: 8192
92+
warmdown_frac: 0.667
93+
warmup_steps: 20
94+
world_size: 8
95+
xsa_last_n: 11
96+
train_shards: 128
97+
val_tokens: 40540160
98+
model_params:35943512
99+
hessian_clip: lambda=0.175
100+
parallel_residuals: ON (layers 7-10)
101+
progressive_recurrence: phase1=0.5 phase2=0.65
102+
gptq:reserving 12s, effective=588000ms
103+
warmup_step: 1/20
104+
warmup_step: 2/20
105+
warmup_step: 3/20
106+
warmup_step: 4/20
107+
warmup_step: 5/20
108+
warmup_step: 6/20
109+
warmup_step: 10/20
110+
warmup_step: 20/20
111+
loop_warmup_phase1: encoder:[0, 1, 2, 3, 4, 5] decoder:[4, 5, 6, 7, 8, 9, 10]
112+
loop_warmup_p1_step: 1/20
113+
loop_warmup_p1_step: 2/20
114+
loop_warmup_p1_step: 3/20
115+
loop_warmup_p1_step: 4/20
116+
loop_warmup_p1_step: 5/20
117+
loop_warmup_p1_step: 6/20
118+
loop_warmup_p1_step: 10/20
119+
loop_warmup_p1_step: 20/20
120+
loop_warmup_phase2: encoder:[0, 1, 2, 3, 4, 5, 4] decoder:[5, 4, 5, 6, 7, 8, 9, 10]
121+
loop_warmup_p2_step: 1/20
122+
loop_warmup_p2_step: 2/20
123+
loop_warmup_p2_step: 3/20
124+
loop_warmup_p2_step: 4/20
125+
loop_warmup_p2_step: 5/20
126+
loop_warmup_p2_step: 6/20
127+
loop_warmup_p2_step: 10/20
128+
loop_warmup_p2_step: 20/20
129+
0/20000 val_loss: 9.0052 val_bpb: 3.4862
130+
1/20000 train_loss: 9.0086 train_time: 0.0m tok/s: 8101558 looping:False
131+
2/20000 train_loss: 12.1911 train_time: 0.0m tok/s: 8047236 looping:False
132+
3/20000 train_loss: 11.0242 train_time: 0.0m tok/s: 7763017 looping:False
133+
4/20000 train_loss: 9.5010 train_time: 0.0m tok/s: 7744899 looping:False
134+
5/20000 train_loss: 8.3911 train_time: 0.0m tok/s: 7764479 looping:False
135+
500/20000 train_loss: 3.3106 train_time: 0.8m tok/s: 7721577 looping:False
136+
1000/20000 train_loss: 3.2012 train_time: 1.7m tok/s: 7711891 looping:False
137+
1500/20000 train_loss: 3.1827 train_time: 2.5m tok/s: 7711479 looping:False
138+
2000/20000 train_loss: 2.9936 train_time: 3.4m tok/s: 7711845 looping:False
139+
2500/20000 train_loss: 3.0679 train_time: 4.2m tok/s: 7713114 looping:False
140+
layer_loop:phase1 step:2884 frac:0.500
141+
3000/20000 train_loss: 3.1068 train_time: 5.1m tok/s: 7668374 looping:True
142+
3500/20000 train_loss: 2.9483 train_time: 6.1m tok/s: 7510155 looping:True
143+
layer_loop:phase2 step:3634 frac:0.650
144+
4000/20000 train_loss: 2.9482 train_time: 7.2m tok/s: 7297673 looping:True
145+
4000/20000 val_loss: 2.9279 val_bpb: 1.1335
146+
4500/20000 train_loss: 2.8499 train_time: 8.3m tok/s: 7110716 looping:True
147+
5000/20000 train_loss: 2.8598 train_time: 9.4m tok/s: 6967320 looping:True
148+
5178/20000 val_loss: 2.8121 val_bpb: 1.0887
149+
stopping_early: wallclock_cap train_time: 588103ms step: 5178/20000
150+
peak memory allocated: 34604 MiB reserved: 34634 MiB
151+
ema:applying EMA weights
152+
pre-quantization post-ema val_loss:2.80947408 val_bpb:1.08764765 eval_time:6554ms
153+
Serialized model: 135426937 bytes
154+
Code size: 78688 bytes
155+
GPTQ:collecting Hessians from calibration data...
156+
GPTQ:collected 67 Hessians in 11.3s
157+
GPTQ:saved Hessian diagnostics to hessian_diagnostics.pt (67 matrices)
158+
Quantized weights:
159+
gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight
160+
gptq (int8): tok_emb.weight
161+
passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights
162+
Serialized model quantized+brotli: 15976275 bytes
163+
Total submission size quantized+brotli: 16054963 bytes
164+
quantized val_loss:2.84032998 val_bpb:1.09959307 eval_time:8134ms
165+
quantized_sliding_window val_loss:2.79749368 val_bpb:1.08300961 eval_time:82837ms
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
exec(open(__file__.replace("train_gpt.py","train_gpt_decode.py")).read())

0 commit comments

Comments
 (0)