Files
cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/moe/moe_utils.py
Junkai-Wu cb37157db5 v4.5 tag update (#3202)
* Python DSL examples reorganization.

* v4.5 tag update.
2026-05-05 20:55:27 -04:00

911 lines
34 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Online TMA Descriptor Construction Utilities.
Provides utilities for dynamically creating TMA descriptors at kernel runtime
based on runtime-provided information (problem sizes, pointers, etc.).
Key components:
- OnlineTensormapDescCreator: Simplified ABC for TMA descriptor builders (2 abstract methods)
- TensormapWorkspace: Helper for linear workspace layout of TMA descriptors
- MoEGroupedGemmTensormapConstructor: TMA descriptor constructor for MoE Grouped GEMM
- GeneralGroupedGemmTensormapConstructor: TMA descriptor constructor for general Grouped GEMM
- Pointer utility functions (ptr_offset_bytes, gmem_ptr_to_generic, etc.)
- tensormap_ptr_for_copy: Convert raw desc ptr to cute.copy-compatible type
- compute_expert_token_range: Compute per-expert token offset and count from offs
- rewrite_tensor_shape: Debug-friendly tensor shape rewrite utility
"""
from abc import ABC, abstractmethod
from typing import Optional, Literal, Tuple, Union
import cutlass
import cutlass.cute as cute
from cutlass.cute.typing import AddressSpace, Pointer
from cutlass.cute.nvgpu import cpasync
from cutlass.cutlass_dsl import dsl_user_op, Int32
from cutlass._mlir import ir
from cutlass._mlir.dialects import llvm
from cutlass._mlir.dialects import cute as _cute_ir
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
from dataclasses import dataclass
from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF
TensormapDescBytes = 128
# =============================================================================
# Pointer Utilities
# =============================================================================
@dsl_user_op
@cute.jit
def spin_wait(
ptr: Pointer, condition, fail_sleep_cycles: int = 100, *, loc=None, ip=None
) -> None:
"""
Generic spin-wait.
Example usage:
# Wait until counter >= total_blocks
spin_wait(counter_ptr, lambda x: x >= total_blocks, fail_sleep_cycles=100)
# Wait until flag == 1
spin_wait(flag_ptr, lambda x: x == 1)
"""
current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip)
while not condition(current):
# Load with L1 cache bypass (ld.global.cg)
if cutlass.const_expr(fail_sleep_cycles > 0):
cute.arch.nanosleep(sleep_time=fail_sleep_cycles, loc=loc, ip=ip)
current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip)
@dsl_user_op
def gmem_ptr_to_generic(
gmem_ptr: Pointer,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> Pointer:
if gmem_ptr.memspace != AddressSpace.gmem:
raise ValueError(
f"gmem_ptr_to_generic requires pointer in gmem address space, "
f"got {gmem_ptr.memspace}"
)
# Get LLVM pointer and cast to generic address space
llvm_ptr = gmem_ptr.to_llvm_ptr(loc=loc, ip=ip)
generic_llvm_ptr = llvm.addrspacecast(
llvm.PointerType.get(AddressSpace.generic), llvm_ptr, loc=loc, ip=ip
)
# Create a new cute.Pointer with generic address space, preserving alignment
return cute.make_ptr(
gmem_ptr.dtype,
generic_llvm_ptr,
AddressSpace.generic,
assumed_align=gmem_ptr.alignment,
loc=loc,
ip=ip,
)
@dsl_user_op
def generic_ptr_to_gmem(
generic_ptr: Pointer,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> Pointer:
if generic_ptr.memspace != AddressSpace.generic:
raise ValueError(
f"generic_ptr_to_gmem requires pointer in generic address space, "
f"got {generic_ptr.memspace}"
)
# Get LLVM pointer and cast to gmem address space
llvm_ptr = generic_ptr.to_llvm_ptr(loc=loc, ip=ip)
gmem_llvm_ptr = llvm.addrspacecast(
llvm.PointerType.get(AddressSpace.gmem), llvm_ptr, loc=loc, ip=ip
)
# Create a new cute.Pointer with gmem address space, preserving alignment
return cute.make_ptr(
generic_ptr.dtype,
gmem_llvm_ptr,
AddressSpace.gmem,
assumed_align=generic_ptr.alignment,
loc=loc,
ip=ip,
)
@dsl_user_op
def prefetch_tma_descriptor(tma_desc_ptr: Pointer, *, loc=None, ip=None) -> None:
"""
Prefetch a TMA descriptor from global memory.
This function prefetches the TMA descriptor pointed to by tma_desc_ptr
into the TMA descriptor cache. The pointer must be in generic or global
address space. If a gmem pointer is passed, it will be automatically
converted to generic address space.
:param tma_desc_ptr: Pointer to the TMA descriptor in global or generic memory
:type tma_desc_ptr: Pointer
:raises ValueError: If pointer is not in generic or global address space
"""
if tma_desc_ptr.memspace not in (AddressSpace.gmem, AddressSpace.generic):
raise ValueError(
f"prefetch_tma_descriptor requires pointer in gmem or generic address space, "
f"got {tma_desc_ptr.memspace}"
)
# Convert gmem pointer to generic if needed
if tma_desc_ptr.memspace == AddressSpace.gmem:
tma_desc_ptr = gmem_ptr_to_generic(tma_desc_ptr, loc=loc, ip=ip)
# Convert cute.Pointer to LLVM pointer for prefetch
llvm_ptr = tma_desc_ptr.to_llvm_ptr(loc=loc, ip=ip)
from cutlass.cute.arch.nvvm_wrappers import prefetch as nvvm_prefetch
nvvm_prefetch(llvm_ptr, tensormap=True, loc=loc, ip=ip)
def ptr_offset_bytes(ptr: Pointer, byte_offset: int) -> Pointer:
"""Offset a pointer by a given number of bytes."""
element_offset = byte_offset * 8 // ptr.dtype.width
return ptr + element_offset
@dsl_user_op
def tensormap_ptr_for_copy(raw_ptr: Pointer, *, loc=None, ip=None) -> Pointer:
"""
Convert a raw TMA descriptor gmem pointer to the type expected by cute.copy.
cute.copy requires the tma_desc_ptr to be in generic address space and
recast to TmaDescriptorTiledType. This utility performs both conversions.
:param raw_ptr: Raw pointer to TMA descriptor in gmem
:type raw_ptr: Pointer
:return: Pointer compatible with cute.copy's tma_desc_ptr parameter
:rtype: Pointer
"""
generic_ptr = gmem_ptr_to_generic(raw_ptr, loc=loc, ip=ip)
tma_desc_ptr_ty = _cute_ir.PtrType.get(
_cute_nvgpu_ir.TmaDescriptorTiledType.get(),
generic_ptr.memspace,
generic_ptr.alignment,
)
return _cute_ir.recast_iter(tma_desc_ptr_ty, generic_ptr.value)
# =============================================================================
# MoE Utilities
# =============================================================================
@dsl_user_op
@cute.jit
def compute_expert_token_range(
offs: cute.Tensor,
expert_idx: Int32,
*,
loc=None,
ip=None,
) -> Tuple[Int32, Int32]:
"""
Compute token offset and count for a given expert from the cumsum offs tensor.
:param offs: Cumulative sum tensor of token counts per expert, shape (experts,)
:param expert_idx: Index of the expert
:return: (token_offset, tokens_i) where token_offset is the start position
and tokens_i is the number of tokens for this expert
"""
token_offset = Int32(0)
if expert_idx > Int32(0):
token_offset = offs[expert_idx - 1] # type: ignore[assignment]
tokens_i = offs[expert_idx] - token_offset
return token_offset, tokens_i
@dsl_user_op
def rewrite_tensor_shape(
tensor: cute.Tensor,
new_shape: Tuple,
*,
loc=None,
ip=None,
) -> cute.Tensor:
"""
Rewrite tensor shape while keeping the same stride and iterator.
This is primarily for debug friendliness - shows the actual expert's shape
instead of the fake global shape. No runtime overhead as it becomes
dead code in non-debug builds.
:param tensor: Source tensor whose stride and iterator to preserve
:param new_shape: New shape to apply
:return: New tensor with the given shape but original stride and iterator
"""
new_layout = cute.make_layout(new_shape, stride=tensor.stride, loc=loc, ip=ip)
return cute.make_tensor(tensor.iterator, new_layout, loc=loc, ip=ip)
# =============================================================================
# TMA Descriptor Workspace Helper
# =============================================================================
class TensormapWorkspace:
"""
Helper for linear workspace layout of TMA descriptors.
Manages address calculation for a workspace buffer containing TMA descriptors
organized as: for each executor (e.g., expert or group), a fixed set of
named descriptor slots.
Layout: [slot_0_exec_0, slot_1_exec_0, ..., slot_0_exec_1, slot_1_exec_1, ...]
Example:
# 2Dx3D MoE: only C is expert-wise
workspace = TensormapWorkspace(workspace_ptr, ["c"])
# 2Dx2D MoE: A and B are expert-wise
workspace = TensormapWorkspace(workspace_ptr, ["a", "b"])
# General grouped GEMM: all three tensors
workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "c"])
"""
def __init__(self, workspace_ptr: Pointer, slot_names: list):
"""
:param workspace_ptr: Pointer to the beginning of the workspace buffer
:param slot_names: Ordered list of tensor names, defining the slot layout
per executor. e.g., ["a", "b", "c"]
"""
self.workspace_ptr = workspace_ptr
self._name_to_slot = {name: i for i, name in enumerate(slot_names)}
self._slots_per_executor = len(slot_names)
@cute.jit
def get_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
"""
Get the workspace pointer for a specific TMA descriptor.
:param tensor_name: Name of the tensor (must be one of the slot_names)
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx)
:return: Aligned pointer to the TMA descriptor in workspace
"""
if cutlass.const_expr(tensor_name not in self._name_to_slot):
raise ValueError(
f"Invalid tensor_name '{tensor_name}', "
f"expected one of {list(self._name_to_slot.keys())}"
)
slot = self._name_to_slot[tensor_name]
byte_offset = (
executor_idx * self._slots_per_executor + slot
) * TensormapDescBytes
return ptr_offset_bytes(self.workspace_ptr, byte_offset).align(
TensormapDescBytes
)
@staticmethod
def size_bytes(num_slots: int, num_executors: int) -> int:
"""
Calculate workspace size in bytes.
:param num_slots: Number of descriptor slots per executor
:param num_executors: Total number of executors (e.g., expert_cnt or group_cnt)
:return: Total workspace size in bytes
"""
return num_slots * num_executors * TensormapDescBytes
# =============================================================================
# Online TMA Descriptor Creator (Abstract Base Class)
# =============================================================================
@dataclass(frozen=True)
class OnlineTensormapDescCreator(ABC):
"""
Abstract base class for building TMA descriptors online (at kernel runtime).
Subclasses store all needed parameters (both codegen-time configs and runtime
values) as explicit instance attributes in __init__. No dict-based APIs.
Subclasses must implement exactly 2 abstract methods:
- construct_and_write: Build TMA descriptor(s) for one executor and write to workspace
- get_desc_ptr: Return raw gmem pointer to a specific descriptor in workspace
To convert the raw pointer for use with cute.copy, callers should use the
standalone tensormap_ptr_for_copy() utility.
"""
@abstractmethod
def construct_and_write(self, executor_idx: Int32, dependency=None) -> None:
"""
Build TMA descriptor(s) for one executor and write to workspace.
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx).
Semantics may vary by subclass when ``dependency`` is provided.
:param dependency: Optional pipeline consumer for inter-warp-group
synchronization. When provided, the subclass decides when to wait
(via ``dependency.wait_and_advance()``) and release. The subclass
also decides how to interpret ``executor_idx`` in this mode.
"""
...
@abstractmethod
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
"""
Get the raw gmem pointer to a specific TMA descriptor in workspace.
:param tensor_name: Name identifying which tensor's descriptor
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx)
:return: Raw pointer (gmem) to the TMA descriptor
"""
...
# =============================================================================
# MoE Grouped GEMM Tensormap Constructor
# =============================================================================
class MoEGroupedGemmTensormapConstructor(OnlineTensormapDescCreator):
"""
Tensormap descriptor constructor for MoE Grouped GEMM (expert-wise descriptors only).
Non-expert-wise descriptors are passed directly at kernel launch.
This class only handles:
- 2Dx3D: C descriptors (expert-wise, to avoid write conflicts)
- 2Dx2D: A and B descriptors (expert-wise, tokens is reduction axis)
All parameters are stored as explicit instance attributes (no dicts).
Workspace layout:
- 2Dx3D: [C_0, C_1, ..., C_{n-1}]
- 2Dx2D: [A_0, A_1, ..., A_{n-1}, B_0, B_1, ..., B_{n-1}]
"""
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
# Codegen-time configs
a_dtype,
b_dtype,
c_dtype,
a_smem_layout,
b_smem_layout,
epi_smem_layout,
a_tma_op,
b_tma_op,
c_tma_op,
tiled_mma,
mma_tiler,
cluster_layout_vmnk_shape,
epi_tile,
# Runtime params
a_tensor: cute.Tensor, # fake GEMM domain A
b_tensor: cute.Tensor, # fake GEMM domain B
c_tensor: cute.Tensor, # fake GEMM domain C
offs: cute.Tensor, # (experts,) cumsum
workspace_ptr: Pointer,
) -> None:
super().__init__()
self.scenario = scenario
# Codegen-time configs
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.c_dtype = c_dtype
self.a_smem_layout = a_smem_layout
self.b_smem_layout = b_smem_layout
self.epi_smem_layout = epi_smem_layout
self.a_tma_op = a_tma_op
self.b_tma_op = b_tma_op
self.c_tma_op = c_tma_op
self.tiled_mma = tiled_mma
self.mma_tiler = mma_tiler
self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape
self.epi_tile = epi_tile
# Runtime params
self.a_tensor = a_tensor
self.b_tensor = b_tensor
self.c_tensor = c_tensor
self.offs = offs
# Workspace with scenario-specific slot layout
if scenario == "2Dx3D":
self.workspace = TensormapWorkspace(workspace_ptr, ["c"])
else:
self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b"])
@staticmethod
def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int:
"""Calculate workspace size in bytes for tensormap descriptors."""
if scenario == "2Dx3D":
return TensormapWorkspace.size_bytes(1, expert_cnt) # only C
else:
return TensormapWorkspace.size_bytes(2, expert_cnt) # A and B
@cute.jit
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
return self.workspace.get_ptr(tensor_name, executor_idx)
@cute.jit
def construct_and_write(self, executor_idx: Int32, dependency=None) -> None:
"""
Create expert-wise tensormap descriptors for the given expert.
- 2Dx3D: Creates C descriptor for this expert
- 2Dx2D: Creates A and B descriptors for this expert
"""
if cutlass.const_expr(self.scenario == "2Dx3D"):
self._construct_c_desc_2dx3d(executor_idx)
else: # 2Dx2D
self._construct_ab_descs_2dx2d(executor_idx)
@cute.jit
def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None:
"""
2Dx3D: Create expert-wise C descriptor.
C tensor: (fake_m, n, 1) = (tokens_sum, intermediate, 1)
Slice fake_m -> (tokens_i, intermediate, 1) per expert.
"""
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
c_ptr = self.c_tensor.iterator
c_stride = self.c_tensor.stride
intermediate = self.c_tensor.shape[1] # type: ignore[index]
c1 = cutlass.Int32(1)
c0 = cutlass.Int32(0)
c_ptr_i = c_ptr + token_offset * c_stride[0] # type: ignore[index]
c_layout_i = cute.make_layout(
(tokens_i, intermediate, c1),
stride=(c_stride[0], c_stride[1], c0), # type: ignore[index]
)
c_tensor_i = cute.make_tensor(c_ptr_i, c_layout_i)
tma_atom_c, _ = cpasync.make_tiled_tma_atom(
self.c_tma_op,
c_tensor_i,
self.epi_smem_layout,
self.epi_tile,
)
cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx))
@cute.jit
def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None:
"""
2Dx2D: Create expert-wise A and B descriptors.
A: (m, fake_k, 1) -> slice to (m, tokens_i, 1)
B: (n, fake_k, 1) -> slice to (n, tokens_i, 1)
"""
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
c1 = cutlass.Int32(1)
c0 = cutlass.Int32(0)
# A tensor: (m, fake_k, 1) -> (m, tokens_i, 1)
a_ptr = self.a_tensor.iterator
a_stride = self.a_tensor.stride
a_m = self.a_tensor.shape[0] # type: ignore[index]
a_ptr_i = a_ptr + token_offset * a_stride[1] # type: ignore[index]
a_layout_i = cute.make_layout(
(a_m, tokens_i, c1),
stride=(a_stride[0], a_stride[1], c0), # type: ignore[index]
)
a_tensor_i = cute.make_tensor(a_ptr_i, a_layout_i)
tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A(
self.a_tma_op,
a_tensor_i,
self.a_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
)
cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx))
# B tensor: (n, fake_k, 1) -> (n, tokens_i, 1)
b_ptr = self.b_tensor.iterator
b_stride = self.b_tensor.stride
b_n = self.b_tensor.shape[0] # type: ignore[index]
b_ptr_i = b_ptr + token_offset * b_stride[1] # type: ignore[index]
b_layout_i = cute.make_layout(
(b_n, tokens_i, c1),
stride=(b_stride[0], b_stride[1], c0), # type: ignore[index]
)
b_tensor_i = cute.make_tensor(b_ptr_i, b_layout_i)
tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
self.b_tma_op,
b_tensor_i,
self.b_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
)
cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx))
# =============================================================================
# MoE Scaled Grouped GEMM Tensormap Constructor
# =============================================================================
class MoEScaledGroupedGemmTensormapConstructor(OnlineTensormapDescCreator):
"""
Tensormap descriptor constructor for MoE Scaled Grouped GEMM (block-scaled).
.. py:attribute:: ChunkSize
:value: 128
Number of experts processed per chunk in the desc_init_kernel.
Must match the warp-group width (4 warps × 32 threads).
Extends MoEGroupedGemmTensormapConstructor with SFA/SFB descriptor support.
Expert-wise descriptors only — non-expert-wise descriptors are passed
directly at kernel launch.
Workspace layout:
- 2Dx3D: [C_0, C_1, ..., C_{n-1}] (1 slot per expert)
- 2Dx2D: [A_0, B_0, SFA_0, SFB_0, A_1, B_1, SFA_1, SFB_1, ...] (4 slots per expert)
:param scenario: "2Dx3D" or "2Dx2D"
:param sf_vec_size: Scale factor vector size (32 for MXFP8/MXFP4, 16 for NVFP4)
:param a_dtype: Data type for tensor A
:param b_dtype: Data type for tensor B
:param c_dtype: Data type for tensor C
:param sf_dtype: Data type for scale factors (SFA/SFB)
:param a_smem_layout: SMEM layout for A TMA
:param b_smem_layout: SMEM layout for B TMA
:param epi_smem_layout: SMEM layout for epilogue (C) TMA
:param sfa_smem_layout: SMEM layout for SFA TMA
:param sfb_smem_layout: SMEM layout for SFB TMA
:param a_tma_op: TMA operation for A
:param b_tma_op: TMA operation for B
:param c_tma_op: TMA operation for C (S2G store or reduce)
:param sfa_tma_op: TMA operation for SFA
:param sfb_tma_op: TMA operation for SFB
:param tiled_mma: TiledMma for A/B/SFA/C TMA atom construction
:param tiled_mma_sfb: TiledMma for SFB (separate due to 2CTA replication)
:param mma_tiler: MMA tiler shape (M, N, K)
:param mma_tiler_sfb: MMA tiler shape for SFB
:param cluster_layout_vmnk_shape: Cluster layout shape for A/B/SFA multicast
:param cluster_layout_sfb_vmnk_shape: Cluster layout shape for SFB multicast
:param epi_tile: Epilogue tile shape
:param a_tensor: Fake GEMM domain A tensor
:param b_tensor: Fake GEMM domain B tensor
:param c_tensor: Fake GEMM domain C tensor
:param sfa_tensor: Fake GEMM domain SFA tensor (atom-tiled layout)
:param sfb_tensor: Fake GEMM domain SFB tensor (atom-tiled layout)
:param offs: (experts,) cumsum offsets in data domain
:param offs_padded: (experts,) cumsum offsets in padded scale domain
:param workspace_ptr: Pointer to workspace for TMA descriptors
:param expert_cnt: Total number of experts
"""
ChunkSize = 128
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
sf_vec_size: int,
# Codegen-time configs: dtypes
a_dtype,
b_dtype,
c_dtype,
sf_dtype,
# Codegen-time configs: SMEM layouts
a_smem_layout,
b_smem_layout,
epi_smem_layout,
sfa_smem_layout,
sfb_smem_layout,
# Codegen-time configs: TMA ops
a_tma_op,
b_tma_op,
c_tma_op,
sfa_tma_op,
sfb_tma_op,
# Codegen-time configs: MMA / cluster / tile
tiled_mma,
tiled_mma_sfb,
mma_tiler,
mma_tiler_sfb,
cluster_layout_vmnk_shape,
cluster_layout_sfb_vmnk_shape,
epi_tile,
# Runtime params
a_tensor: cute.Tensor,
b_tensor: cute.Tensor,
c_tensor: cute.Tensor,
sfa_tensor: cute.Tensor,
sfb_tensor: cute.Tensor,
offs: cute.Tensor,
offs_padded: cute.Tensor,
workspace_ptr: Pointer,
expert_cnt: Optional[Union[Int32, int]] = None,
) -> None:
super().__init__()
self.scenario = scenario
self.sf_vec_size = sf_vec_size
# Dtypes
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.c_dtype = c_dtype
self.sf_dtype = sf_dtype
# SMEM layouts
self.a_smem_layout = a_smem_layout
self.b_smem_layout = b_smem_layout
self.epi_smem_layout = epi_smem_layout
self.sfa_smem_layout = sfa_smem_layout
self.sfb_smem_layout = sfb_smem_layout
# TMA ops
self.a_tma_op = a_tma_op
self.b_tma_op = b_tma_op
self.c_tma_op = c_tma_op
self.sfa_tma_op = sfa_tma_op
self.sfb_tma_op = sfb_tma_op
# MMA / cluster / tile
self.tiled_mma = tiled_mma
self.tiled_mma_sfb = tiled_mma_sfb
self.mma_tiler = mma_tiler
self.mma_tiler_sfb = mma_tiler_sfb
self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape
self.cluster_layout_sfb_vmnk_shape = cluster_layout_sfb_vmnk_shape
self.epi_tile = epi_tile
# Runtime params
self.a_tensor = a_tensor
self.b_tensor = b_tensor
self.c_tensor = c_tensor
self.sfa_tensor = sfa_tensor
self.sfb_tensor = sfb_tensor
self.offs = offs
self.offs_padded = offs_padded
self.expert_cnt = expert_cnt
# Workspace with scenario-specific slot layout
if scenario == "2Dx3D":
self.workspace = TensormapWorkspace(workspace_ptr, ["c"])
else:
self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "sfa", "sfb"])
@staticmethod
def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int:
"""Calculate workspace size in bytes for tensormap descriptors."""
if scenario == "2Dx3D":
return TensormapWorkspace.size_bytes(1, expert_cnt) # C only
else:
return TensormapWorkspace.size_bytes(4, expert_cnt) # A, B, SFA, SFB
@cute.jit
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
return self.workspace.get_ptr(tensor_name, executor_idx)
@cute.jit
def construct_and_write(self, lane_in_group: Int32, dependency=None) -> None:
"""
Create expert-wise tensormap descriptors for all experts.
``lane_in_group`` is the thread's position within its warp group
(0..ChunkSize-1). The method loops internally over all experts in
chunks of ``ChunkSize``, with two-phase pipeline synchronization
per chunk.
Per-chunk execution:
1. Phase 1: Build descriptors that do NOT depend on ``offs_padded``
(A/B for 2Dx2D, C for 2Dx3D). Overlaps with Group A's prefix sum.
2. Barrier: ``consumer.wait_and_advance()`` — all threads participate.
3. Phase 2: Build descriptors that depend on ``offs_padded``
(SFA/SFB for 2Dx2D). Reads padded offsets from SMEM buffer.
4. Release: ``handle.release()`` — all threads participate.
:param lane_in_group: Thread's position within the warp group (0..127).
:param dependency: ``(PipelineConsumer, smem_offs_padded)`` — the
consumer for mbarrier sync, and the SMEM tensor of shape
``(ChunkSize + 1,)`` with layout ``[carry, offs_padded[0..127]]``.
"""
consumer, smem_offs_padded = dependency
assert self.expert_cnt is not None
num_chunks = (self.expert_cnt + self.ChunkSize - 1) // self.ChunkSize
chunk_idx = cutlass.Int32(0)
while chunk_idx < num_chunks:
expert_idx = chunk_idx * self.ChunkSize + lane_in_group
in_bounds = expert_idx < self.expert_cnt
# Phase 1: non-dependent descriptors
if in_bounds:
if cutlass.const_expr(self.scenario == "2Dx2D"):
self._construct_ab_descs_2dx2d(expert_idx)
else:
self._construct_c_desc_2dx3d(expert_idx)
# All threads participate in barrier (fixed arrive count)
handle = consumer.wait_and_advance()
# Phase 2: dependent descriptors (read padded offsets from SMEM)
if in_bounds:
if cutlass.const_expr(self.scenario == "2Dx2D"):
# smem_offs_padded layout: [carry, chunk[0], ..., chunk[127]]
# padded_offset = smem[lane] (prev expert's cumulative)
# padded_end = smem[lane + 1] (this expert's cumulative)
padded_offset = smem_offs_padded[lane_in_group]
padded_size_i = smem_offs_padded[lane_in_group + 1] - padded_offset
self._construct_sf_descs_2dx2d_direct(
expert_idx, padded_offset, padded_size_i
)
# All threads release (fixed arrive count)
handle.release()
chunk_idx += 1
# -----------------------------------------------------------------
# 2Dx3D: C descriptor (same as MoEGroupedGemmTensormapConstructor)
# -----------------------------------------------------------------
@cute.jit
def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None:
"""
2Dx3D: Create expert-wise C descriptor.
C: (fake_m, n, 1) -> slice to (tokens_i, n, 1) per expert.
"""
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
c1 = cutlass.Int32(1)
c_i = cute.domain_offset((token_offset, 0, 0), self.c_tensor)
c_i = rewrite_tensor_shape(c_i, (tokens_i, self.c_tensor.shape[1], c1)) # type: ignore[index]
tma_atom_c, _ = cpasync.make_tiled_tma_atom(
self.c_tma_op,
c_i,
self.epi_smem_layout,
self.epi_tile,
)
cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx))
# -----------------------------------------------------------------
# 2Dx2D: A, B descriptors (same as MoEGroupedGemmTensormapConstructor)
# -----------------------------------------------------------------
@cute.jit
def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None:
"""
2Dx2D: Create expert-wise A and B descriptors.
A: (m, fake_k, 1) -> slice to (m, tokens_i, 1)
B: (n, fake_k, 1) -> slice to (n, tokens_i, 1)
"""
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
c1 = cutlass.Int32(1)
# A: (m, fake_k, 1) -> domain_offset + rewrite shape
a_i = cute.domain_offset((0, token_offset, 0), self.a_tensor)
a_i = rewrite_tensor_shape(a_i, (self.a_tensor.shape[0], tokens_i, c1)) # type: ignore[index]
tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A(
self.a_tma_op,
a_i,
self.a_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
)
cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx))
# B: (n, fake_k, 1) -> domain_offset + rewrite shape
b_i = cute.domain_offset((0, token_offset, 0), self.b_tensor)
b_i = rewrite_tensor_shape(b_i, (self.b_tensor.shape[0], tokens_i, c1)) # type: ignore[index]
tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
self.b_tma_op,
b_i,
self.b_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
)
cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx))
# -----------------------------------------------------------------
# 2Dx2D: SFA, SFB descriptors (new for block-scaled)
# -----------------------------------------------------------------
@cute.jit
def _construct_sf_descs_2dx2d_direct(
self,
expert_idx: Int32,
padded_offset: Int32,
padded_size_i: Int32,
) -> None:
"""
2Dx2D: Create expert-wise SFA and SFB descriptors with pre-computed
padded offset and size.
This variant allows the caller to supply padded offsets from SMEM
(in desc_init_kernel) instead of reading from ``self.offs_padded`` in GMEM.
"""
c1 = cutlass.Int32(1)
a_chunks_to_move = (
padded_offset
// self.sf_vec_size
* cute.size(self.sfa_tensor, mode=[0])
// 128
)
a_elems_to_move = (
cute.size(self.sfa_tensor, mode=[0]) * padded_offset // self.sf_vec_size
)
b_chunks_to_move = (
padded_offset
// self.sf_vec_size
* cute.size(self.sfb_tensor, mode=[0])
// 128
)
b_elems_to_move = (
cute.size(self.sfb_tensor, mode=[0]) * padded_offset // self.sf_vec_size
)
per_expert_sfa_shape = (self.sfa_tensor.shape[0], padded_size_i, c1) # type: ignore[index]
sfa_layout_i = tile_atom_to_shape_SF(per_expert_sfa_shape, self.sf_vec_size)
sfa_i = cute.make_tensor(
self.sfa_tensor.iterator + a_elems_to_move, sfa_layout_i
)
tma_atom_sfa, _ = cute.nvgpu.make_tiled_tma_atom_A(
self.sfa_tma_op,
sfa_i,
self.sfa_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
internal_type=cutlass.Uint64,
)
cpasync.copy_tensormap(tma_atom_sfa, self.get_desc_ptr("sfa", expert_idx))
per_expert_sfb_shape = (self.sfb_tensor.shape[0], padded_size_i, c1) # type: ignore[index]
sfb_layout_i = tile_atom_to_shape_SF(per_expert_sfb_shape, self.sf_vec_size)
sfb_i = cute.make_tensor(
self.sfb_tensor.iterator + b_elems_to_move, sfb_layout_i
)
tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B(
self.sfb_tma_op,
sfb_i,
self.sfb_smem_layout,
self.mma_tiler_sfb,
self.tiled_mma_sfb,
self.cluster_layout_sfb_vmnk_shape,
internal_type=cutlass.Uint64,
)
cpasync.copy_tensormap(tma_atom_sfb, self.get_desc_ptr("sfb", expert_idx))