Add tensor parallel support for MiniMax M2 Q/K norms

MiniMax M2 uses Q/K RMSNorm with span_heads=True, which normalizes
across ALL heads at each sequence position. When using tensor
parallelism, heads are split across devices, so each device only
sees a subset of heads and computes incorrect local variance.

The fix follows vLLM's approach:
- Compute local sum of squares on each TP rank
- All-reduce the sum across ranks
- Divide by global dimension to get true global mean
- Apply normalization with corrected global variance

Key changes:
- attn.py: Add apply_qk_norms_tp() method with variance all-reduce
- attn.py: Modify tp_export/tp_import to handle span_heads norms
- rmsnorm.py: Preserve span_heads in tp_export, handle 1D tensors in split
- minimax_m2.py: Enable TP support (supports_tp: True)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Jo-Philipp Wich
2026-01-12 10:44:26 +01:00
committed by Jo-Philipp Wich
parent e839152802
commit 4845c8fa25
3 changed files with 138 additions and 23 deletions

View File

@@ -165,8 +165,8 @@ class MiniMaxM2Model(Model):
# Activate all experts during H capture pass in quantization
self.calibration_all_experts = True
# TODO: Q/K norms span all heads, so TP requires an additional step to reduce variance across ranks
self.caps.update({"supports_tp": False})
# Q/K norms span all heads - TP support uses variance all-reduce across ranks
self.caps.update({"supports_tp": True})
@override

View File

