Skip to content

Commit 66e57bf

Browse files
leon2k2k2kclaude
andcommitted
spec 013: BigramHash auxiliary embedding (port openai#1716)
110 LOC pure addition to train_gpt.py, fully env-gated by BIGRAM_HASH_ENABLED=0/1. Default-off invariant: with env unset the forward pass, state_dict, and optimizer param list are byte-identical to baseline. Components: - BigramHashEmbedding(nn.Module): embed(buckets, dim) + CastedLinear proj(dim, model_dim). proj._zero_init=True -> identity at step 0. Hash: ((prime_a * curr) ^ (prime_b * prev)) % buckets. Position-0 fallback: prev = curr (self-bigram). Cross-doc leakage not special cased, matching openai#1736's SmearGate convention. - GPT.__init__: creates self.bigram_embed when enabled else None. - forward_logits + forward_ttt: additive merge of bigram(input_ids) to tok_emb(input_ids) before SmearGate. attr-guarded. - Optimizers: embed.weight -> AdamW optimizer_tok (embed_wd), proj.weight -> Muon matrix_params. - GPTQ hessian hooks: bigram_embed.embed output -> (dim,dim) hessian; bigram_embed.proj input -> (dim,dim) hessian (proj is <=65536 numel so fp16 passthrough; harmless hook). - Startup log line echoing config. Sizing: 16384*32 int6 embed ~= 393KB. 512*32 fp16 proj = 32KB. Total ~425KB added to artifact; budget dry-run needed before launch. Env vars (defaults): BIGRAM_HASH_ENABLED=0, BIGRAM_HASH_BUCKETS=16384, BIGRAM_HASH_DIM=32, BIGRAM_HASH_PRIME_A=36313, BIGRAM_HASH_PRIME_B=27191. Bug lesson learned from exp/training-bundle commit 8d54854: when Edit's old_string only captures part of a for-loop body, trailing loop statements get pushed outside the loop and may be absorbed by nearby conditional blocks. This patch is a pure prepend/append style (no splits of existing blocks) so that failure mode is avoided. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2b4d12e commit 66e57bf

1 file changed

Lines changed: 110 additions & 0 deletions

File tree

  • records/track_10min_16mb/2026-04-19_SP8192_CaseOps_GatedAttn_QuantGate_Loop45_PhasedTTT

records/track_10min_16mb/2026-04-19_SP8192_CaseOps_GatedAttn_QuantGate_Loop45_PhasedTTT/train_gpt.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,15 @@ class Hyperparameters:
117117
smear_gate_enabled = bool(int(os.environ.get("SMEAR_GATE_ENABLED", "0")))
118118
# Window: first GATE_WINDOW dims of the source feed the gate projection.
119119
gate_window = int(os.environ.get("GATE_WINDOW", 12))
120+
# Spec 013 BigramHash (port #1716 himanshudongre). Auxiliary (buckets, d)
121+
# embedding keyed by hash(prev_token, curr_token), projected to model_dim
122+
# and added to tok_emb pre-block-0. Zero-init projection -> byte-identical
123+
# to baseline at init. Default off.
124+
bigram_hash_enabled = bool(int(os.environ.get("BIGRAM_HASH_ENABLED", "0")))
125+
bigram_hash_buckets = int(os.environ.get("BIGRAM_HASH_BUCKETS", 16384))
126+
bigram_hash_dim = int(os.environ.get("BIGRAM_HASH_DIM", 32))
127+
bigram_hash_prime_a = int(os.environ.get("BIGRAM_HASH_PRIME_A", 36313))
128+
bigram_hash_prime_b = int(os.environ.get("BIGRAM_HASH_PRIME_B", 27191))
120129
# Gated Attention (Qwen, NeurIPS 2025 Best Paper, arXiv:2505.06708;
121130
# qiuzh20/gated_attention). Per-head sigmoid gate on SDPA output, BEFORE
122131
# out_proj. Gate input = full block input x (paper's headwise G1 variant
@@ -922,6 +931,42 @@ def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=None, ma
922931
] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w)
923932
return x_out
924933

