# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ Online TMA Descriptor Construction Utilities. Provides utilities for dynamically creating TMA descriptors at kernel runtime based on runtime-provided information (problem sizes, pointers, etc.). Key components: - OnlineTensormapDescCreator: Simplified ABC for TMA descriptor builders (2 abstract methods) - TensormapWorkspace: Helper for linear workspace layout of TMA descriptors - MoEGroupedGemmTensormapConstructor: TMA descriptor constructor for MoE Grouped GEMM - GeneralGroupedGemmTensormapConstructor: TMA descriptor constructor for general Grouped GEMM - Pointer utility functions (ptr_offset_bytes, gmem_ptr_to_generic, etc.) - tensormap_ptr_for_copy: Convert raw desc ptr to cute.copy-compatible type - compute_expert_token_range: Compute per-expert token offset and count from offs - rewrite_tensor_shape: Debug-friendly tensor shape rewrite utility """ from abc import ABC, abstractmethod from typing import Optional, Literal, Tuple, Union import cutlass import cutlass.cute as cute from cutlass.cute.typing import AddressSpace, Pointer from cutlass.cute.nvgpu import cpasync from cutlass.cutlass_dsl import dsl_user_op, Int32 from cutlass._mlir import ir from cutlass._mlir.dialects import llvm from cutlass._mlir.dialects import cute as _cute_ir from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir from dataclasses import dataclass from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF TensormapDescBytes = 128 # ============================================================================= # Pointer Utilities # ============================================================================= @dsl_user_op @cute.jit def spin_wait( ptr: Pointer, condition, fail_sleep_cycles: int = 100, *, loc=None, ip=None ) -> None: """ Generic spin-wait. Example usage: # Wait until counter >= total_blocks spin_wait(counter_ptr, lambda x: x >= total_blocks, fail_sleep_cycles=100) # Wait until flag == 1 spin_wait(flag_ptr, lambda x: x == 1) """ current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip) while not condition(current): # Load with L1 cache bypass (ld.global.cg) if cutlass.const_expr(fail_sleep_cycles > 0): cute.arch.nanosleep(sleep_time=fail_sleep_cycles, loc=loc, ip=ip) current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip) @dsl_user_op def gmem_ptr_to_generic( gmem_ptr: Pointer, *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None, ) -> Pointer: if gmem_ptr.memspace != AddressSpace.gmem: raise ValueError( f"gmem_ptr_to_generic requires pointer in gmem address space, " f"got {gmem_ptr.memspace}" ) # Get LLVM pointer and cast to generic address space llvm_ptr = gmem_ptr.to_llvm_ptr(loc=loc, ip=ip) generic_llvm_ptr = llvm.addrspacecast( llvm.PointerType.get(AddressSpace.generic), llvm_ptr, loc=loc, ip=ip ) # Create a new cute.Pointer with generic address space, preserving alignment return cute.make_ptr( gmem_ptr.dtype, generic_llvm_ptr, AddressSpace.generic, assumed_align=gmem_ptr.alignment, loc=loc, ip=ip, ) @dsl_user_op def generic_ptr_to_gmem( generic_ptr: Pointer, *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None, ) -> Pointer: if generic_ptr.memspace != AddressSpace.generic: raise ValueError( f"generic_ptr_to_gmem requires pointer in generic address space, " f"got {generic_ptr.memspace}" ) # Get LLVM pointer and cast to gmem address space llvm_ptr = generic_ptr.to_llvm_ptr(loc=loc, ip=ip) gmem_llvm_ptr = llvm.addrspacecast( llvm.PointerType.get(AddressSpace.gmem), llvm_ptr, loc=loc, ip=ip ) # Create a new cute.Pointer with gmem address space, preserving alignment return cute.make_ptr( generic_ptr.dtype, gmem_llvm_ptr, AddressSpace.gmem, assumed_align=generic_ptr.alignment, loc=loc, ip=ip, ) @dsl_user_op def prefetch_tma_descriptor(tma_desc_ptr: Pointer, *, loc=None, ip=None) -> None: """ Prefetch a TMA descriptor from global memory. This function prefetches the TMA descriptor pointed to by tma_desc_ptr into the TMA descriptor cache. The pointer must be in generic or global address space. If a gmem pointer is passed, it will be automatically converted to generic address space. :param tma_desc_ptr: Pointer to the TMA descriptor in global or generic memory :type tma_desc_ptr: Pointer :raises ValueError: If pointer is not in generic or global address space """ if tma_desc_ptr.memspace not in (AddressSpace.gmem, AddressSpace.generic): raise ValueError( f"prefetch_tma_descriptor requires pointer in gmem or generic address space, " f"got {tma_desc_ptr.memspace}" ) # Convert gmem pointer to generic if needed if tma_desc_ptr.memspace == AddressSpace.gmem: tma_desc_ptr = gmem_ptr_to_generic(tma_desc_ptr, loc=loc, ip=ip) # Convert cute.Pointer to LLVM pointer for prefetch llvm_ptr = tma_desc_ptr.to_llvm_ptr(loc=loc, ip=ip) from cutlass.cute.arch.nvvm_wrappers import prefetch as nvvm_prefetch nvvm_prefetch(llvm_ptr, tensormap=True, loc=loc, ip=ip) def ptr_offset_bytes(ptr: Pointer, byte_offset: int) -> Pointer: """Offset a pointer by a given number of bytes.""" element_offset = byte_offset * 8 // ptr.dtype.width return ptr + element_offset @dsl_user_op def tensormap_ptr_for_copy(raw_ptr: Pointer, *, loc=None, ip=None) -> Pointer: """ Convert a raw TMA descriptor gmem pointer to the type expected by cute.copy. cute.copy requires the tma_desc_ptr to be in generic address space and recast to TmaDescriptorTiledType. This utility performs both conversions. :param raw_ptr: Raw pointer to TMA descriptor in gmem :type raw_ptr: Pointer :return: Pointer compatible with cute.copy's tma_desc_ptr parameter :rtype: Pointer """ generic_ptr = gmem_ptr_to_generic(raw_ptr, loc=loc, ip=ip) tma_desc_ptr_ty = _cute_ir.PtrType.get( _cute_nvgpu_ir.TmaDescriptorTiledType.get(), generic_ptr.memspace, generic_ptr.alignment, ) return _cute_ir.recast_iter(tma_desc_ptr_ty, generic_ptr.value) # ============================================================================= # MoE Utilities # ============================================================================= @dsl_user_op @cute.jit def compute_expert_token_range( offs: cute.Tensor, expert_idx: Int32, *, loc=None, ip=None, ) -> Tuple[Int32, Int32]: """ Compute token offset and count for a given expert from the cumsum offs tensor. :param offs: Cumulative sum tensor of token counts per expert, shape (experts,) :param expert_idx: Index of the expert :return: (token_offset, tokens_i) where token_offset is the start position and tokens_i is the number of tokens for this expert """ token_offset = Int32(0) if expert_idx > Int32(0): token_offset = offs[expert_idx - 1] # type: ignore[assignment] tokens_i = offs[expert_idx] - token_offset return token_offset, tokens_i @dsl_user_op def rewrite_tensor_shape( tensor: cute.Tensor, new_shape: Tuple, *, loc=None, ip=None, ) -> cute.Tensor: """ Rewrite tensor shape while keeping the same stride and iterator. This is primarily for debug friendliness - shows the actual expert's shape instead of the fake global shape. No runtime overhead as it becomes dead code in non-debug builds. :param tensor: Source tensor whose stride and iterator to preserve :param new_shape: New shape to apply :return: New tensor with the given shape but original stride and iterator """ new_layout = cute.make_layout(new_shape, stride=tensor.stride, loc=loc, ip=ip) return cute.make_tensor(tensor.iterator, new_layout, loc=loc, ip=ip) # ============================================================================= # TMA Descriptor Workspace Helper # ============================================================================= class TensormapWorkspace: """ Helper for linear workspace layout of TMA descriptors. Manages address calculation for a workspace buffer containing TMA descriptors organized as: for each executor (e.g., expert or group), a fixed set of named descriptor slots. Layout: [slot_0_exec_0, slot_1_exec_0, ..., slot_0_exec_1, slot_1_exec_1, ...] Example: # 2Dx3D MoE: only C is expert-wise workspace = TensormapWorkspace(workspace_ptr, ["c"]) # 2Dx2D MoE: A and B are expert-wise workspace = TensormapWorkspace(workspace_ptr, ["a", "b"]) # General grouped GEMM: all three tensors workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "c"]) """ def __init__(self, workspace_ptr: Pointer, slot_names: list): """ :param workspace_ptr: Pointer to the beginning of the workspace buffer :param slot_names: Ordered list of tensor names, defining the slot layout per executor. e.g., ["a", "b", "c"] """ self.workspace_ptr = workspace_ptr self._name_to_slot = {name: i for i, name in enumerate(slot_names)} self._slots_per_executor = len(slot_names) @cute.jit def get_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer: """ Get the workspace pointer for a specific TMA descriptor. :param tensor_name: Name of the tensor (must be one of the slot_names) :param executor_idx: Index of the executor (e.g., group_idx or expert_idx) :return: Aligned pointer to the TMA descriptor in workspace """ if cutlass.const_expr(tensor_name not in self._name_to_slot): raise ValueError( f"Invalid tensor_name '{tensor_name}', " f"expected one of {list(self._name_to_slot.keys())}" ) slot = self._name_to_slot[tensor_name] byte_offset = ( executor_idx * self._slots_per_executor + slot ) * TensormapDescBytes return ptr_offset_bytes(self.workspace_ptr, byte_offset).align( TensormapDescBytes ) @staticmethod def size_bytes(num_slots: int, num_executors: int) -> int: """ Calculate workspace size in bytes. :param num_slots: Number of descriptor slots per executor :param num_executors: Total number of executors (e.g., expert_cnt or group_cnt) :return: Total workspace size in bytes """ return num_slots * num_executors * TensormapDescBytes # ============================================================================= # Online TMA Descriptor Creator (Abstract Base Class) # ============================================================================= @dataclass(frozen=True) class OnlineTensormapDescCreator(ABC): """ Abstract base class for building TMA descriptors online (at kernel runtime). Subclasses store all needed parameters (both codegen-time configs and runtime values) as explicit instance attributes in __init__. No dict-based APIs. Subclasses must implement exactly 2 abstract methods: - construct_and_write: Build TMA descriptor(s) for one executor and write to workspace - get_desc_ptr: Return raw gmem pointer to a specific descriptor in workspace To convert the raw pointer for use with cute.copy, callers should use the standalone tensormap_ptr_for_copy() utility. """ @abstractmethod def construct_and_write(self, executor_idx: Int32, dependency=None) -> None: """ Build TMA descriptor(s) for one executor and write to workspace. :param executor_idx: Index of the executor (e.g., group_idx or expert_idx). Semantics may vary by subclass when ``dependency`` is provided. :param dependency: Optional pipeline consumer for inter-warp-group synchronization. When provided, the subclass decides when to wait (via ``dependency.wait_and_advance()``) and release. The subclass also decides how to interpret ``executor_idx`` in this mode. """ ... @abstractmethod def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer: """ Get the raw gmem pointer to a specific TMA descriptor in workspace. :param tensor_name: Name identifying which tensor's descriptor :param executor_idx: Index of the executor (e.g., group_idx or expert_idx) :return: Raw pointer (gmem) to the TMA descriptor """ ... # ============================================================================= # MoE Grouped GEMM Tensormap Constructor # ============================================================================= class MoEGroupedGemmTensormapConstructor(OnlineTensormapDescCreator): """ Tensormap descriptor constructor for MoE Grouped GEMM (expert-wise descriptors only). Non-expert-wise descriptors are passed directly at kernel launch. This class only handles: - 2Dx3D: C descriptors (expert-wise, to avoid write conflicts) - 2Dx2D: A and B descriptors (expert-wise, tokens is reduction axis) All parameters are stored as explicit instance attributes (no dicts). Workspace layout: - 2Dx3D: [C_0, C_1, ..., C_{n-1}] - 2Dx2D: [A_0, A_1, ..., A_{n-1}, B_0, B_1, ..., B_{n-1}] """ def __init__( self, scenario: Literal["2Dx3D", "2Dx2D"], # Codegen-time configs a_dtype, b_dtype, c_dtype, a_smem_layout, b_smem_layout, epi_smem_layout, a_tma_op, b_tma_op, c_tma_op, tiled_mma, mma_tiler, cluster_layout_vmnk_shape, epi_tile, # Runtime params a_tensor: cute.Tensor, # fake GEMM domain A b_tensor: cute.Tensor, # fake GEMM domain B c_tensor: cute.Tensor, # fake GEMM domain C offs: cute.Tensor, # (experts,) cumsum workspace_ptr: Pointer, ) -> None: super().__init__() self.scenario = scenario # Codegen-time configs self.a_dtype = a_dtype self.b_dtype = b_dtype self.c_dtype = c_dtype self.a_smem_layout = a_smem_layout self.b_smem_layout = b_smem_layout self.epi_smem_layout = epi_smem_layout self.a_tma_op = a_tma_op self.b_tma_op = b_tma_op self.c_tma_op = c_tma_op self.tiled_mma = tiled_mma self.mma_tiler = mma_tiler self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape self.epi_tile = epi_tile # Runtime params self.a_tensor = a_tensor self.b_tensor = b_tensor self.c_tensor = c_tensor self.offs = offs # Workspace with scenario-specific slot layout if scenario == "2Dx3D": self.workspace = TensormapWorkspace(workspace_ptr, ["c"]) else: self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b"]) @staticmethod def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int: """Calculate workspace size in bytes for tensormap descriptors.""" if scenario == "2Dx3D": return TensormapWorkspace.size_bytes(1, expert_cnt) # only C else: return TensormapWorkspace.size_bytes(2, expert_cnt) # A and B @cute.jit def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer: return self.workspace.get_ptr(tensor_name, executor_idx) @cute.jit def construct_and_write(self, executor_idx: Int32, dependency=None) -> None: """ Create expert-wise tensormap descriptors for the given expert. - 2Dx3D: Creates C descriptor for this expert - 2Dx2D: Creates A and B descriptors for this expert """ if cutlass.const_expr(self.scenario == "2Dx3D"): self._construct_c_desc_2dx3d(executor_idx) else: # 2Dx2D self._construct_ab_descs_2dx2d(executor_idx) @cute.jit def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None: """ 2Dx3D: Create expert-wise C descriptor. C tensor: (fake_m, n, 1) = (tokens_sum, intermediate, 1) Slice fake_m -> (tokens_i, intermediate, 1) per expert. """ token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx) c_ptr = self.c_tensor.iterator c_stride = self.c_tensor.stride intermediate = self.c_tensor.shape[1] # type: ignore[index] c1 = cutlass.Int32(1) c0 = cutlass.Int32(0) c_ptr_i = c_ptr + token_offset * c_stride[0] # type: ignore[index] c_layout_i = cute.make_layout( (tokens_i, intermediate, c1), stride=(c_stride[0], c_stride[1], c0), # type: ignore[index] ) c_tensor_i = cute.make_tensor(c_ptr_i, c_layout_i) tma_atom_c, _ = cpasync.make_tiled_tma_atom( self.c_tma_op, c_tensor_i, self.epi_smem_layout, self.epi_tile, ) cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx)) @cute.jit def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None: """ 2Dx2D: Create expert-wise A and B descriptors. A: (m, fake_k, 1) -> slice to (m, tokens_i, 1) B: (n, fake_k, 1) -> slice to (n, tokens_i, 1) """ token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx) c1 = cutlass.Int32(1) c0 = cutlass.Int32(0) # A tensor: (m, fake_k, 1) -> (m, tokens_i, 1) a_ptr = self.a_tensor.iterator a_stride = self.a_tensor.stride a_m = self.a_tensor.shape[0] # type: ignore[index] a_ptr_i = a_ptr + token_offset * a_stride[1] # type: ignore[index] a_layout_i = cute.make_layout( (a_m, tokens_i, c1), stride=(a_stride[0], a_stride[1], c0), # type: ignore[index] ) a_tensor_i = cute.make_tensor(a_ptr_i, a_layout_i) tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A( self.a_tma_op, a_tensor_i, self.a_smem_layout, self.mma_tiler, self.tiled_mma, self.cluster_layout_vmnk_shape, ) cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx)) # B tensor: (n, fake_k, 1) -> (n, tokens_i, 1) b_ptr = self.b_tensor.iterator b_stride = self.b_tensor.stride b_n = self.b_tensor.shape[0] # type: ignore[index] b_ptr_i = b_ptr + token_offset * b_stride[1] # type: ignore[index] b_layout_i = cute.make_layout( (b_n, tokens_i, c1), stride=(b_stride[0], b_stride[1], c0), # type: ignore[index] ) b_tensor_i = cute.make_tensor(b_ptr_i, b_layout_i) tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B( self.b_tma_op, b_tensor_i, self.b_smem_layout, self.mma_tiler, self.tiled_mma, self.cluster_layout_vmnk_shape, ) cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx)) # ============================================================================= # MoE Scaled Grouped GEMM Tensormap Constructor # ============================================================================= class MoEScaledGroupedGemmTensormapConstructor(OnlineTensormapDescCreator): """ Tensormap descriptor constructor for MoE Scaled Grouped GEMM (block-scaled). .. py:attribute:: ChunkSize :value: 128 Number of experts processed per chunk in the desc_init_kernel. Must match the warp-group width (4 warps × 32 threads). Extends MoEGroupedGemmTensormapConstructor with SFA/SFB descriptor support. Expert-wise descriptors only — non-expert-wise descriptors are passed directly at kernel launch. Workspace layout: - 2Dx3D: [C_0, C_1, ..., C_{n-1}] (1 slot per expert) - 2Dx2D: [A_0, B_0, SFA_0, SFB_0, A_1, B_1, SFA_1, SFB_1, ...] (4 slots per expert) :param scenario: "2Dx3D" or "2Dx2D" :param sf_vec_size: Scale factor vector size (32 for MXFP8/MXFP4, 16 for NVFP4) :param a_dtype: Data type for tensor A :param b_dtype: Data type for tensor B :param c_dtype: Data type for tensor C :param sf_dtype: Data type for scale factors (SFA/SFB) :param a_smem_layout: SMEM layout for A TMA :param b_smem_layout: SMEM layout for B TMA :param epi_smem_layout: SMEM layout for epilogue (C) TMA :param sfa_smem_layout: SMEM layout for SFA TMA :param sfb_smem_layout: SMEM layout for SFB TMA :param a_tma_op: TMA operation for A :param b_tma_op: TMA operation for B :param c_tma_op: TMA operation for C (S2G store or reduce) :param sfa_tma_op: TMA operation for SFA :param sfb_tma_op: TMA operation for SFB :param tiled_mma: TiledMma for A/B/SFA/C TMA atom construction :param tiled_mma_sfb: TiledMma for SFB (separate due to 2CTA replication) :param mma_tiler: MMA tiler shape (M, N, K) :param mma_tiler_sfb: MMA tiler shape for SFB :param cluster_layout_vmnk_shape: Cluster layout shape for A/B/SFA multicast :param cluster_layout_sfb_vmnk_shape: Cluster layout shape for SFB multicast :param epi_tile: Epilogue tile shape :param a_tensor: Fake GEMM domain A tensor :param b_tensor: Fake GEMM domain B tensor :param c_tensor: Fake GEMM domain C tensor :param sfa_tensor: Fake GEMM domain SFA tensor (atom-tiled layout) :param sfb_tensor: Fake GEMM domain SFB tensor (atom-tiled layout) :param offs: (experts,) cumsum offsets in data domain :param offs_padded: (experts,) cumsum offsets in padded scale domain :param workspace_ptr: Pointer to workspace for TMA descriptors :param expert_cnt: Total number of experts """ ChunkSize = 128 def __init__( self, scenario: Literal["2Dx3D", "2Dx2D"], sf_vec_size: int, # Codegen-time configs: dtypes a_dtype, b_dtype, c_dtype, sf_dtype, # Codegen-time configs: SMEM layouts a_smem_layout, b_smem_layout, epi_smem_layout, sfa_smem_layout, sfb_smem_layout, # Codegen-time configs: TMA ops a_tma_op, b_tma_op, c_tma_op, sfa_tma_op, sfb_tma_op, # Codegen-time configs: MMA / cluster / tile tiled_mma, tiled_mma_sfb, mma_tiler, mma_tiler_sfb, cluster_layout_vmnk_shape, cluster_layout_sfb_vmnk_shape, epi_tile, # Runtime params a_tensor: cute.Tensor, b_tensor: cute.Tensor, c_tensor: cute.Tensor, sfa_tensor: cute.Tensor, sfb_tensor: cute.Tensor, offs: cute.Tensor, offs_padded: cute.Tensor, workspace_ptr: Pointer, expert_cnt: Optional[Union[Int32, int]] = None, ) -> None: super().__init__() self.scenario = scenario self.sf_vec_size = sf_vec_size # Dtypes self.a_dtype = a_dtype self.b_dtype = b_dtype self.c_dtype = c_dtype self.sf_dtype = sf_dtype # SMEM layouts self.a_smem_layout = a_smem_layout self.b_smem_layout = b_smem_layout self.epi_smem_layout = epi_smem_layout self.sfa_smem_layout = sfa_smem_layout self.sfb_smem_layout = sfb_smem_layout # TMA ops self.a_tma_op = a_tma_op self.b_tma_op = b_tma_op self.c_tma_op = c_tma_op self.sfa_tma_op = sfa_tma_op self.sfb_tma_op = sfb_tma_op # MMA / cluster / tile self.tiled_mma = tiled_mma self.tiled_mma_sfb = tiled_mma_sfb self.mma_tiler = mma_tiler self.mma_tiler_sfb = mma_tiler_sfb self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape self.cluster_layout_sfb_vmnk_shape = cluster_layout_sfb_vmnk_shape self.epi_tile = epi_tile # Runtime params self.a_tensor = a_tensor self.b_tensor = b_tensor self.c_tensor = c_tensor self.sfa_tensor = sfa_tensor self.sfb_tensor = sfb_tensor self.offs = offs self.offs_padded = offs_padded self.expert_cnt = expert_cnt # Workspace with scenario-specific slot layout if scenario == "2Dx3D": self.workspace = TensormapWorkspace(workspace_ptr, ["c"]) else: self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "sfa", "sfb"]) @staticmethod def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int: """Calculate workspace size in bytes for tensormap descriptors.""" if scenario == "2Dx3D": return TensormapWorkspace.size_bytes(1, expert_cnt) # C only else: return TensormapWorkspace.size_bytes(4, expert_cnt) # A, B, SFA, SFB @cute.jit def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer: return self.workspace.get_ptr(tensor_name, executor_idx) @cute.jit def construct_and_write(self, lane_in_group: Int32, dependency=None) -> None: """ Create expert-wise tensormap descriptors for all experts. ``lane_in_group`` is the thread's position within its warp group (0..ChunkSize-1). The method loops internally over all experts in chunks of ``ChunkSize``, with two-phase pipeline synchronization per chunk. Per-chunk execution: 1. Phase 1: Build descriptors that do NOT depend on ``offs_padded`` (A/B for 2Dx2D, C for 2Dx3D). Overlaps with Group A's prefix sum. 2. Barrier: ``consumer.wait_and_advance()`` — all threads participate. 3. Phase 2: Build descriptors that depend on ``offs_padded`` (SFA/SFB for 2Dx2D). Reads padded offsets from SMEM buffer. 4. Release: ``handle.release()`` — all threads participate. :param lane_in_group: Thread's position within the warp group (0..127). :param dependency: ``(PipelineConsumer, smem_offs_padded)`` — the consumer for mbarrier sync, and the SMEM tensor of shape ``(ChunkSize + 1,)`` with layout ``[carry, offs_padded[0..127]]``. """ consumer, smem_offs_padded = dependency assert self.expert_cnt is not None num_chunks = (self.expert_cnt + self.ChunkSize - 1) // self.ChunkSize chunk_idx = cutlass.Int32(0) while chunk_idx < num_chunks: expert_idx = chunk_idx * self.ChunkSize + lane_in_group in_bounds = expert_idx < self.expert_cnt # Phase 1: non-dependent descriptors if in_bounds: if cutlass.const_expr(self.scenario == "2Dx2D"): self._construct_ab_descs_2dx2d(expert_idx) else: self._construct_c_desc_2dx3d(expert_idx) # All threads participate in barrier (fixed arrive count) handle = consumer.wait_and_advance() # Phase 2: dependent descriptors (read padded offsets from SMEM) if in_bounds: if cutlass.const_expr(self.scenario == "2Dx2D"): # smem_offs_padded layout: [carry, chunk[0], ..., chunk[127]] # padded_offset = smem[lane] (prev expert's cumulative) # padded_end = smem[lane + 1] (this expert's cumulative) padded_offset = smem_offs_padded[lane_in_group] padded_size_i = smem_offs_padded[lane_in_group + 1] - padded_offset self._construct_sf_descs_2dx2d_direct( expert_idx, padded_offset, padded_size_i ) # All threads release (fixed arrive count) handle.release() chunk_idx += 1 # ----------------------------------------------------------------- # 2Dx3D: C descriptor (same as MoEGroupedGemmTensormapConstructor) # ----------------------------------------------------------------- @cute.jit def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None: """ 2Dx3D: Create expert-wise C descriptor. C: (fake_m, n, 1) -> slice to (tokens_i, n, 1) per expert. """ token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx) c1 = cutlass.Int32(1) c_i = cute.domain_offset((token_offset, 0, 0), self.c_tensor) c_i = rewrite_tensor_shape(c_i, (tokens_i, self.c_tensor.shape[1], c1)) # type: ignore[index] tma_atom_c, _ = cpasync.make_tiled_tma_atom( self.c_tma_op, c_i, self.epi_smem_layout, self.epi_tile, ) cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx)) # ----------------------------------------------------------------- # 2Dx2D: A, B descriptors (same as MoEGroupedGemmTensormapConstructor) # ----------------------------------------------------------------- @cute.jit def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None: """ 2Dx2D: Create expert-wise A and B descriptors. A: (m, fake_k, 1) -> slice to (m, tokens_i, 1) B: (n, fake_k, 1) -> slice to (n, tokens_i, 1) """ token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx) c1 = cutlass.Int32(1) # A: (m, fake_k, 1) -> domain_offset + rewrite shape a_i = cute.domain_offset((0, token_offset, 0), self.a_tensor) a_i = rewrite_tensor_shape(a_i, (self.a_tensor.shape[0], tokens_i, c1)) # type: ignore[index] tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A( self.a_tma_op, a_i, self.a_smem_layout, self.mma_tiler, self.tiled_mma, self.cluster_layout_vmnk_shape, ) cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx)) # B: (n, fake_k, 1) -> domain_offset + rewrite shape b_i = cute.domain_offset((0, token_offset, 0), self.b_tensor) b_i = rewrite_tensor_shape(b_i, (self.b_tensor.shape[0], tokens_i, c1)) # type: ignore[index] tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B( self.b_tma_op, b_i, self.b_smem_layout, self.mma_tiler, self.tiled_mma, self.cluster_layout_vmnk_shape, ) cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx)) # ----------------------------------------------------------------- # 2Dx2D: SFA, SFB descriptors (new for block-scaled) # ----------------------------------------------------------------- @cute.jit def _construct_sf_descs_2dx2d_direct( self, expert_idx: Int32, padded_offset: Int32, padded_size_i: Int32, ) -> None: """ 2Dx2D: Create expert-wise SFA and SFB descriptors with pre-computed padded offset and size. This variant allows the caller to supply padded offsets from SMEM (in desc_init_kernel) instead of reading from ``self.offs_padded`` in GMEM. """ c1 = cutlass.Int32(1) a_chunks_to_move = ( padded_offset // self.sf_vec_size * cute.size(self.sfa_tensor, mode=[0]) // 128 ) a_elems_to_move = ( cute.size(self.sfa_tensor, mode=[0]) * padded_offset // self.sf_vec_size ) b_chunks_to_move = ( padded_offset // self.sf_vec_size * cute.size(self.sfb_tensor, mode=[0]) // 128 ) b_elems_to_move = ( cute.size(self.sfb_tensor, mode=[0]) * padded_offset // self.sf_vec_size ) per_expert_sfa_shape = (self.sfa_tensor.shape[0], padded_size_i, c1) # type: ignore[index] sfa_layout_i = tile_atom_to_shape_SF(per_expert_sfa_shape, self.sf_vec_size) sfa_i = cute.make_tensor( self.sfa_tensor.iterator + a_elems_to_move, sfa_layout_i ) tma_atom_sfa, _ = cute.nvgpu.make_tiled_tma_atom_A( self.sfa_tma_op, sfa_i, self.sfa_smem_layout, self.mma_tiler, self.tiled_mma, self.cluster_layout_vmnk_shape, internal_type=cutlass.Uint64, ) cpasync.copy_tensormap(tma_atom_sfa, self.get_desc_ptr("sfa", expert_idx)) per_expert_sfb_shape = (self.sfb_tensor.shape[0], padded_size_i, c1) # type: ignore[index] sfb_layout_i = tile_atom_to_shape_SF(per_expert_sfb_shape, self.sf_vec_size) sfb_i = cute.make_tensor( self.sfb_tensor.iterator + b_elems_to_move, sfb_layout_i ) tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B( self.sfb_tma_op, sfb_i, self.sfb_smem_layout, self.mma_tiler_sfb, self.tiled_mma_sfb, self.cluster_layout_sfb_vmnk_shape, internal_type=cutlass.Uint64, ) cpasync.copy_tensormap(tma_atom_sfb, self.get_desc_ptr("sfb", expert_idx))