@@ -257,6 +257,11 @@ class Attention(Module):
self.has_split_cache = False
# TP-aware span_heads norm support
self.tp_span_heads_norm = False
self.q_global_dim = 0
self.k_global_dim = 0
@override
def optimizer_targets(self):
@@ -400,6 +405,68 @@ class Attention(Module):
return x
def apply_qk_norms_tp(
self,
q: torch.Tensor,
k: torch.Tensor,
params: dict,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply Q/K RMSNorm with TP-aware variance reduction.
Used when span_heads=True and tp_reduce=True.
The variance is computed locally, then all-reduced across TP ranks
to get the true global variance.
"""
backend = params["backend"]
orig_q_shape = q.shape
orig_k_shape = k.shape
bsz, seq_len = q.shape[0], q.shape[1]
# Flatten head dimension for norm computation: (bsz, seq, num_heads*head_dim)
q_flat = q.view(bsz, seq_len, -1).float()
k_flat = k.view(bsz, seq_len, -1).float()
# Compute local sum of squares
q_sq_sum = q_flat.pow(2).sum(dim=-1, keepdim=True)
k_sq_sum = k_flat.pow(2).sum(dim=-1, keepdim=True)
# Native TP backend requires data_size (numel * 2 bytes) to be multiple of 16.
# So numel must be multiple of 8. We have 2 values (q, k) per position.
# Flatten to 1D and pad to multiple of 8 elements.
qk_sq_sum = torch.cat([q_sq_sum.view(-1), k_sq_sum.view(-1)]) # (bsz*seq_len*2,)
numel = qk_sq_sum.numel()
pad_to = (numel + 7) // 8 * 8
if pad_to > numel:
qk_sq_sum = torch.nn.functional.pad(qk_sq_sum, (0, pad_to - numel))
backend.all_reduce(qk_sq_sum)
# Extract back (first half is q, second half is k)
half = bsz * seq_len
q_sq_sum = qk_sq_sum[:half].view(bsz, seq_len, 1)
k_sq_sum = qk_sq_sum[half:half*2].view(bsz, seq_len, 1)
# Compute global variance (sum / global_dim)
q_var = q_sq_sum / self.q_global_dim
k_var = k_sq_sum / self.k_global_dim
# Compute normalization factors
q_rmf = torch.rsqrt(q_var + self.norm_eps)
k_rmf = torch.rsqrt(k_var + self.norm_eps)
# Get weights (handle constant_bias if needed)
q_w = self.q_norm.weight
k_w = self.k_norm.weight
if self.norm_constant_bias != 0.0:
q_w = q_w + self.norm_constant_bias
k_w = k_w + self.norm_constant_bias
# Apply normalization and reshape back
q = (q_flat * q_rmf * q_w).half().view(orig_q_shape)
k = (k_flat * k_rmf * k_w).half().view(orig_k_shape)
return q, k
def decode_sdpa_nc(
self,
x: torch.Tensor,
@@ -423,9 +490,13 @@ class Attention(Module):
assert self.logit_softcapping == 0.0, \
"Torch SDPA does not support logit softcapping"
if self.q_norm and (not self.rope or self.q_norm_tensor is None):
q = self.q_norm.forward(q, params, out_dtype = torch.half)
k = self.k_norm.forward(k, params, out_dtype = torch.half)
if self.q_norm:
if self.tp_span_heads_norm:
# TP-aware path for span_heads=True
q, k = self.apply_qk_norms_tp(q, k, params)
elif not self.rope or self.q_norm_tensor is None:
q = self.q_norm.forward(q, params, out_dtype = torch.half)
k = self.k_norm.forward(k, params, out_dtype = torch.half)
if self.rope:
q, k = self.rope.apply(
@@ -434,8 +505,8 @@ class Attention(Module):
positions,
position_ids,
True,
self.q_norm_tensor,
self.k_norm_tensor,
self.q_norm_tensor if not self.tp_span_heads_norm else None,
self.k_norm_tensor if not self.tp_span_heads_norm else None,
self.norm_eps,
self.norm_constant_bias,
inv_freq,
@@ -471,9 +542,13 @@ class Attention(Module):
k = k.view(bsz, seqlen, self.num_kv_heads, self.head_dim)
v = v.view(bsz, seqlen, self.num_kv_heads, self.head_dim)
if self.q_norm and (not self.rope or self.q_norm_tensor is None):
q = self.q_norm.forward(q, params, out_dtype = torch.half)
k = self.k_norm.forward(k, params, out_dtype = torch.half)
if self.q_norm:
if self.tp_span_heads_norm:
# TP-aware path for span_heads=True
q, k = self.apply_qk_norms_tp(q, k, params)
elif not self.rope or self.q_norm_tensor is None:
q = self.q_norm.forward(q, params, out_dtype = torch.half)
k = self.k_norm.forward(k, params, out_dtype = torch.half)
if self.rope:
q, k = self.rope.apply(
@@ -482,8 +557,8 @@ class Attention(Module):
positions,
position_ids,
True,
self.q_norm_tensor,
self.k_norm_tensor,
self.q_norm_tensor if not self.tp_span_heads_norm else None,
self.k_norm_tensor if not self.tp_span_heads_norm else None,
self.norm_eps,
self.norm_constant_bias,
inv_freq,
@@ -545,9 +620,13 @@ class Attention(Module):
v = v.view(bsz, seqlen, self.num_kv_heads, self.head_dim)
# TODO: Add LayerNorm option to fused norm/RoPE kernel
if self.q_norm and (not self.rope or self.q_norm_tensor is None):
q = self.q_norm.forward(q, params, out_dtype = torch.half)
k = self.k_norm.forward(k, params, out_dtype = torch.half)
if self.q_norm:
if self.tp_span_heads_norm:
# TP-aware path for span_heads=True
q, k = self.apply_qk_norms_tp(q, k, params)
elif not self.rope or self.q_norm_tensor is None:
q = self.q_norm.forward(q, params, out_dtype = torch.half)
k = self.k_norm.forward(k, params, out_dtype = torch.half)
if self.rope:
q, k = self.rope.apply(
@@ -556,8 +635,8 @@ class Attention(Module):
positions,
position_ids,
True,
self.q_norm_tensor,
self.k_norm_tensor,
self.q_norm_tensor if not self.tp_span_heads_norm else None,
self.k_norm_tensor if not self.tp_span_heads_norm else None,
self.norm_eps,
self.norm_constant_bias,
inv_freq,
@@ -648,6 +727,13 @@ class Attention(Module):
nonlocal producer
return child.tp_export(plan, producer) if child is not None else None
# Check if q_norm uses span_heads
q_norm_span_heads = (
self.q_norm is not None and
isinstance(self.q_norm, RMSNorm) and
self.q_norm.span_heads
)
return {
"cls": Attention,
"kwargs": {
@@ -676,7 +762,11 @@ class Attention(Module):
"cache_layers": [
cl.tp_export(plan) for cl in self.cache_layers
],
"n_gqa": self.num_q_heads // self.num_kv_heads
"n_gqa": self.num_q_heads // self.num_kv_heads,
# For TP-aware span_heads norm
"q_norm_span_heads": q_norm_span_heads,
"q_global_dim": self.num_q_heads * self.head_dim if self.q_norm else 0,
"k_global_dim": self.num_kv_heads * self.head_dim if self.k_norm else 0,
}
@@ -697,10 +787,21 @@ class Attention(Module):
if num_kv_heads else None
o_split = (False, first * head_dim * n_gqa, last * head_dim * n_gqa) \
if num_kv_heads else None
norm_q_split = (first * n_gqa, last * n_gqa) \
if num_kv_heads else None
norm_k_split = (first, last) \
if num_kv_heads else None
# For span_heads norms, we need element indices (head_idx * head_dim)
# For regular norms, we use head indices
q_norm_span_heads = exported.get("q_norm_span_heads", False)
if q_norm_span_heads:
# span_heads=True: norm weight is 1D with shape (num_heads * head_dim,)
norm_q_split = (first * head_dim * n_gqa, last * head_dim * n_gqa) \
if num_kv_heads else None
norm_k_split = (first * head_dim, last * head_dim) \
if num_kv_heads else None
else:
# span_heads=False: norm weight is 2D with shape (num_heads, head_dim)
norm_q_split = (first * n_gqa, last * n_gqa) \
if num_kv_heads else None
norm_k_split = (first, last) \
if num_kv_heads else None
# def _import(name):
# nonlocal exported, plan
@@ -738,6 +839,13 @@ class Attention(Module):
module.device = device
if not kwargs.get("skip_reduction"):
module.tp_reduce = True
# Set up TP-aware span_heads norm if needed
if exported.get("q_norm_span_heads", False):
module.tp_span_heads_norm = True
module.q_global_dim = exported.get("q_global_dim", 0)
module.k_global_dim = exported.get("k_global_dim", 0)
module.load_local(device)
torch.cuda.synchronize()
return module

View File

@@ -111,6 +111,7 @@ class RMSNorm(Module):
"rms_norm_eps": self.rms_norm_eps,
"out_dtype": self.out_dtype,
"constant_bias": self.constant_bias,
"span_heads": self.span_heads,
},
"weight": producer.send(self.weight),
"device": self.device,
@@ -127,6 +128,7 @@ class RMSNorm(Module):
module.device = device
w = consumer.recv(exported["weight"], cuda = True)
module.weight = nn.Parameter(w) if w is not None else None
# span_heads is preserved via kwargs
torch.cuda.synchronize()
return module
@@ -145,6 +147,11 @@ class RMSNorm(Module):
if w is not None:
if w.dim() == 2:
w = w[first : last, :]
elif w.dim() == 1:
# 1D weight tensor (e.g., span_heads=True norms)
# split contains element indices
w = w[first : last]
module.weight = nn.Parameter(w.to(module.device).contiguous())
# span_heads is preserved via kwargs
return module
return module