Skip to content

Commit fcb816f

Browse files
committed
Add 046L/M deploy-time quant repair (AR self-gen + passthrough fit)
Spec 046L (deploy-time) / 046M (bake-it-in variant). Adds: - _generate_ar_batch_for_repair(): AR sampling from BOS, returns (input_ids, target_ids) shifted-by-1 for next-token prediction - fit_passthrough_on_ar_gen(): full pipeline: 1. Generate N batches of AR text (~1-3 min depending on size) 2. Fit small (numel <=65536) passthrough fp16 params via AdamW on next-token CE loss (~30s-1min) 3. Updated eval_model used for downstream quantized eval New env vars (default OFF): DEPLOY_TIME_REPAIR_ENABLED=1 DEPLOY_TIME_REPAIR_ITERS=5 DEPLOY_TIME_REPAIR_LR=1e-3 DEPLOY_TIME_REPAIR_BATCHES=8 DEPLOY_TIME_REPAIR_AR_SEQ_LEN=512 DEPLOY_TIME_REPAIR_AR_TEMP=1.0 Wired into train_and_eval after deserialize, before compile. SCREEN MODE only: fitted values affect THIS run's quantized eval but do NOT propagate back to artifact. Production mode (re-serialize with fitted passthrough) deferred until screen confirms the lever works. Rules-legal per challenge README: AR self-gen uses no val data; "you're free to evaluate however" / "we encourage competitors to push the bounds of evaluation methods". Default test config: 8 batches × 512 seq × batch_size=8 = ~32K AR tokens. For 64K-token test: set DEPLOY_TIME_REPAIR_BATCHES=16.
1 parent c6e5d3e commit fcb816f

1 file changed

Lines changed: 165 additions & 0 deletions

File tree

  • records/track_10min_16mb/2026-04-19_SP8192_CaseOps_GatedAttn_QuantGate_Loop45_PhasedTTT

records/track_10min_16mb/2026-04-19_SP8192_CaseOps_GatedAttn_QuantGate_Loop45_PhasedTTT/train_gpt.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,22 @@ class Hyperparameters:
412412
gptq_calib_source = os.environ.get("GPTQ_CALIB_SOURCE", "train")
413413
gptq_ar_temp = float(os.environ.get("GPTQ_AR_TEMP", 1.0))
414414
gptq_ar_seq_len = int(os.environ.get("GPTQ_AR_SEQ_LEN", 512))
415+
# Spec 046L — Deploy-time quant repair. Runs AFTER deserialize and BEFORE
416+
# compile/eval. Generates AR self-gen calib data (no val leak), fits
417+
# passthrough fp16 params (attn_scale, mlp_scale, resid_mix, q_gain, ...)
418+
# to minimize next-token CE loss on those AR samples. Bypasses the
419+
# 16,000,000-byte artifact cap by paying eval-time compute (~60s) for
420+
# quant repair instead of artifact bytes. Uses ~60s of the 100-180s of
421+
# leaderboard eval headroom that PR #1797 leaves unused.
422+
# Rules-legal per challenge README: "you're free to evaluate however"
423+
# + "we encourage competitors to push the bounds of evaluation methods".
424+
# No val data accessed (AR self-gen from BOS).
425+
deploy_time_repair_enabled = bool(int(os.environ.get("DEPLOY_TIME_REPAIR_ENABLED", "0")))
426+
deploy_time_repair_iters = int(os.environ.get("DEPLOY_TIME_REPAIR_ITERS", 5))
427+
deploy_time_repair_lr = float(os.environ.get("DEPLOY_TIME_REPAIR_LR", 1e-3))
428+
deploy_time_repair_batches = int(os.environ.get("DEPLOY_TIME_REPAIR_BATCHES", 8))
429+
deploy_time_repair_ar_seq_len = int(os.environ.get("DEPLOY_TIME_REPAIR_AR_SEQ_LEN", 512))
430+
deploy_time_repair_ar_temp = float(os.environ.get("DEPLOY_TIME_REPAIR_AR_TEMP", 1.0))
415431
# Gated Attention (Qwen, NeurIPS 2025 Best Paper, arXiv:2505.06708;
416432
# qiuzh20/gated_attention). Per-head sigmoid gate on SDPA output, BEFORE
417433
# out_proj. Gate input = full block input x (paper's headwise G1 variant
@@ -2722,6 +2738,125 @@ def fit_passthrough_params_to_match_base(eval_model, base_model, h, device):
27222738
p.requires_grad_(False)
27232739

