@@ -412,6 +412,22 @@ class Hyperparameters:
412412 gptq_calib_source = os .environ .get ("GPTQ_CALIB_SOURCE" , "train" )
413413 gptq_ar_temp = float (os .environ .get ("GPTQ_AR_TEMP" , 1.0 ))
414414 gptq_ar_seq_len = int (os .environ .get ("GPTQ_AR_SEQ_LEN" , 512 ))
415+ # Spec 046L — Deploy-time quant repair. Runs AFTER deserialize and BEFORE
416+ # compile/eval. Generates AR self-gen calib data (no val leak), fits
417+ # passthrough fp16 params (attn_scale, mlp_scale, resid_mix, q_gain, ...)
418+ # to minimize next-token CE loss on those AR samples. Bypasses the
419+ # 16,000,000-byte artifact cap by paying eval-time compute (~60s) for
420+ # quant repair instead of artifact bytes. Uses ~60s of the 100-180s of
421+ # leaderboard eval headroom that PR #1797 leaves unused.
422+ # Rules-legal per challenge README: "you're free to evaluate however"
423+ # + "we encourage competitors to push the bounds of evaluation methods".
424+ # No val data accessed (AR self-gen from BOS).
425+ deploy_time_repair_enabled = bool (int (os .environ .get ("DEPLOY_TIME_REPAIR_ENABLED" , "0" )))
426+ deploy_time_repair_iters = int (os .environ .get ("DEPLOY_TIME_REPAIR_ITERS" , 5 ))
427+ deploy_time_repair_lr = float (os .environ .get ("DEPLOY_TIME_REPAIR_LR" , 1e-3 ))
428+ deploy_time_repair_batches = int (os .environ .get ("DEPLOY_TIME_REPAIR_BATCHES" , 8 ))
429+ deploy_time_repair_ar_seq_len = int (os .environ .get ("DEPLOY_TIME_REPAIR_AR_SEQ_LEN" , 512 ))
430+ deploy_time_repair_ar_temp = float (os .environ .get ("DEPLOY_TIME_REPAIR_AR_TEMP" , 1.0 ))
415431 # Gated Attention (Qwen, NeurIPS 2025 Best Paper, arXiv:2505.06708;
416432 # qiuzh20/gated_attention). Per-head sigmoid gate on SDPA output, BEFORE
417433 # out_proj. Gate input = full block input x (paper's headwise G1 variant
@@ -2722,6 +2738,125 @@ def fit_passthrough_params_to_match_base(eval_model, base_model, h, device):
27222738 p .requires_grad_ (False )
27232739
27242740
2741+ @torch .no_grad ()
2742+ def _generate_ar_batch_for_repair (model , batch_size , seq_len , vocab_size , temp , device , bos_token_id = 0 ):
2743+ """
2744+ Generate ONE batch of (input_ids, target_ids) sequences via autoregressive
2745+ sampling from the model itself, starting from BOS. Used by deploy-time
2746+ quant repair (spec 046L/M).
2747+
2748+ Returns:
2749+ x: (batch_size, seq_len) int64 - input token ids
2750+ y: (batch_size, seq_len) int64 - target token ids (shifted by 1)
2751+ """
2752+ # Generate seq_len + 1 tokens so we can split into (input, target) for next-token prediction
2753+ tokens = torch .full ((batch_size , seq_len + 1 ), bos_token_id , dtype = torch .int64 , device = device )
2754+ for t in range (1 , seq_len + 1 ):
2755+ with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
2756+ logits = model .forward_logits (tokens [:, :t ])
2757+ # Take last position's logits, apply temperature, sample
2758+ last = logits [:, - 1 , :].float () / max (temp , 1e-6 )
2759+ probs = F .softmax (last , dim = - 1 )
2760+ next_tok = torch .multinomial (probs , num_samples = 1 ).squeeze (- 1 )
2761+ tokens [:, t ] = next_tok
2762+ return tokens [:, :seq_len ], tokens [:, 1 :seq_len + 1 ]
2763+
2764+
2765+ def fit_passthrough_on_ar_gen (eval_model , h , device ):
2766+ """
2767+ Spec 046L/M — deploy-time quant repair via passthrough param fit on
2768+ AR self-generated text.
2769+
2770+ Generates synthetic next-token-prediction data using the quantized model
2771+ itself (sampling from BOS), then fits the small (numel <=65536) passthrough
2772+ fp16 params on next-token CE loss. Matrix weights stay frozen.
2773+
2774+ Bypasses the 16MB artifact byte cap by paying eval-time compute (~1-3 min)
2775+ for quant repair. Rules-legal: no val data accessed (AR self-gen from BOS).
2776+
2777+ The fitted params live on eval_model in-memory after this returns; caller
2778+ decides whether to (a) just run eval on the fitted model (screen mode) or
2779+ (b) re-serialize the artifact with the fitted values (production mode).
2780+ """
2781+ log ("postquant_fit: starting AR self-gen + passthrough fit" )
2782+ seq_len = h .deploy_time_repair_ar_seq_len
2783+ n_batches = h .deploy_time_repair_batches
2784+ # Pick a batch size that fits in memory. For 4xH100 with our 36M model,
2785+ # batch=8 at seq_len=512 is comfortable. Scale by world_size for parallelism.
2786+ batch_size = 8
2787+
2788+ # Generate AR data first (no grad; uses eval_model's quantized weights to sample)
2789+ eval_model .eval ()
2790+ t0 = time .perf_counter ()
2791+ ar_batches = []
2792+ for bi in range (n_batches ):
2793+ x , y = _generate_ar_batch_for_repair (
2794+ eval_model , batch_size , seq_len , h .vocab_size ,
2795+ h .deploy_time_repair_ar_temp , device ,
2796+ )
2797+ ar_batches .append ((x , y ))
2798+ if (bi + 1 ) % 4 == 0 or bi == n_batches - 1 :
2799+ log (
2800+ f"postquant_fit:ar_gen progress={ bi + 1 } /{ n_batches } "
2801+ f"tokens_so_far={ (bi + 1 )* batch_size * seq_len } "
2802+ f"elapsed={ time .perf_counter ()- t0 :.1f} s"
2803+ )
2804+ log (
2805+ f"postquant_fit:ar_gen done { n_batches * batch_size * seq_len } tokens "
2806+ f"in { time .perf_counter ()- t0 :.1f} s"
2807+ )
2808+
2809+ # Identify trainable params: small (<=65536 element) floating-point params.
2810+ # These are the passthrough fp16 tensors in our quant artifact.
2811+ fit_params = []
2812+ n_elements = 0
2813+ for name , p in eval_model .named_parameters ():
2814+ if p .is_floating_point () and p .numel () <= 65536 :
2815+ p .requires_grad_ (True )
2816+ fit_params .append ((name , p ))
2817+ n_elements += p .numel ()
2818+ else :
2819+ p .requires_grad_ (False )
2820+ log (
2821+ f"postquant_fit: fitting { len (fit_params )} params "
2822+ f"({ n_elements } elements) over { h .deploy_time_repair_iters } iters, "
2823+ f"lr={ h .deploy_time_repair_lr } "
2824+ )
2825+
2826+ opt = torch .optim .AdamW (
2827+ [p for _ , p in fit_params ],
2828+ lr = h .deploy_time_repair_lr ,
2829+ betas = (0.9 , 0.95 ),
2830+ weight_decay = 0.0 ,
2831+ )
2832+
2833+ eval_model .train () # required for backward
2834+ t0 = time .perf_counter ()
2835+ for it in range (h .deploy_time_repair_iters ):
2836+ accum_loss = 0.0
2837+ for (x , y ) in ar_batches :
2838+ with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
2839+ logits = eval_model .forward_logits (x )
2840+ loss = F .cross_entropy (
2841+ logits .reshape (- 1 , logits .size (- 1 )).float (),
2842+ y .reshape (- 1 ),
2843+ reduction = "mean" ,
2844+ )
2845+ opt .zero_grad ()
2846+ loss .backward ()
2847+ opt .step ()
2848+ accum_loss += loss .item ()
2849+ log (
2850+ f"postquant_fit:iter={ it + 1 } /{ h .deploy_time_repair_iters } "
2851+ f"avg_ce={ accum_loss / len (ar_batches ):.6f} "
2852+ )
2853+ log (f"postquant_fit: fit done in { time .perf_counter ()- t0 :.1f} s" )
2854+
2855+ eval_model .eval ()
2856+ for _ , p in fit_params :
2857+ p .requires_grad_ (False )
2858+
2859+
27252860def collect_hessians (model , train_loader , h , device , n_calibration_batches = 64 ):
27262861 hessians = {}
27272862 hooks = []
@@ -4442,6 +4577,29 @@ def _run_forward_logits_bucket_warmup():
44424577
44434578
44444579def train_and_eval (h , device ):
4580+ # GUARD: SPINQUANT_ENABLED=1 in this code path triggers residual-stream
4581+ # rotation in serialize() (_spinquant_rotate_sd_and_H) + deserialize()
4582+ # (install_spinquant_rotations). Per spec 009 analysis, that variant
4583+ # requires folding per-channel multipliers (attn_scale, mlp_scale,
4584+ # skip_weights, resid_mix) into the rotation — NOT IMPLEMENTED here.
4585+ # Confirmed catastrophic on spec 046C: post-quant val_bpb 7.85 vs ~1.075
4586+ # baseline (+6.78 BPB blow-up).
4587+ #
4588+ # The only working SpinQuant variant on this stack is the per-KV-group
4589+ # attention-internal R_a (d_head=64), which is float-invariant by
4590+ # construction (softmax(QK^T)V is rotation-equivariant in V's d_head
4591+ # axis). That variant lives in spinquant_hotstart.py with
4592+ # SPINQUANT_MODE=internal_only — see runs/009-spinquant-hotstart/
4593+ # internal_only/ and runs/010b-spinquant-sites/ for past results.
4594+ if h .spinquant_enabled :
4595+ raise RuntimeError (
4596+ "SPINQUANT_ENABLED=1 in train_gpt.py uses residual-stream rotation "
4597+ "which requires per-channel multiplier folding (attn_scale, "
4598+ "mlp_scale, skip_weights, resid_mix) — NOT IMPLEMENTED. "
4599+ "Confirmed catastrophic in spec 046C: post-quant val_bpb 7.85 "
4600+ "(+6.78 BPB vs ~1.075 baseline). For working SpinQuant on this "
4601+ "stack, use spinquant_hotstart.py with SPINQUANT_MODE=internal_only."
4602+ )
44454603 random .seed (h .seed )
44464604 np .random .seed (h .seed )
44474605 torch .manual_seed (h .seed )
@@ -4542,6 +4700,13 @@ def train_and_eval(h, device):
45424700 eval_model , base_model , h , device
45434701 )
45444702 log ("postquant_lnfit:fit complete" )
4703+ # Spec 046L/M — deploy-time quant repair (AR self-gen + passthrough fit).
4704+ # Runs in-memory on the loaded quantized model. SCREEN MODE: fitted values
4705+ # affect this run's quantized eval but DO NOT propagate back to artifact.
4706+ # PRODUCTION MODE (separate todo): re-serialize artifact with fitted passthrough.
4707+ if h .deploy_time_repair_enabled :
4708+ torch ._dynamo .reset ()
4709+ fit_passthrough_on_ar_gen (eval_model , h , device )
45454710 compiled_model = torch .compile (eval_model , dynamic = False , fullgraph = True )
45464711 compiled_forward_logits = torch .compile (
45474712 eval_model .forward_logits , dynamic = False , fullgraph = True
0 commit comments