Deterministic Fine-Tuning on Dual MI100s
· v37
Most fine-tuning recipes published in 2025 don’t reproduce. Same script, same data, same seed, different weights. The discrepancy is small enough to look like training noise — small enough that nobody usually checks. But “small” matters when you’re trying to compare two LoRA runs to decide which won, or when you’re trying to cite a result.
On AMD MI100s the situation is worse: the upstream Hugging Face stack assumes a CUDA-shaped world, and the divergences are subtle enough to look like the hardware is the problem. It isn’t. With three small patches, the same smoke run produced the same adapter weights twice in a row, bit for bit. That’s narrower than “all training is solved forever,” but it’s the verification that matters before you trust a fine-tuning recipe.
This is the recipe. The repo is at forge.v37.io (Codeberg mirror). The repo has the full README; this post is the story.
The symptom that started it
Train Qwen3.6-MoE with LoRA + gradient checkpointing under ROCm. You’ll see:
RuntimeError: torch.utils.checkpoint: a different number of tensors was saved
during the original forward and recomputation.
around steps 1-2 across repeated attempts. That’s PyTorch’s gradient
checkpointing complaining that the forward pass and the recompute pass
disagree on the intermediate tensors they saved. In some runs the number
differed; in others the tensor metadata differed, with shapes like
[83, 512] in forward and [82, 512] in recompute.
That sounds like a PyTorch bug. It isn’t.
Why it happens
PyTorch’s gradient checkpointing trades memory for compute by re-running the forward pass during the backward pass to recover activations rather than keeping them in VRAM. The contract is: forward and recompute must produce structurally identical computation graphs. Same ops, same tensor shapes, same number of tensors saved.
Qwen3.6-MoE has 256 experts and routes 8 of them per token via top-k. The router’s computation is roughly:
With 256 experts, many of those softmax probabilities are near-tied. Tiny floating-point drift between forward and recompute can flip which experts win the top-8. Different experts firing means different per-expert tensor shapes during the recompute. The count/metadata checks catch it. The training step dies.
The logs prove the shape/count mismatch and the successful patched run. They do not contain a clean router-index ablation that prints “forward top-k != recompute top-k.” The routing explanation is the mechanism that fits the failure shape, the model architecture, and the fixes that finally held. The evidence trail points at three surfaces:
torch.conv1dfallback — Qwen’sGatedDeltaNetlayer wires itscausal_conv1d_fnattribute to the CUDA-onlycausal-conv1dPyPI package. When that’s not installed (it can’t be on ROCm), the layer falls back to a torch conv1d path. Replacing that fallback was one layer in the final working stack; the artifacts don’t isolate it alone.- SDPA / flash backward — forcing eager attention removed another fused ROCm surface from the recompute path and changed the failure ladder.
- Residual bf16 drift before routing — even after the obvious fused paths were removed, the MoE router still needed stable top-k decisions across forward and recompute. Small numeric differences are enough when many experts are near-tied.
Each one feeds the next. The router is just where the consequence becomes visible.
The three patches
1. Replace the conv1d fallback
The conv1d fallback was one layer of the problem, and the fix is small.
fla ships a
Triton-based causal_conv1d that works on the ROCm stack we tested. Adapter:
from fla.modules.conv.causal_conv1d import causal_conv1d as _fla
def fla_causal_conv1d_fn(x, weight, bias, activation=None, **_):
out = _fla(x=x.transpose(1, 2).contiguous(), weight=weight,
bias=bias, activation=activation)
out = out[0] if isinstance(out, tuple) else out
return out.transpose(1, 2).contiguous()
for m in model.modules():
if m.__class__.__name__ == 'Qwen3_5MoeGatedDeltaNet':
m.causal_conv1d_fn = fla_causal_conv1d_fn
The [B, D, T] ↔ [B, T, D] transposes handle the layout disagreement
between the two functions. That’s the entire shim.
2. Force eager attention
model.set_attn_implementation("eager")
Eager is torch.matmul + torch.softmax. Slower per step, but it keeps the
full-attention layers on ordinary PyTorch ops where
use_deterministic_algorithms(True) can do its job or fail loudly. SDPA and
flash are faster; in this run, they were the wrong surface to debug through.
3. Pin top-k routing across forward / recompute
This is the load-bearing one and it’s the cleverest of the three, mostly because the obvious fix doesn’t work.
Naive idea: skip the gate’s compute on recompute and reuse the routing decisions from forward. Doesn’t work, because checkpoint’s count check also enforces that the number of tensors saved by autograd matches. Skipping the gate drops ~6 ops from the autograd graph per layer. Count check fires.
The fix that works: run the full
chain on both passes
(same ops, same tensor count), but on recompute, replace topk’s indices
with the cached ones from forward before they’re used downstream. The
gather op runs on both passes, so the autograd graph is structurally
identical. The only difference between forward and recompute is the integer
index tensor’s values — and integer tensors carry no gradient.
forward: linear → softmax → topk → [cache indices] → gather → divide → cast
recompute: linear → softmax → topk → [override w/ cache] → gather → divide → cast
Tensor count matches. Per-expert shapes match. Gradients flow through
gate.weight normally because the integer-index swap happens after topk
and is invisible to autograd.
The patch is ~25 lines. It’s the kind of code that takes an afternoon to debug and ten minutes to write once you understand the constraint.
Verifying determinism
Run twice with the same seed, the same data, the same base. Hash the saved adapter artifacts:
uv run python tests/hash_adapter.py /tmp/deterministic-stack-smoke/run-a \
/tmp/deterministic-stack-smoke/run-b
The 2026-06-11 verification used /home/daniele/models/Qwen3.6-35B-A3B-GPTQ-Int4-self,
train-v3.jsonl, --smoke, seed 42, and the full Qwen3.6 LoRA target list
q_proj,k_proj,v_proj,o_proj,in_proj_qkv,out_proj. The result:
ADAPTER_BYTES_IDENTICAL=yes
adapter_model.safetensors
a: 288c2ee7ca0e511ccbfaf350bc15ec3c0c10572a98d9f01132564cb0049a4bbc
b: 288c2ee7ca0e511ccbfaf350bc15ec3c0c10572a98d9f01132564cb0049a4bbc
checkpoint-2/adapter_model.safetensors
a: 288c2ee7ca0e511ccbfaf350bc15ec3c0c10572a98d9f01132564cb0049a4bbc
b: 288c2ee7ca0e511ccbfaf350bc15ec3c0c10572a98d9f01132564cb0049a4bbc
That’s the verified claim: the smoke run saved byte-identical adapter artifacts
twice in a row. If a future run differs, start with the training log: check for
Patched causal_conv1d_fn on N GatedDeltaNet layers, the eager-attention line,
the routing-cache install line, and the LoRA target list before looking for more
exotic causes.
What this is not
This is not a Qwen-specific library; the patches reach into class names
(Qwen3_5MoeGatedDeltaNet, Qwen3_5MoeTopKRouter) and would need renaming
for other MoE architectures. (One source of confusion worth flagging: the
Qwen3_5Moe... class prefix is a Hugging Face naming holdover and does
not match the model line — the model is Qwen3.6-MoE; the classes are
just what transformers calls them.) It is not the full Lares fine-tune; the corpus,
the rubrics, the held-out eval set, and the produced adapters all live
elsewhere and aren’t shipped with this recipe. It is not a replacement for
FSDP at scale; we’re verifying determinism on a 2-card device_map="auto"
setup, not a multi-process training run.
It is, narrowly, one verified recipe that made a Qwen3.6-MoE LoRA smoke run reproduce on AMD MI100s under ROCm 7.0–7.2. That’s a small claim, and a useful one if you’re trying to do the same thing.
What’s open
- Whether the routing-cache patch is necessary on the
batched_mm/grouped_mmexperts implementations. We use the defaultlinear_loop_experts_forwardbecause gptqmodel unrolls per-expert modules into individualnn.Linears; the fused-tensor paths don’t apply to our layout. On a different layout, the contention surface is different; the patch may or may not still be required. - Whether ROCm 7.3+ improves determinism around the fused/numeric surfaces enough to make patch #2 unnecessary. We’ll re-test when it drops.
- Whether the same recipe works unchanged on MI200 / MI300. They have better fp determinism than MI100, so the failure mode may not even trigger — but unverified.
The repo will be updated with answers as we get them. If you’ve got data points from your own runs, the issue tracker on forge.v37.io/ai-lab/deterministic-stack is the right place to drop them.
#reproducibility #training #rocm #qwen #moe