|
| 1 | +# True Byte-Level JEPA for Parameter Golf |
| 2 | + |
| 3 | +Hey there! This submission is a direct answer to the open **JEPA [bounty/PR request]** in the challenge. |
| 4 | + |
| 5 | +## Why this is different |
| 6 | +Every other submission on the leaderboard (whether it's testing new layers, distillation, or quantization tricks) still trains the model the normal way: by predicting the next token using Cross-Entropy loss. |
| 7 | + |
| 8 | +But a true JEPA (Joint-Embedding Predictive Architecture) is fundamentally different. It doesn't guess the next token. Instead, it tries to predict the *abstract concept* of what comes next, operating entirely in the model's hidden layers. |
| 9 | + |
| 10 | +Because a pure JEPA model doesn't guess specific characters, it's mathematically impossible for it to output the Bits Per Byte (BPB) score the challenge requires. |
| 11 | + |
| 12 | +## How we solved the JEPA vs. BPB problem |
| 13 | +To give the challenge the BPB score it needs without ruining the pure JEPA architecture, we use the standard "Two-Phase" approach from the official Meta JEPA papers: |
| 14 | + |
| 15 | +1. **Phase 1: Pure JEPA Pretraining (70% of the 10-minute clock)** |
| 16 | + - The model learns entirely by predicting its own future hidden states. |
| 17 | + - We use a trick called SIGReg to keep the model from cheating and outputting the same vector every time. |
| 18 | + - During this phase, the model doesn't even try to guess bytes. It's just learning the raw structure of the language. |
| 19 | + |
| 20 | +2. **Phase 2: Supervised Fine-Tuning (30% of the 10-minute clock)** |
| 21 | + - At the 7-minute mark, the script automatically shifts gears. |
| 22 | + - It attaches a simple translation layer on top of the model and fine-tunes the network to map its abstract thoughts into actual byte probabilities. |
| 23 | + - This lets us cleanly output the rigorous `val_bpb` metric you see in the leaderboard. |
| 24 | + |
| 25 | +## Under the Hood |
| 26 | +- **No Tokenizer**: We dumped the tokenizer completely. The model reads raw UTF-8 bytes (`vocab_size=256`). This forces the JEPA to learn complex word boundaries blindly from scratch—a true test of its representation power. |
| 27 | +- **Predictor Network**: Added a 2-Layer MLP that tries to guess the target's hidden state. We funded the parameter cost of this network by saving space on the disabled tokenizer embeddings! |
| 28 | +- **Constraint Safe**: Our code tracks the 16MB file limits natively. |
| 29 | + |
| 30 | +It's worth noting that because byte-level models have a shorter effective "memory window" than tokenized ones in a 10-minute race, this won't shatter the top of the leaderboard purely on BPB. But as the very first submission to successfully change the core learning objective from token-prediction to representation-prediction, we hope it hits the "weird & creative" mark for the non-record track! |
| 31 | + |
| 32 | +## Run Command |
| 33 | + |
| 34 | +You'll need the byte-level data shards (script included). |
| 35 | + |
| 36 | +```bash |
| 37 | +# 1. Convert the standard SP1024 shards into raw Bytes |
| 38 | +DATA_PATH_BASE="../../data" python transpile_to_bytes.py |
| 39 | + |
| 40 | +# 2. Run the Two-Phase JEPA trainer |
| 41 | +DATA_PATH=./data/datasets/fineweb10B_bytes \ |
| 42 | +VOCAB_SIZE=256 \ |
| 43 | +JEPA_PRETRAIN_FRAC=0.7 \ |
| 44 | +torchrun --standalone --nproc_per_node=8 train_gpt.py |
| 45 | +``` |
| 46 | + |
| 47 | +## Challenge Metrics |
| 48 | + |
| 49 | +1x RTX 5090, 3600s wallclock, 4141 steps: |
| 50 | +- Pre-quant: `val_loss:0.9308 val_bpb:1.3429` |
| 51 | +- Post-quant (int8+zlib roundtrip): `val_loss:0.9355 val_bpb:1.3496` |
| 52 | +- Serialized model int8+zlib: `15080959 bytes` |
| 53 | +- Code size: `51873 bytes` |
| 54 | +- Total submission size int8+zlib: `15132832 bytes` |
0 commit comments