934+
935+
class BigramHashEmbedding(nn.Module):
936+
"""Spec 013 BigramHash (port #1716).
937+
938+
Hash-keyed auxiliary embedding table. Hash of (prev_token, curr_token)
939+
selects a row from `embed`; row is projected to model_dim by `proj`
940+
and added to tok_emb before block 0.
941+
942+
- `embed.weight` shape (buckets, dim); quantized at matrix_bits via GPTQ
943+
using a hook-collected (dim, dim) hessian on the embedding output.
944+
- `proj.weight` shape (model_dim, dim); zero-init so the module is
945+
identity at step 0 (output is exactly tok_emb[input_ids]).
946+
947+
Position-0 handling: fallback `prev = curr` (self-bigram). Cross-document
948+
leakage via cu_seqlens is NOT special-cased, matching #1736's SmearGate
949+
convention (which also uses ids[:-1] without doc-boundary sentinels).
950+
"""
951+
952+
def __init__(self, buckets, dim, model_dim, prime_a, prime_b):
953+
super().__init__()
954+
self.buckets = buckets
955+
self.prime_a = prime_a
956+
self.prime_b = prime_b
957+
self.embed = nn.Embedding(buckets, dim)
958+
self.proj = CastedLinear(dim, model_dim, bias=False)
959+
self.proj._zero_init = True # baseline-preserving at init
960+
nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)
961+
962+
def forward(self, input_ids):
963+
prev = torch.cat([input_ids[:, :1], input_ids[:, :-1]], dim=1).to(torch.long)
964+
curr = input_ids.to(torch.long)
965+
hash_ids = ((self.prime_a * curr) ^ (self.prime_b * prev)) % self.buckets
966+
h = self.embed(hash_ids)
967+
return self.proj(h)
968+
969+
925970
class GPT(nn.Module):
926971
def __init__(self, h):
927972
super().__init__()
@@ -1029,6 +1074,18 @@ def __init__(self, h):
10291074
self.smear_gate = CastedLinear(self.smear_window, 1, bias=False)
10301075
self.smear_gate._zero_init = True
10311076
self.smear_lambda = nn.Parameter(torch.zeros(1, dtype=torch.float32))
1077+
# Spec 013 BigramHash (port #1716). None-gated so default=off is byte-identical.
1078+
self.bigram_hash_enabled = h.bigram_hash_enabled
1079+
if self.bigram_hash_enabled:
1080+
self.bigram_embed = BigramHashEmbedding(
1081+
buckets=h.bigram_hash_buckets,
1082+
dim=h.bigram_hash_dim,
1083+
model_dim=h.model_dim,
1084+
prime_a=h.bigram_hash_prime_a,
1085+
prime_b=h.bigram_hash_prime_b,
1086+
)
1087+
else:
1088+
self.bigram_embed = None
10321089
self._init_weights()
10331090

10341091
def _init_weights(self):
@@ -1103,6 +1160,11 @@ def _final_parallel_hidden(self, lane0, lane1):
11031160

11041161
def forward_logits(self, input_ids, cu_seqlens=None, max_seqlen=0):
11051162
x = self.tok_emb(input_ids)
1163+
# Spec 013 BigramHash: add hash(prev, curr) embedding additively. proj is
1164+
# zero-init so this is identity at step 0. Attr-guarded for byte-identical
1165+
# default behavior.
1166+
if self.bigram_embed is not None:
1167+
x = x + self.bigram_embed(input_ids).to(dtype=x.dtype)
11061168
# SmearGate (PR #1667). Inline gate compute with .contiguous() on the slice fed
11071169
# to the projection so torch.compile fullgraph is happy. lam=0 + W=0 -> identity
11081170
# at init. This block runs unconditionally on the smear path; the cat keeps
@@ -1186,6 +1248,9 @@ def forward(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0):
11861248

