|
| 1 | +# Record: VarLen Attention + Fused MLP + Multi-Phase Global SGD TTT |
| 2 | + |
| 3 | +**val_bpb: 1.07193** (3-seed mean, std 0.00063) | **2.76890 nats** | **~15.93 MB** | 8xH100 SXM, 596s train + ~331s TTT eval |
| 4 | + |
| 5 | +**Improvement over PR #1530** (@samacqua, 1.07336 BPP): -0.00143 BPP / -0.00370 nats |
| 6 | + |
| 7 | +**Improvement over merged SOTA** (PR #1493, 1.0810 BPP): -0.00907 BPP / -0.02344 nats |
| 8 | + |
| 9 | +## Results (8xH100 80GB SXM, PyTorch 2.9.1+cu128, Phased TTT) |
| 10 | + |
| 11 | +| Seed | Steps | Pre-TTT BPB | **Post-TTT BPB** | TTT gain | TTT time | Artifact | |
| 12 | +|------|-------|-------------|------------------|----------|----------|----------| |
| 13 | +| 42 | 4,971 | 1.08502 | **1.07280** | -0.01222 | 329.0s | 15,932,897 | |
| 14 | +| 0 | 4,967 | 1.08392 | **1.07134** | -0.01258 | 332.1s | 15,939,841 | |
| 15 | +| 1234 | 4,977 | 1.08517 | **1.07164** | -0.01353 | 332.8s | 15,932,419 | |
| 16 | +| **Mean** | | | **1.07193** | -0.01278 | | | |
| 17 | + |
| 18 | +### Supplemental Diagnostics |
| 19 | + |
| 20 | +| Seed | Pre-EMA BPB | Post-EMA BPB | Post-Quant BPB | Post-TTT BPB | val_loss (nats) | Code size | Total | Train time | Eval time | |
| 21 | +|------|-------------|-------------|----------------|-------------|-----------------|-----------|-------|------------|-----------| |
| 22 | +| 42 | 1.0733 | 1.07257 | 1.08502 | 1.07280 | 2.77116 | 122,168 | 15,932,897 | 596.1s | 329.0s | |
| 23 | +| 0 | 1.0723 | 1.07108 | 1.08392 | 1.07134 | 2.76739 | 122,168 | 15,939,841 | 596.1s | 332.1s | |
| 24 | +| 1234 | 1.0713 | 1.07174 | 1.08517 | 1.07164 | 2.76815 | 122,168 | 15,932,419 | 596.2s | 332.8s | |
| 25 | + |
| 26 | +## Key Innovation: Multi-Phase Global SGD |
| 27 | + |
| 28 | +This submission introduces **multi-phase global SGD** during phased TTT evaluation. While PR #1610 (@romeerp) introduced single-phase global SGD (score prefix docs, run one round of SGD, score suffix), we extend this to **N phases** with interleaved scoring and adaptation: |
| 29 | + |
| 30 | +1. Split 2000 prefix docs into 3 equal chunks (~666 docs each) |
| 31 | +2. Score chunk 1 with base model (score-before-update) |
| 32 | +3. Run distributed SGD on scored chunk 1 |
| 33 | +4. Score chunk 2 with improved model |
| 34 | +5. Run SGD on scored chunks 1+2 |
| 35 | +6. Score chunk 3 with further improved model |
| 36 | +7. Run SGD on all scored prefix docs |
| 37 | +8. Score remaining 48,000 suffix docs with fully adapted model |
| 38 | + |
| 39 | +This progressively improves the base model through multiple adaptation rounds while maintaining strict score-before-update legality. Each phase scores new tokens BEFORE any SGD update uses them. |
| 40 | + |
| 41 | +```python |
| 42 | +# Key code (simplified) |
| 43 | +for phase_idx in range(num_phases): |
| 44 | + boundary = boundaries[phase_idx] # [666, 1333, 2000] |
| 45 | + # Score docs from previous boundary to this boundary |
| 46 | + for doc in docs[prev_boundary:boundary]: |
| 47 | + score(doc) # score-first, no adaptation yet |
| 48 | + # SGD on ALL scored docs so far |
| 49 | + global_sgd(scored_docs[:boundary]) |
| 50 | +# Score remaining 48000 suffix docs with adapted model |
| 51 | +for doc in suffix_docs: |
| 52 | + score(doc) |
| 53 | +``` |
| 54 | + |
| 55 | +**3-phase gives -0.00081 BPP over 1-phase** (1.07190 vs 1.07271, same seed). More phases (6+) cause overfitting on small subsets. |
| 56 | + |
| 57 | +## Changes from PR #1530 Baseline |
| 58 | + |
| 59 | +| Change | Source | Effect | |
| 60 | +|--------|--------|--------| |
| 61 | +| Multi-phase global SGD (3-phase) | **Novel (this work)** | -0.0008 BPP eval-time | |
| 62 | +| Trimmed GPTQ (reserve=4s, calib=16) | PR #1586 (@dexhunter) | -0.0013 BPP, +72 training steps | |
| 63 | +| MATRIX_LR=0.026 | PR #1586 (@dexhunter) | -0.0003 BPP (sharp optimum) | |
| 64 | +| Per-layer adaptive GPTQ clip (MLP=12, Attn=13, Emb=15) | PR #1586 (@dexhunter) | Better quant-vs-bytes tradeoff | |
| 65 | +| int7 embeddings (EMBED_BITS=7) | PR #1586 (@dexhunter) | -530 KB artifact, ~0 BPP cost | |
| 66 | +| WARMDOWN_FRAC=0.75 | PR #1560 (@dexhunter) | More warmdown iterations | |
| 67 | +| Dead code removal | This work | -1.9 KB compressed code size | |
| 68 | + |
| 69 | +## Architecture |
| 70 | + |
| 71 | +| Component | Setting | Source | |
| 72 | +|-----------|---------|--------| |
| 73 | +| Layers | 11 (512d, 8 heads, 4 KV heads) | Baseline | |
| 74 | +| MLP | 4x (2048) with LeakyReLU(0.5)^2, Triton fused | PR #1530 @samacqua | |
| 75 | +| Attention | VarLen (flash_attn_varlen_func), causal | PR #1530 @samacqua | |
| 76 | +| Recurrence | 3-layer loop (L3-5), encoder+decoder | PR #1523 @EthanYangTW | |
| 77 | +| Skip connections | U-Net encoder-decoder | Baseline | |
| 78 | +| RoPE | Partial (16/64 dims) | Baseline | |
| 79 | +| Optimizer | Muon (momentum=0.97) + AdamW | PR #1530 @samacqua | |
| 80 | +| EMA | Decay 0.9965 | Baseline | |
| 81 | +| Quantization | Full Hessian GPTQ int6 + int7 embeddings | PR #1530, enhanced | |
| 82 | +| Compression | Brotli quality=11 | Baseline | |
| 83 | +| TTT | Phased LoRA TTT with multi-phase global SGD | **This work** + PR #1530 + PR #1610 | |
| 84 | + |
| 85 | +## Rule Compliance |
| 86 | + |
| 87 | +- **Condition 1 (Causal):** All attention uses `causal=True`. No future token leakage. |
| 88 | +- **Condition 2 (Normalized):** All scoring uses `F.cross_entropy` (full softmax over vocabulary). |
| 89 | +- **Condition 3 (Score-before-update):** Prefix docs are scored BEFORE any global SGD update. Each phase scores new docs first, then runs SGD on already-scored data only. |
| 90 | +- **Condition 4 (Single pass):** Single left-to-right pass over validation data. No rescoring. |
| 91 | +- **No val data during training:** Training uses only fineweb train shards. |
| 92 | +- **Full validation split:** All fineweb_val shards loaded via sorted glob. |
| 93 | +- **Byte accounting:** Tokenizer-derived byte counts including boundary/leading-space handling. |
| 94 | + |
| 95 | +## Requirements |
| 96 | + |
| 97 | +Python >= 3.12 (PEP 701 f-strings). Flash Attention 3 (Hopper) required. |
| 98 | + |
| 99 | +```bash |
| 100 | +pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291 |
| 101 | +pip install sentencepiece brotli |
| 102 | +``` |
| 103 | + |
| 104 | +## Run Command |
| 105 | + |
| 106 | +```bash |
| 107 | +for seed in 42 0 1234; do |
| 108 | + NCCL_NET=Socket SEED=$seed \ |
| 109 | + PHASED_TTT_ENABLED=1 PHASED_TTT_PREFIX_DOCS=2000 PHASED_TTT_NUM_PHASES=3 \ |
| 110 | + MLP_CLIP_SIGMAS=12.0 ATTN_CLIP_SIGMAS=13.0 EMBED_BITS=7 EMBED_CLIP_SIGMAS=15.0 \ |
| 111 | + MATRIX_LR=0.026 GPTQ_RESERVE_SECONDS=4 GPTQ_CALIBRATION_BATCHES=16 \ |
| 112 | + torchrun --standalone --nproc_per_node=8 train_gpt.py \ |
| 113 | + > train_seed${seed}.log 2>&1 |
| 114 | +done |
| 115 | +``` |
| 116 | + |
| 117 | +## Lineage |
| 118 | + |
| 119 | +``` |
| 120 | +PR #1493 (Merged SOTA, 1.0810) by @bigbag |
| 121 | + -> PR #1523 (1.0778) by @EthanYangTW — triple recurrence, parameter banking |
| 122 | + -> PR #1530 (1.07336) by @samacqua — varlen attention, fused MLP, doc-TTT |
| 123 | + -> PR #1610 (1.07281) by @romeerp — phased TTT (single-phase global SGD) |
| 124 | + -> This work (1.07193) adds: |
| 125 | + +-- Multi-phase global SGD (3-phase, novel) |
| 126 | + +-- Trimmed GPTQ (reserve=4s, calib=16) |
| 127 | + +-- MATRIX_LR=0.026 (sharp optimum) |
| 128 | + +-- Per-layer adaptive GPTQ clip |
| 129 | + +-- int7 embeddings |
| 130 | + +-- Dead code removal |
| 131 | +``` |
| 132 | + |
| 133 | +## Credits |
| 134 | + |
| 135 | +- @samacqua — PR #1530 base (VarLen attention, fused MLP, doc-TTT) |
| 136 | +- @romeerp — PR #1610 phased TTT concept (single-phase global SGD) |
| 137 | +- @EthanYangTW — PR #1523 triple recurrence, parameter banking |
| 138 | +- @bigbag — PR #1493 merged SOTA baseline |
| 139 | +- @abaybektursun — PR #549 legal TTT framework |
| 140 | + |
| 141 | +## Included Files |
| 142 | + |
| 143 | +- `train_gpt.py` — Complete training + eval script (122,168 bytes) |
| 144 | +- `submission.json` — Metadata |
| 145 | +- `train_seed42.log`, `train_seed0.log`, `train_seed1234.log` — Full seed logs |
0 commit comments