Skip to content

Commit 4908dc4

Browse files
authored
Merge pull request #1443 from hardik-bhadani-git/bytejepa-submission
Non-record: ByteJEPA — True Byte-Level JEPA (val_bpb 1.3496)
2 parents e159eec + 9ce181e commit 4908dc4

5 files changed

Lines changed: 2695 additions & 0 deletions

File tree

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# True Byte-Level JEPA for Parameter Golf
2+
3+
Hey there! This submission is a direct answer to the open **JEPA [bounty/PR request]** in the challenge.
4+
5+
## Why this is different
6+
Every other submission on the leaderboard (whether it's testing new layers, distillation, or quantization tricks) still trains the model the normal way: by predicting the next token using Cross-Entropy loss.
7+
8+
But a true JEPA (Joint-Embedding Predictive Architecture) is fundamentally different. It doesn't guess the next token. Instead, it tries to predict the *abstract concept* of what comes next, operating entirely in the model's hidden layers.
9+
10+
Because a pure JEPA model doesn't guess specific characters, it's mathematically impossible for it to output the Bits Per Byte (BPB) score the challenge requires.
11+
12+
## How we solved the JEPA vs. BPB problem
13+
To give the challenge the BPB score it needs without ruining the pure JEPA architecture, we use the standard "Two-Phase" approach from the official Meta JEPA papers:
14+
15+
1. **Phase 1: Pure JEPA Pretraining (70% of the 10-minute clock)**
16+
- The model learns entirely by predicting its own future hidden states.
17+
- We use a trick called SIGReg to keep the model from cheating and outputting the same vector every time.
18+
- During this phase, the model doesn't even try to guess bytes. It's just learning the raw structure of the language.
19+
20+
2. **Phase 2: Supervised Fine-Tuning (30% of the 10-minute clock)**
21+
- At the 7-minute mark, the script automatically shifts gears.
22+
- It attaches a simple translation layer on top of the model and fine-tunes the network to map its abstract thoughts into actual byte probabilities.
23+
- This lets us cleanly output the rigorous `val_bpb` metric you see in the leaderboard.
24+
25+
## Under the Hood
26+
- **No Tokenizer**: We dumped the tokenizer completely. The model reads raw UTF-8 bytes (`vocab_size=256`). This forces the JEPA to learn complex word boundaries blindly from scratch—a true test of its representation power.
27+
- **Predictor Network**: Added a 2-Layer MLP that tries to guess the target's hidden state. We funded the parameter cost of this network by saving space on the disabled tokenizer embeddings!
28+
- **Constraint Safe**: Our code tracks the 16MB file limits natively.
29+
30+
It's worth noting that because byte-level models have a shorter effective "memory window" than tokenized ones in a 10-minute race, this won't shatter the top of the leaderboard purely on BPB. But as the very first submission to successfully change the core learning objective from token-prediction to representation-prediction, we hope it hits the "weird & creative" mark for the non-record track!
31+
32+
## Run Command
33+
34+
You'll need the byte-level data shards (script included).
35+
36+
```bash
37+
# 1. Convert the standard SP1024 shards into raw Bytes
38+
DATA_PATH_BASE="../../data" python transpile_to_bytes.py
39+
40+
# 2. Run the Two-Phase JEPA trainer
41+
DATA_PATH=./data/datasets/fineweb10B_bytes \
42+
VOCAB_SIZE=256 \
43+
JEPA_PRETRAIN_FRAC=0.7 \
44+
torchrun --standalone --nproc_per_node=8 train_gpt.py
45+
```
46+
47+
## Challenge Metrics
48+
49+
1x RTX 5090, 3600s wallclock, 4141 steps:
50+
- Pre-quant: `val_loss:0.9308 val_bpb:1.3429`
51+
- Post-quant (int8+zlib roundtrip): `val_loss:0.9355 val_bpb:1.3496`
52+
- Serialized model int8+zlib: `15080959 bytes`
53+
- Code size: `51873 bytes`
54+
- Total submission size int8+zlib: `15132832 bytes`

0 commit comments

Comments
 (0)