27242740

2741+
@torch.no_grad()
2742+
def _generate_ar_batch_for_repair(model, batch_size, seq_len, vocab_size, temp, device, bos_token_id=0):
2743+
"""
2744+
Generate ONE batch of (input_ids, target_ids) sequences via autoregressive
2745+
sampling from the model itself, starting from BOS. Used by deploy-time
2746+
quant repair (spec 046L/M).
2747+
2748+
Returns:
2749+
x: (batch_size, seq_len) int64 - input token ids
2750+
y: (batch_size, seq_len) int64 - target token ids (shifted by 1)
2751+
"""
2752+
# Generate seq_len + 1 tokens so we can split into (input, target) for next-token prediction
2753+
tokens = torch.full((batch_size, seq_len + 1), bos_token_id, dtype=torch.int64, device=device)
2754+
for t in range(1, seq_len + 1):
2755+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
2756+
logits = model.forward_logits(tokens[:, :t])
2757+
# Take last position's logits, apply temperature, sample
2758+
last = logits[:, -1, :].float() / max(temp, 1e-6)
2759+
probs = F.softmax(last, dim=-1)
2760+
next_tok = torch.multinomial(probs, num_samples=1).squeeze(-1)
2761+
tokens[:, t] = next_tok
2762+
return tokens[:, :seq_len], tokens[:, 1:seq_len + 1]
2763+
2764+
2765+
def fit_passthrough_on_ar_gen(eval_model, h, device):
2766+
"""
2767+
Spec 046L/M — deploy-time quant repair via passthrough param fit on
2768+
AR self-generated text.
2769+
2770+
Generates synthetic next-token-prediction data using the quantized model
2771+
itself (sampling from BOS), then fits the small (numel <=65536) passthrough
2772+
fp16 params on next-token CE loss. Matrix weights stay frozen.
2773+
2774+
Bypasses the 16MB artifact byte cap by paying eval-time compute (~1-3 min)
2775+
for quant repair. Rules-legal: no val data accessed (AR self-gen from BOS).
2776+
2777+
The fitted params live on eval_model in-memory after this returns; caller
2778+
decides whether to (a) just run eval on the fitted model (screen mode) or
2779+
(b) re-serialize the artifact with the fitted values (production mode).
2780+
"""
2781+
log("postquant_fit: starting AR self-gen + passthrough fit")
2782+
seq_len = h.deploy_time_repair_ar_seq_len
2783+
n_batches = h.deploy_time_repair_batches
2784+
# Pick a batch size that fits in memory. For 4xH100 with our 36M model,
2785+
# batch=8 at seq_len=512 is comfortable. Scale by world_size for parallelism.
2786+
batch_size = 8
2787+
2788+
# Generate AR data first (no grad; uses eval_model's quantized weights to sample)
2789+
eval_model.eval()
2790+
t0 = time.perf_counter()
2791+
ar_batches = []
2792+
for bi in range(n_batches):
2793+
x, y = _generate_ar_batch_for_repair(
2794+
eval_model, batch_size, seq_len, h.vocab_size,
2795+
h.deploy_time_repair_ar_temp, device,
2796+
)
2797+
ar_batches.append((x, y))
2798+
if (bi + 1) % 4 == 0 or bi == n_batches - 1:
2799+
log(
2800+
f"postquant_fit:ar_gen progress={bi+1}/{n_batches} "
2801+
f"tokens_so_far={(bi+1)*batch_size*seq_len} "
2802+
f"elapsed={time.perf_counter()-t0:.1f}s"
2803+
)
2804+
log(
2805+
f"postquant_fit:ar_gen done {n_batches*batch_size*seq_len} tokens "
2806+
f"in {time.perf_counter()-t0:.1f}s"
2807+
)
2808+
2809+
# Identify trainable params: small (<=65536 element) floating-point params.
2810+
# These are the passthrough fp16 tensors in our quant artifact.
2811+
fit_params = []
2812+
n_elements = 0
2813+
for name, p in eval_model.named_parameters():
2814+
if p.is_floating_point() and p.numel() <= 65536:
2815+
p.requires_grad_(True)
2816+
fit_params.append((name, p))
2817+
n_elements += p.numel()
2818+
else:
2819+
p.requires_grad_(False)
2820+
log(
2821+
f"postquant_fit: fitting {len(fit_params)} params "
2822+
f"({n_elements} elements) over {h.deploy_time_repair_iters} iters, "
2823+
f"lr={h.deploy_time_repair_lr}"
2824+
)
2825+
2826+
opt = torch.optim.AdamW(
2827+
[p for _, p in fit_params],
2828+
lr=h.deploy_time_repair_lr,
2829+
betas=(0.9, 0.95),
2830+
weight_decay=0.0,
2831+
)
2832+
2833+
eval_model.train() # required for backward
2834+
t0 = time.perf_counter()
2835+
for it in range(h.deploy_time_repair_iters):
2836+
accum_loss = 0.0
2837+
for (x, y) in ar_batches:
2838+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
2839+
logits = eval_model.forward_logits(x)
2840+
loss = F.cross_entropy(
2841+
logits.reshape(-1, logits.size(-1)).float(),
2842+
y.reshape(-1),
2843+
reduction="mean",
2844+
)
2845+
opt.zero_grad()
2846+
loss.backward()
2847+
opt.step()
2848+
accum_loss += loss.item()
2849+
log(
2850+
f"postquant_fit:iter={it+1}/{h.deploy_time_repair_iters} "
2851+
f"avg_ce={accum_loss / len(ar_batches):.6f}"
2852+
)
2853+
log(f"postquant_fit: fit done in {time.perf_counter()-t0:.1f}s")
2854+
2855+
eval_model.eval()
2856+
for _, p in fit_params:
2857+
p.requires_grad_(False)
2858+
2859+
27252860
def collect_hessians(model, train_loader, h, device, n_calibration_batches=64):
27262861
hessians = {}
27272862
hooks = []
@@ -4442,6 +4577,29 @@ def _run_forward_logits_bucket_warmup():
44424577

