Skip to content

Commit 8d54854

Browse files
leon2k2k2kclaude
andcommitted
spec 012: training-bundle patch — tapered WD + GradPower + per-layer QK
Four new env-gated hyperparameters, all default to no-op so spec 008 is byte-identical when the vars are unset: - WD_TAPER_START_FRAC / WD_TAPER_FINAL_MULT (port openai#1729): linear Muon WD taper from 1.0 at start_step to final_mult at h.iterations. Applied in step_fn before optimizers.step. Adam/embed WD untouched per openai#1729. - MUON_GRAD_POWER (port openai#1682): g = sign(g) * |g|^p, applied to Muon gradients just before the momentum buffer update. Covers both sharded (shard path) and non-sharded paths. - QK_GAIN_INIT (existing): already present, lowering default not changed; setting QK_GAIN_INIT=2.5 at runtime gives uniform softer attention per openai#1648's convergence finding. - QK_GAIN_PER_LAYER (new): comma-sep list, overrides each block's attn.q_gain after block construction. Validated to match num_layers. Also: one startup log line echoing the four values for post-hoc verification. Spec: research/specs/012-training-bundle.md. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 5123db7 commit 8d54854

1 file changed

Lines changed: 42 additions & 0 deletions

File tree

  • records/track_10min_16mb/2026-04-19_SP8192_CaseOps_GatedAttn_QuantGate_Loop45_PhasedTTT

records/track_10min_16mb/2026-04-19_SP8192_CaseOps_GatedAttn_QuantGate_Loop45_PhasedTTT/train_gpt.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ class Hyperparameters:
6969
adam_wd = float(os.environ.get("ADAM_WD", 0.02))
7070
muon_wd = float(os.environ.get("MUON_WD", 0.095))
7171
embed_wd = float(os.environ.get("EMBED_WD", 0.085))
72+
# Spec 012 training-bundle: all four default to no-op so spec 008 is byte-identical.
73+
wd_taper_start_frac = float(os.environ.get("WD_TAPER_START_FRAC", 0.0))
74+
wd_taper_final_mult = float(os.environ.get("WD_TAPER_FINAL_MULT", 1.0))
75+
muon_grad_power = float(os.environ.get("MUON_GRAD_POWER", 1.0))
76+
qk_gain_per_layer = os.environ.get("QK_GAIN_PER_LAYER", "")
7277
ema_decay = float(os.environ.get("EMA_DECAY", 0.9965))
7378
ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1")))
7479
ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 96))
@@ -967,6 +972,17 @@ def __init__(self, h):
967972
head_dim = h.model_dim // h.num_heads
968973
for block in self.blocks:
969974
block.attn.rope_dims = h.rope_dims
975+
# Spec 012: per-layer QK_GAIN override (port #1648 methodology; uniform softer is cheapest first pass).
976+
# Set QK_GAIN_INIT=<scalar> for uniform override. Set QK_GAIN_PER_LAYER="v0,v1,...,vN-1" for per-layer.
977+
if h.qk_gain_per_layer:
978+
vals = [float(v) for v in h.qk_gain_per_layer.split(",")]
979+
if len(vals) != h.num_layers:
980+
raise ValueError(
981+
f"QK_GAIN_PER_LAYER has {len(vals)} values but num_layers={h.num_layers}"
982+
)
983+
with torch.no_grad():
984+
for block, v in zip(self.blocks, vals):
985+
block.attn.q_gain.data.fill_(v)
970986
block.attn.rotary = Rotary(
971987
head_dim,
972988
base=h.rope_base,
@@ -1595,6 +1611,12 @@ def step(self, closure=None):
15951611
if "momentum_buffer" not in state:
15961612
state["momentum_buffer"] = torch.zeros_like(g)
15971613
buf = state["momentum_buffer"]
1614+
# Spec 012: GradPower port from #1682. Elementwise sign(g)*|g|^p.
1615+
# Default p=1.0 → identity (no-op). Applied pre-momentum, pre-orthogonalization.
1616+
# Covers both sharded (g = m["shard"]) and non-sharded (g = p.grad) paths.
1617+
gp = getattr(self, "grad_power", 1.0)
1618+
if gp != 1.0:
1619+
g = torch.sign(g) * g.abs().pow(gp)
15981620
buf.mul_(momentum).add_(g)
15991621
if nesterov:
16001622
update = g.add(buf, alpha=momentum)
@@ -1685,6 +1707,8 @@ def __init__(self, h, base_model):
16851707
weight_decay=h.muon_wd,
16861708
row_normalize=h.muon_row_normalize,
16871709
)
1710+
# Spec 012: GradPower (port #1682). Read by Muon.step via getattr.
1711+
self.optimizer_muon.grad_power = h.muon_grad_power
16881712
for group in self.optimizer_muon.param_groups:
16891713
group["base_lr"] = h.matrix_lr
16901714
self.optimizer_scalar = torch.optim.AdamW(
@@ -3035,6 +3059,13 @@ def train_model(h, device, val_data):
30353059
)
30363060
model = compiled_model
30373061
log(f"model_params:{sum(p.numel()for p in base_model.parameters())}")
3062+
log(
3063+
f"training_bundle: wd_taper_start_frac={h.wd_taper_start_frac} "
3064+
f"wd_taper_final_mult={h.wd_taper_final_mult} "
3065+
f"muon_grad_power={h.muon_grad_power} "
3066+
f"qk_gain_init={h.qk_gain_init} "
3067+
f"qk_gain_per_layer='{h.qk_gain_per_layer}'"
3068+
)
30383069
optimizers = Optimizers(h, base_model)
30393070
train_loader = DocumentPackingLoader(h, device)
30403071
max_wallclock_ms = (
@@ -3080,6 +3111,17 @@ def step_fn(step, lr_scale):
30803111
) * h.muon_momentum_warmup_start + frac * h.muon_momentum
30813112
for group in optimizers.optimizer_muon.param_groups:
30823113
group["momentum"] = muon_momentum
3114+
# Spec 012: tapered WD (port #1729). Linear from 1.0 at start_step to final_mult at h.iterations.
3115+
# Applied to Muon only (per #1729). Default: start_frac=0 → no-op (group["weight_decay"] untouched).
3116+
if h.wd_taper_start_frac > 0.0:
3117+
start_step = int(h.wd_taper_start_frac * h.iterations)
3118+
if step >= start_step:
3119+
progress = (step - start_step) / max(1, h.iterations - start_step)
3120+
mult = 1.0 - progress * (1.0 - h.wd_taper_final_mult)
3121+
else:
3122+
mult = 1.0
3123+
for group in optimizers.optimizer_muon.param_groups:
3124+
group["weight_decay"] = h.muon_wd * mult
30833125
for opt in optimizers:
30843126
for group in opt.param_groups:
30853127
group["lr"] = group["base_lr"] * lr_scale

0 commit comments

Comments
 (0)