11871249
def forward_ttt(self, input_ids, target_ids, lora):
11881250
x = self.tok_emb(input_ids)
1251+
# Spec 013 BigramHash: TTT path parallel to forward_logits.
1252+
if self.bigram_embed is not None:
1253+
x = x + self.bigram_embed(input_ids).to(dtype=x.dtype)
11891254
# SmearGate on the TTT path — same inline compute as forward_logits.
11901255
if self.smear_gate_enabled:
11911256
sl = self.smear_lambda.to(dtype=x.dtype)
@@ -1666,10 +1731,24 @@ def __init__(self, h, base_model):
16661731
if getattr(base_model, "smear_gate_enabled", False):
16671732
scalar_params.append(base_model.smear_gate.weight)
16681733
scalar_params.append(base_model.smear_lambda)
1734+
# Spec 013 BigramHash: embed table -> AdamW with embed_wd (sparse-gradient
1735+
# pattern like tok_emb); proj weight -> Muon (standard matrix). Only
1736+
# appended when enabled so the optimizer param list is identical to
1737+
# baseline when BIGRAM_HASH_ENABLED=0.
1738+
if getattr(base_model, "bigram_hash_enabled", False):
1739+
matrix_params.append(base_model.bigram_embed.proj.weight)
16691740
token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr
16701741
tok_params = [
16711742
{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}
16721743
]
1744+
if getattr(base_model, "bigram_hash_enabled", False):
1745+
tok_params.append(
1746+
{
1747+
"params": [base_model.bigram_embed.embed.weight],
1748+
"lr": token_lr,
1749+
"base_lr": token_lr,
1750+
}
1751+
)
16731752
self.optimizer_tok = torch.optim.AdamW(
16741753
tok_params,
16751754
betas=(h.beta1, h.beta2),
@@ -1865,6 +1944,32 @@ def hook_fn(module, inp, out):
18651944
hooks.append(
18661945
hook_module.register_forward_hook(make_output_hook("tok_emb.weight"))
18671946
)
1947+
# Spec 013 BigramHash: collect (dim, dim) hessian from embed OUTPUT during
1948+
# calibration. GPTQ on the weight of shape (buckets, dim) treats dim as cols;
1949+
# covariance over output rows is the activation-weighted column importance.
1950+
if getattr(model, "bigram_embed", None) is not None:
1951+
def _make_bigram_embed_hook(name):
1952+
def hook_fn(module, inp, out):
1953+
x = out.detach().float()
1954+
if x.ndim == 3:
1955+
x = x.reshape(-1, x.shape[-1])
1956+
if name not in hessians:
1957+
hessians[name] = torch.zeros(
1958+
x.shape[1], x.shape[1], dtype=torch.float32, device=device
1959+
)
1960+
hessians[name].addmm_(x.T, x)
1961+
return hook_fn
1962+
1963+
hooks.append(
1964+
model.bigram_embed.embed.register_forward_hook(
1965+
_make_bigram_embed_hook("bigram_embed.embed.weight")
1966+
)
1967+
)
1968+
hooks.append(
1969+
model.bigram_embed.proj.register_forward_hook(
1970+
make_linear_input_hook("bigram_embed.proj.weight")
1971+
)
1972+
)
18681973
model.eval()
18691974
with torch.no_grad():
18701975
for _ in range(n_calibration_batches):
@@ -3035,6 +3140,11 @@ def train_model(h, device, val_data):
30353140
)
30363141
model = compiled_model
30373142
log(f"model_params:{sum(p.numel()for p in base_model.parameters())}")
3143+
log(
3144+
f"bigram_hash: enabled={h.bigram_hash_enabled} "
3145+
f"buckets={h.bigram_hash_buckets} dim={h.bigram_hash_dim} "
3146+
f"primes=({h.bigram_hash_prime_a},{h.bigram_hash_prime_b})"
3147+
)
30383148
optimizers = Optimizers(h, base_model)
30393149
train_loader = DocumentPackingLoader(h, device)
30403150
max_wallclock_ms = (

0 commit comments

Comments
 (0)