v37/ai

v37 / research

Deterministic Fine-Tuning on Dual MI100s

· v37

Most fine-tuning recipes published in 2025 don’t reproduce. Same script, same data, same seed, different weights. The discrepancy is small enough to look like training noise — small enough that nobody usually checks. But “small” matters when you’re trying to compare two LoRA runs to decide which won, or when you’re trying to cite a result.

On AMD MI100s the situation is worse: the upstream Hugging Face stack assumes a CUDA-shaped world, and the divergences are subtle enough to look like the hardware is the problem. It isn’t. With three small patches, the same inputs produce the same weights — bit for bit, across forward and gradient-checkpoint recompute, twice in a row, every time.

This is the recipe. The repo is at forge.v37.io/ai-lab/deterministic-stack (mirror coming). The README has the full README; this post is the story.

The symptom that started it

Train Qwen3.6-MoE with LoRA + gradient checkpointing under ROCm. You’ll see:

RuntimeError: torch.utils.checkpoint: a different number of tensors was saved
during the original forward and recomputation.

at step 2, every time, on every adapter you try. That’s PyTorch’s gradient checkpointing complaining that the forward pass and the recompute pass disagree on how many intermediate tensors to save. Specifically: the number differs, not just the values.

That sounds like a PyTorch bug. It isn’t.

Why it happens

PyTorch’s gradient checkpointing trades memory for compute by re-running the forward pass during the backward pass to recover activations rather than keeping them in VRAM. The contract is: forward and recompute must produce structurally identical computation graphs. Same ops, same tensor shapes, same number of tensors saved.

Qwen3.6-MoE has 256 experts and routes 8 of them per token via top-k. The router’s computation is roughly:

routing_indices=TopK8(softmax(Wrh))\text{routing\_indices} = \mathrm{TopK}_8\big(\,\mathrm{softmax}(W_r\, h)\,\big)

With 256 experts, many of those softmax probabilities are near-tied. Tiny floating-point drift between forward and recompute — last few bits of bf16 mantissa — flips which experts win the top-8. Different experts firing means different per-expert tensor shapes during the recompute. The count check catches it. The training step dies.

The drift is real. On ROCm with bf16, three sources contribute:

  1. torch.conv1d fallback — Qwen’s GatedDeltaNet layer wires its causal_conv1d_fn attribute to the CUDA-only causal-conv1d PyPI package. When that’s not installed (it can’t be on ROCm), the layer falls back to a torch conv1d that is non-deterministic on ROCm even with torch.use_deterministic_algorithms(True).
  2. SDPA / flash backward — both have non-deterministic backward kernels on ROCm. Forward is fine; backward isn’t.
  3. Residual hipBLAS GEMM drift — even with everything above pinned, bf16 GEMMs sometimes produce a few-bit-different output across recompute. Small, but enough to flip near-tied softmax scores.

Each one feeds the next. The router is just where the consequence becomes visible.

The three patches

1. Replace the conv1d fallback

The non-deterministic conv1d isn’t a hard problem; it’s just a wrong default. fla ships a Triton-based causal_conv1d that works on ROCm and is deterministic by construction. Adapter:

from fla.modules.conv.causal_conv1d import causal_conv1d as _fla
def fla_causal_conv1d_fn(x, weight, bias, activation=None, **_):
    out = _fla(x=x.transpose(1, 2).contiguous(), weight=weight,
               bias=bias, activation=activation)
    out = out[0] if isinstance(out, tuple) else out
    return out.transpose(1, 2).contiguous()

for m in model.modules():
    if m.__class__.__name__ == 'Qwen3_5MoeGatedDeltaNet':
        m.causal_conv1d_fn = fla_causal_conv1d_fn

The [B, D, T][B, T, D] transposes handle the layout disagreement between the two functions. That’s the entire shim.

2. Force eager attention

model.set_attn_implementation("eager")

Eager is torch.matmul + torch.softmax. Slower per step, deterministic by virtue of touching only ops that are themselves deterministic under use_deterministic_algorithms(True). SDPA and flash get measurably faster training but break the contract; eager respects it.

3. Pin top-k routing across forward / recompute

This is the load-bearing one and it’s the cleverest of the three, mostly because the obvious fix doesn’t work.

Naive idea: skip the gate’s compute on recompute and reuse the routing decisions from forward. Doesn’t work, because checkpoint’s count check also enforces that the number of tensors saved by autograd matches. Skipping the gate drops ~6 ops from the autograd graph per layer. Count check fires.

The fix that works: run the full linearsoftmaxtopk\text{linear} \to \text{softmax} \to \mathrm{topk} chain on both passes (same ops, same tensor count), but on recompute, replace topk’s indices with the cached ones from forward before they’re used downstream. The gather op runs on both passes, so the autograd graph is structurally identical. The only difference between forward and recompute is the integer index tensor’s values — and integer tensors carry no gradient.

forward:    linear → softmax → topk → [cache indices]    → gather → divide → cast
recompute:  linear → softmax → topk → [override w/ cache] → gather → divide → cast

Tensor count matches. Per-expert shapes match. Gradients flow through gate.weight normally because the integer-index swap happens after topk and is invisible to autograd.

The patch is ~25 lines. It’s the kind of code that takes an afternoon to debug and ten minutes to write once you understand the constraint.

Verifying determinism

Run twice with the same seed, the same data, the same base. Diff the adapters:

import safetensors.torch as st
a = st.load_file('checkpoints/run-a/adapter_model.safetensors')
b = st.load_file('checkpoints/run-b/adapter_model.safetensors')
deltas = {k: (a[k] - b[k]).abs().max().item() for k in a}
print(max(deltas.values()), sum(1 for v in deltas.values() if v > 0))
# 0.0 0

max=0.0, nonzero=0. Bit-for-bit identical. If anything is non-zero, one of the three patches isn’t installed — check the training log for the Patched causal_conv1d_fn on N GatedDeltaNet layers and similar.

What this is not

This is not a Qwen-specific library; the patches reach into class names (Qwen3_5MoeGatedDeltaNet, Qwen3_5MoeTopKRouter) and would need renaming for other MoE architectures. (One source of confusion worth flagging: the Qwen3_5Moe... class prefix is a Hugging Face naming holdover and does not match the model line — the model is Qwen3.6-MoE; the classes are just what transformers calls them.) It is not the full Lares fine-tune; the corpus, the rubrics, the held-out eval set, and the produced adapters all live elsewhere and aren’t shipped with this recipe. It is not a replacement for FSDP at scale; we’re verifying determinism on a 2-card device_map="auto" setup, not a multi-process training run.

It is, narrowly, what it takes to make a Qwen3.6-MoE LoRA fine-tune reproduce on AMD MI100s under ROCm 7.0–7.2. That’s a small claim, and a useful one if you’re trying to do the same thing.

What’s open

  • Whether the routing-cache patch is necessary on the batched_mm / grouped_mm experts implementations. We use the default linear_loop_experts_forward because gptqmodel unrolls per-expert modules into individual nn.Linears; the fused-tensor paths don’t apply to our layout. On a different layout, the contention surface is different; the patch may or may not still be required.
  • Whether ROCm 7.3+ improves the underlying bf16 drift enough to make patch #2 unnecessary. We’ll re-test when it drops.
  • Whether the same recipe works unchanged on MI200 / MI300. They have better fp determinism than MI100, so the failure mode may not even trigger — but unverified.

The repo will be updated with answers as we get them. If you’ve got data points from your own runs, the issue tracker on forge.v37.io/ai-lab/deterministic-stack is the right place to drop them.

#reproducibility #training #rocm #qwen #moe

← back to research