Skip to content

Commit 5c8e045

Browse files
authored
Merge PR #1626: Record: VarLen Attention + Fused MLP + Multi-Phase Global SGD TTT — val_bpb 1.07193 (3-seed mean)
Merge accepted Parameter Golf record submission #1626.
2 parents 96d3c34 + 9c5a579 commit 5c8e045

6 files changed

Lines changed: 5383 additions & 0 deletions

File tree

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Record: VarLen Attention + Fused MLP + Multi-Phase Global SGD TTT
2+
3+
**val_bpb: 1.07193** (3-seed mean, std 0.00063) | **2.76890 nats** | **~15.93 MB** | 8xH100 SXM, 596s train + ~331s TTT eval
4+
5+
**Improvement over PR #1530** (@samacqua, 1.07336 BPP): -0.00143 BPP / -0.00370 nats
6+
7+
**Improvement over merged SOTA** (PR #1493, 1.0810 BPP): -0.00907 BPP / -0.02344 nats
8+
9+
## Results (8xH100 80GB SXM, PyTorch 2.9.1+cu128, Phased TTT)
10+
11+
| Seed | Steps | Pre-TTT BPB | **Post-TTT BPB** | TTT gain | TTT time | Artifact |
12+
|------|-------|-------------|------------------|----------|----------|----------|
13+
| 42 | 4,971 | 1.08502 | **1.07280** | -0.01222 | 329.0s | 15,932,897 |
14+
| 0 | 4,967 | 1.08392 | **1.07134** | -0.01258 | 332.1s | 15,939,841 |
15+
| 1234 | 4,977 | 1.08517 | **1.07164** | -0.01353 | 332.8s | 15,932,419 |
16+
| **Mean** | | | **1.07193** | -0.01278 | | |
17+
18+
### Supplemental Diagnostics
19+
20+
| Seed | Pre-EMA BPB | Post-EMA BPB | Post-Quant BPB | Post-TTT BPB | val_loss (nats) | Code size | Total | Train time | Eval time |
21+
|------|-------------|-------------|----------------|-------------|-----------------|-----------|-------|------------|-----------|
22+
| 42 | 1.0733 | 1.07257 | 1.08502 | 1.07280 | 2.77116 | 122,168 | 15,932,897 | 596.1s | 329.0s |
23+
| 0 | 1.0723 | 1.07108 | 1.08392 | 1.07134 | 2.76739 | 122,168 | 15,939,841 | 596.1s | 332.1s |
24+
| 1234 | 1.0713 | 1.07174 | 1.08517 | 1.07164 | 2.76815 | 122,168 | 15,932,419 | 596.2s | 332.8s |
25+
26+
## Key Innovation: Multi-Phase Global SGD
27+
28+
This submission introduces **multi-phase global SGD** during phased TTT evaluation. While PR #1610 (@romeerp) introduced single-phase global SGD (score prefix docs, run one round of SGD, score suffix), we extend this to **N phases** with interleaved scoring and adaptation:
29+
30+
1. Split 2000 prefix docs into 3 equal chunks (~666 docs each)
31+
2. Score chunk 1 with base model (score-before-update)
32+
3. Run distributed SGD on scored chunk 1
33+
4. Score chunk 2 with improved model
34+
5. Run SGD on scored chunks 1+2
35+
6. Score chunk 3 with further improved model
36+
7. Run SGD on all scored prefix docs
37+
8. Score remaining 48,000 suffix docs with fully adapted model
38+
39+
This progressively improves the base model through multiple adaptation rounds while maintaining strict score-before-update legality. Each phase scores new tokens BEFORE any SGD update uses them.
40+
41+
```python
42+
# Key code (simplified)
43+
for phase_idx in range(num_phases):
44+
boundary = boundaries[phase_idx] # [666, 1333, 2000]
45+
# Score docs from previous boundary to this boundary
46+
for doc in docs[prev_boundary:boundary]:
47+
score(doc) # score-first, no adaptation yet
48+
# SGD on ALL scored docs so far
49+
global_sgd(scored_docs[:boundary])
50+
# Score remaining 48000 suffix docs with adapted model
51+
for doc in suffix_docs:
52+
score(doc)
53+
```
54+
55+
**3-phase gives -0.00081 BPP over 1-phase** (1.07190 vs 1.07271, same seed). More phases (6+) cause overfitting on small subsets.
56+
57+
## Changes from PR #1530 Baseline
58+
59+
| Change | Source | Effect |
60+
|--------|--------|--------|
61+
| Multi-phase global SGD (3-phase) | **Novel (this work)** | -0.0008 BPP eval-time |
62+
| Trimmed GPTQ (reserve=4s, calib=16) | PR #1586 (@dexhunter) | -0.0013 BPP, +72 training steps |
63+
| MATRIX_LR=0.026 | PR #1586 (@dexhunter) | -0.0003 BPP (sharp optimum) |
64+
| Per-layer adaptive GPTQ clip (MLP=12, Attn=13, Emb=15) | PR #1586 (@dexhunter) | Better quant-vs-bytes tradeoff |
65+
| int7 embeddings (EMBED_BITS=7) | PR #1586 (@dexhunter) | -530 KB artifact, ~0 BPP cost |
66+
| WARMDOWN_FRAC=0.75 | PR #1560 (@dexhunter) | More warmdown iterations |
67+
| Dead code removal | This work | -1.9 KB compressed code size |
68+
69+
## Architecture
70+
71+
| Component | Setting | Source |
72+
|-----------|---------|--------|
73+
| Layers | 11 (512d, 8 heads, 4 KV heads) | Baseline |
74+
| MLP | 4x (2048) with LeakyReLU(0.5)^2, Triton fused | PR #1530 @samacqua |
75+
| Attention | VarLen (flash_attn_varlen_func), causal | PR #1530 @samacqua |
76+
| Recurrence | 3-layer loop (L3-5), encoder+decoder | PR #1523 @EthanYangTW |
77+
| Skip connections | U-Net encoder-decoder | Baseline |
78+
| RoPE | Partial (16/64 dims) | Baseline |
79+
| Optimizer | Muon (momentum=0.97) + AdamW | PR #1530 @samacqua |
80+
| EMA | Decay 0.9965 | Baseline |
81+
| Quantization | Full Hessian GPTQ int6 + int7 embeddings | PR #1530, enhanced |
82+
| Compression | Brotli quality=11 | Baseline |
83+
| TTT | Phased LoRA TTT with multi-phase global SGD | **This work** + PR #1530 + PR #1610 |
84+
85+
## Rule Compliance
86+
87+
- **Condition 1 (Causal):** All attention uses `causal=True`. No future token leakage.
88+
- **Condition 2 (Normalized):** All scoring uses `F.cross_entropy` (full softmax over vocabulary).
89+
- **Condition 3 (Score-before-update):** Prefix docs are scored BEFORE any global SGD update. Each phase scores new docs first, then runs SGD on already-scored data only.
90+
- **Condition 4 (Single pass):** Single left-to-right pass over validation data. No rescoring.
91+
- **No val data during training:** Training uses only fineweb train shards.
92+
- **Full validation split:** All fineweb_val shards loaded via sorted glob.
93+
- **Byte accounting:** Tokenizer-derived byte counts including boundary/leading-space handling.
94+
95+
## Requirements
96+
97+
Python >= 3.12 (PEP 701 f-strings). Flash Attention 3 (Hopper) required.
98+
99+
```bash
100+
pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291
101+
pip install sentencepiece brotli
102+
```
103+
104+
## Run Command
105+
106+
```bash
107+
for seed in 42 0 1234; do
108+
NCCL_NET=Socket SEED=$seed \
109+
PHASED_TTT_ENABLED=1 PHASED_TTT_PREFIX_DOCS=2000 PHASED_TTT_NUM_PHASES=3 \
110+
MLP_CLIP_SIGMAS=12.0 ATTN_CLIP_SIGMAS=13.0 EMBED_BITS=7 EMBED_CLIP_SIGMAS=15.0 \
111+
MATRIX_LR=0.026 GPTQ_RESERVE_SECONDS=4 GPTQ_CALIBRATION_BATCHES=16 \
112+
torchrun --standalone --nproc_per_node=8 train_gpt.py \
113+
> train_seed${seed}.log 2>&1
114+
done
115+
```
116+
117+
## Lineage
118+
119+
```
120+
PR #1493 (Merged SOTA, 1.0810) by @bigbag
121+
-> PR #1523 (1.0778) by @EthanYangTW — triple recurrence, parameter banking
122+
-> PR #1530 (1.07336) by @samacqua — varlen attention, fused MLP, doc-TTT
123+
-> PR #1610 (1.07281) by @romeerp — phased TTT (single-phase global SGD)
124+
-> This work (1.07193) adds:
125+
+-- Multi-phase global SGD (3-phase, novel)
126+
+-- Trimmed GPTQ (reserve=4s, calib=16)
127+
+-- MATRIX_LR=0.026 (sharp optimum)
128+
+-- Per-layer adaptive GPTQ clip
129+
+-- int7 embeddings
130+
+-- Dead code removal
131+
```
132+
133+
## Credits
134+
135+
- @samacqua — PR #1530 base (VarLen attention, fused MLP, doc-TTT)
136+
- @romeerp — PR #1610 phased TTT concept (single-phase global SGD)
137+
- @EthanYangTW — PR #1523 triple recurrence, parameter banking
138+
- @bigbag — PR #1493 merged SOTA baseline
139+
- @abaybektursun — PR #549 legal TTT framework
140+
141+
## Included Files
142+
143+
- `train_gpt.py` — Complete training + eval script (122,168 bytes)
144+
- `submission.json` — Metadata
145+
- `train_seed42.log`, `train_seed0.log`, `train_seed1234.log` — Full seed logs
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
{
2+
"author": "dexhunter",
3+
"github_id": "dexhunter",
4+
"name": "VarLen Attention + Fused MLP + Multi-Phase Global SGD TTT + Trimmed GPTQ + MLR 0.026",
5+
"date": "2026-04-14",
6+
"track": "10min_16mb",
7+
"val_bpb": 1.07193,
8+
"val_bpb_std": 0.00063,
9+
"seeds": [42, 0, 1234],
10+
"seed_results": {
11+
"42": {"val_bpb": 1.07280, "val_loss": 2.77116, "artifact_bytes": 15932897},
12+
"0": {"val_bpb": 1.07134, "val_loss": 2.76739, "artifact_bytes": 15939841},
13+
"1234": {"val_bpb": 1.07164, "val_loss": 2.76815, "artifact_bytes": 15932419}
14+
},
15+
"hardware": "8xH100 80GB SXM",
16+
"pytorch_version": "2.9.1+cu128",
17+
"technique_summary": "VarLen Attention + Triton Fused MLP + Multi-Phase Global SGD (3-phase) during Phased TTT + Trimmed GPTQ (reserve=4s, calib=16) + int7 Embeddings + Per-Layer Adaptive Clip + MLR 0.026 + Warmdown 0.75 + Brotli-11",
18+
"compliance": {
19+
"train_under_600s": true,
20+
"artifact_under_16mb": true,
21+
"eval_under_600s": true,
22+
"no_slot": true,
23+
"no_pre_quant_ttt": true,
24+
"no_etlb": true,
25+
"no_ngram_cache": true,
26+
"score_first_ttt": true,
27+
"three_seeds": true
28+
},
29+
"attribution": {
30+
"varlen_attention_fused_mlp_doc_ttt": "@samacqua (PR #1530)",
31+
"phased_ttt_concept": "@romeerp (PR #1610)",
32+
"triple_recurrence_parallel_residuals": "@bigbag (PR #1493), @EthanYangTW (PR #1523)",
33+
"trimmed_gptq_mlr026": "@dexhunter (PR #1586)",
34+
"legal_ttt_framework": "@abaybektursun (PR #549)"
35+
}
36+
}

0 commit comments

Comments
 (0)