Skip to content

Commit 2895db3

Browse files
leon2k2k2kclaude
andcommitted
spec 019: extend constant-α wiring to TTT forward path
018c (aabfbea) applied constant-α to forward_logits but not to forward_ttt / _block_with_lora. For a full-pipeline run we need TTT to also exercise the hardcoded α values (same lerp with literal weight, same compile specialization benefit). Mirror the encoder/decoder pattern from forward_logits: precomputed _encoder_alpha_info and _decoder_alpha_info lists store Python floats; forward_ttt reads them via Python indexing and calls torch.lerp(x_before, x, alpha) after each _block_with_lora at recur-alpha sites. Closes the TTT-path gap from spec 015's original patch AND maintains the compile-time-constant α optimization validated at proxy scale in 018c (92% blend overhead recovered). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent aabfbea commit 2895db3

1 file changed

Lines changed: 25 additions & 3 deletions

File tree

  • records/track_10min_16mb/2026-04-19_SP8192_CaseOps_GatedAttn_QuantGate_Loop45_PhasedTTT

records/track_10min_16mb/2026-04-19_SP8192_CaseOps_GatedAttn_QuantGate_Loop45_PhasedTTT/train_gpt.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,10 +1310,28 @@ def forward_ttt(self, input_ids, target_ids, lora):
13101310
)
13111311
)
13121312
)
1313+
# Spec 019: apply constant-α blend in the TTT forward path too.
1314+
# alpha_info lists contain Python floats (set at __init__ from the 017
1315+
# endpoint table). The torch.lerp call sees the literal in the weight
1316+
# position just like forward_logits, so compile specialization applies.
1317+
enc_alpha_info = (
1318+
self._encoder_alpha_info
1319+
if (self.recur_alpha_enabled and self.looping_active)
1320+
else None
1321+
)
1322+
dec_alpha_info = (
1323+
self._decoder_alpha_info
1324+
if (self.recur_alpha_enabled and self.looping_active)
1325+
else None
1326+
)
13131327
slot = 0
1314-
for i in enc_iter:
1328+
for step_idx, i in enumerate(enc_iter):
13151329
q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i)
1316-
x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w)
1330+
x_before = x
1331+
x = self._block_with_lora(self.blocks[i], x_before, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w)
1332+
if enc_alpha_info is not None and enc_alpha_info[step_idx] is not None:
1333+
alpha = enc_alpha_info[step_idx] # Python float constant
1334+
x = torch.lerp(x_before, x, alpha)
13171335
slot += 1
13181336
skips.append(x)
13191337
psl = self.parallel_start_layer
@@ -1348,7 +1366,11 @@ def forward_ttt(self, input_ids, target_ids, lora):
13481366
x = torch.lerp(scaled_skip, x, g)
13491367
else:
13501368
x = x + scaled_skip
1351-
x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w)
1369+
x_before = x
1370+
x = self._block_with_lora(self.blocks[i], x_before, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w)
1371+
if dec_alpha_info is not None and dec_alpha_info[skip_idx] is not None:
1372+
alpha = dec_alpha_info[skip_idx] # Python float constant
1373+
x = torch.lerp(x_before, x, alpha)
13521374
slot += 1
13531375
if lane0 is not None:
13541376
x = self._final_parallel_hidden(lane0, lane1)

0 commit comments

Comments
 (0)