44434578

44444579
def train_and_eval(h, device):
4580+
# GUARD: SPINQUANT_ENABLED=1 in this code path triggers residual-stream
4581+
# rotation in serialize() (_spinquant_rotate_sd_and_H) + deserialize()
4582+
# (install_spinquant_rotations). Per spec 009 analysis, that variant
4583+
# requires folding per-channel multipliers (attn_scale, mlp_scale,
4584+
# skip_weights, resid_mix) into the rotation — NOT IMPLEMENTED here.
4585+
# Confirmed catastrophic on spec 046C: post-quant val_bpb 7.85 vs ~1.075
4586+
# baseline (+6.78 BPB blow-up).
4587+
#
4588+
# The only working SpinQuant variant on this stack is the per-KV-group
4589+
# attention-internal R_a (d_head=64), which is float-invariant by
4590+
# construction (softmax(QK^T)V is rotation-equivariant in V's d_head
4591+
# axis). That variant lives in spinquant_hotstart.py with
4592+
# SPINQUANT_MODE=internal_only — see runs/009-spinquant-hotstart/
4593+
# internal_only/ and runs/010b-spinquant-sites/ for past results.
4594+
if h.spinquant_enabled:
4595+
raise RuntimeError(
4596+
"SPINQUANT_ENABLED=1 in train_gpt.py uses residual-stream rotation "
4597+
"which requires per-channel multiplier folding (attn_scale, "
4598+
"mlp_scale, skip_weights, resid_mix) — NOT IMPLEMENTED. "
4599+
"Confirmed catastrophic in spec 046C: post-quant val_bpb 7.85 "
4600+
"(+6.78 BPB vs ~1.075 baseline). For working SpinQuant on this "
4601+
"stack, use spinquant_hotstart.py with SPINQUANT_MODE=internal_only."
4602+
)
44454603
random.seed(h.seed)
44464604
np.random.seed(h.seed)
44474605
torch.manual_seed(h.seed)
@@ -4542,6 +4700,13 @@ def train_and_eval(h, device):
45424700
eval_model, base_model, h, device
45434701
)
45444702
log("postquant_lnfit:fit complete")
4703+
# Spec 046L/M — deploy-time quant repair (AR self-gen + passthrough fit).
4704+
# Runs in-memory on the loaded quantized model. SCREEN MODE: fitted values
4705+
# affect this run's quantized eval but DO NOT propagate back to artifact.
4706+
# PRODUCTION MODE (separate todo): re-serialize artifact with fitted passthrough.
4707+
if h.deploy_time_repair_enabled:
4708+
torch._dynamo.reset()
4709+
fit_passthrough_on_ar_gen(eval_model, h, device)
45454710
compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True)
45464711
compiled_forward_logits = torch.compile(
45474712
eval_model.forward_logits, dynamic=False, fullgraph=True

0 commit comments

Comments
 (0)