mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 08:50:09 +00:00
911 lines
34 KiB
Python
911 lines
34 KiB
Python
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
# SPDX-License-Identifier: BSD-3-Clause
|
||
|
||
# Redistribution and use in source and binary forms, with or without
|
||
# modification, are permitted provided that the following conditions are met:
|
||
|
||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||
# list of conditions and the following disclaimer.
|
||
|
||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||
# this list of conditions and the following disclaimer in the documentation
|
||
# and/or other materials provided with the distribution.
|
||
|
||
# 3. Neither the name of the copyright holder nor the names of its
|
||
# contributors may be used to endorse or promote products derived from
|
||
# this software without specific prior written permission.
|
||
|
||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||
|
||
"""
|
||
Online TMA Descriptor Construction Utilities.
|
||
|
||
Provides utilities for dynamically creating TMA descriptors at kernel runtime
|
||
based on runtime-provided information (problem sizes, pointers, etc.).
|
||
|
||
Key components:
|
||
- OnlineTensormapDescCreator: Simplified ABC for TMA descriptor builders (2 abstract methods)
|
||
- TensormapWorkspace: Helper for linear workspace layout of TMA descriptors
|
||
- MoEGroupedGemmTensormapConstructor: TMA descriptor constructor for MoE Grouped GEMM
|
||
- GeneralGroupedGemmTensormapConstructor: TMA descriptor constructor for general Grouped GEMM
|
||
- Pointer utility functions (ptr_offset_bytes, gmem_ptr_to_generic, etc.)
|
||
- tensormap_ptr_for_copy: Convert raw desc ptr to cute.copy-compatible type
|
||
- compute_expert_token_range: Compute per-expert token offset and count from offs
|
||
- rewrite_tensor_shape: Debug-friendly tensor shape rewrite utility
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import Optional, Literal, Tuple, Union
|
||
|
||
import cutlass
|
||
import cutlass.cute as cute
|
||
from cutlass.cute.typing import AddressSpace, Pointer
|
||
from cutlass.cute.nvgpu import cpasync
|
||
from cutlass.cutlass_dsl import dsl_user_op, Int32
|
||
from cutlass._mlir import ir
|
||
from cutlass._mlir.dialects import llvm
|
||
from cutlass._mlir.dialects import cute as _cute_ir
|
||
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
|
||
from dataclasses import dataclass
|
||
|
||
from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF
|
||
|
||
TensormapDescBytes = 128
|
||
|
||
# =============================================================================
|
||
# Pointer Utilities
|
||
# =============================================================================
|
||
|
||
|
||
@dsl_user_op
|
||
@cute.jit
|
||
def spin_wait(
|
||
ptr: Pointer, condition, fail_sleep_cycles: int = 100, *, loc=None, ip=None
|
||
) -> None:
|
||
"""
|
||
Generic spin-wait.
|
||
Example usage:
|
||
# Wait until counter >= total_blocks
|
||
spin_wait(counter_ptr, lambda x: x >= total_blocks, fail_sleep_cycles=100)
|
||
|
||
# Wait until flag == 1
|
||
spin_wait(flag_ptr, lambda x: x == 1)
|
||
"""
|
||
current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip)
|
||
while not condition(current):
|
||
# Load with L1 cache bypass (ld.global.cg)
|
||
if cutlass.const_expr(fail_sleep_cycles > 0):
|
||
cute.arch.nanosleep(sleep_time=fail_sleep_cycles, loc=loc, ip=ip)
|
||
current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip)
|
||
|
||
|
||
@dsl_user_op
|
||
def gmem_ptr_to_generic(
|
||
gmem_ptr: Pointer,
|
||
*,
|
||
loc: Optional[ir.Location] = None,
|
||
ip: Optional[ir.InsertionPoint] = None,
|
||
) -> Pointer:
|
||
if gmem_ptr.memspace != AddressSpace.gmem:
|
||
raise ValueError(
|
||
f"gmem_ptr_to_generic requires pointer in gmem address space, "
|
||
f"got {gmem_ptr.memspace}"
|
||
)
|
||
# Get LLVM pointer and cast to generic address space
|
||
llvm_ptr = gmem_ptr.to_llvm_ptr(loc=loc, ip=ip)
|
||
generic_llvm_ptr = llvm.addrspacecast(
|
||
llvm.PointerType.get(AddressSpace.generic), llvm_ptr, loc=loc, ip=ip
|
||
)
|
||
# Create a new cute.Pointer with generic address space, preserving alignment
|
||
return cute.make_ptr(
|
||
gmem_ptr.dtype,
|
||
generic_llvm_ptr,
|
||
AddressSpace.generic,
|
||
assumed_align=gmem_ptr.alignment,
|
||
loc=loc,
|
||
ip=ip,
|
||
)
|
||
|
||
|
||
@dsl_user_op
|
||
def generic_ptr_to_gmem(
|
||
generic_ptr: Pointer,
|
||
*,
|
||
loc: Optional[ir.Location] = None,
|
||
ip: Optional[ir.InsertionPoint] = None,
|
||
) -> Pointer:
|
||
if generic_ptr.memspace != AddressSpace.generic:
|
||
raise ValueError(
|
||
f"generic_ptr_to_gmem requires pointer in generic address space, "
|
||
f"got {generic_ptr.memspace}"
|
||
)
|
||
# Get LLVM pointer and cast to gmem address space
|
||
llvm_ptr = generic_ptr.to_llvm_ptr(loc=loc, ip=ip)
|
||
gmem_llvm_ptr = llvm.addrspacecast(
|
||
llvm.PointerType.get(AddressSpace.gmem), llvm_ptr, loc=loc, ip=ip
|
||
)
|
||
# Create a new cute.Pointer with gmem address space, preserving alignment
|
||
return cute.make_ptr(
|
||
generic_ptr.dtype,
|
||
gmem_llvm_ptr,
|
||
AddressSpace.gmem,
|
||
assumed_align=generic_ptr.alignment,
|
||
loc=loc,
|
||
ip=ip,
|
||
)
|
||
|
||
|
||
@dsl_user_op
|
||
def prefetch_tma_descriptor(tma_desc_ptr: Pointer, *, loc=None, ip=None) -> None:
|
||
"""
|
||
Prefetch a TMA descriptor from global memory.
|
||
|
||
This function prefetches the TMA descriptor pointed to by tma_desc_ptr
|
||
into the TMA descriptor cache. The pointer must be in generic or global
|
||
address space. If a gmem pointer is passed, it will be automatically
|
||
converted to generic address space.
|
||
|
||
:param tma_desc_ptr: Pointer to the TMA descriptor in global or generic memory
|
||
:type tma_desc_ptr: Pointer
|
||
:raises ValueError: If pointer is not in generic or global address space
|
||
"""
|
||
if tma_desc_ptr.memspace not in (AddressSpace.gmem, AddressSpace.generic):
|
||
raise ValueError(
|
||
f"prefetch_tma_descriptor requires pointer in gmem or generic address space, "
|
||
f"got {tma_desc_ptr.memspace}"
|
||
)
|
||
# Convert gmem pointer to generic if needed
|
||
if tma_desc_ptr.memspace == AddressSpace.gmem:
|
||
tma_desc_ptr = gmem_ptr_to_generic(tma_desc_ptr, loc=loc, ip=ip)
|
||
# Convert cute.Pointer to LLVM pointer for prefetch
|
||
llvm_ptr = tma_desc_ptr.to_llvm_ptr(loc=loc, ip=ip)
|
||
from cutlass.cute.arch.nvvm_wrappers import prefetch as nvvm_prefetch
|
||
|
||
nvvm_prefetch(llvm_ptr, tensormap=True, loc=loc, ip=ip)
|
||
|
||
|
||
def ptr_offset_bytes(ptr: Pointer, byte_offset: int) -> Pointer:
|
||
"""Offset a pointer by a given number of bytes."""
|
||
element_offset = byte_offset * 8 // ptr.dtype.width
|
||
return ptr + element_offset
|
||
|
||
|
||
@dsl_user_op
|
||
def tensormap_ptr_for_copy(raw_ptr: Pointer, *, loc=None, ip=None) -> Pointer:
|
||
"""
|
||
Convert a raw TMA descriptor gmem pointer to the type expected by cute.copy.
|
||
|
||
cute.copy requires the tma_desc_ptr to be in generic address space and
|
||
recast to TmaDescriptorTiledType. This utility performs both conversions.
|
||
|
||
:param raw_ptr: Raw pointer to TMA descriptor in gmem
|
||
:type raw_ptr: Pointer
|
||
:return: Pointer compatible with cute.copy's tma_desc_ptr parameter
|
||
:rtype: Pointer
|
||
"""
|
||
generic_ptr = gmem_ptr_to_generic(raw_ptr, loc=loc, ip=ip)
|
||
tma_desc_ptr_ty = _cute_ir.PtrType.get(
|
||
_cute_nvgpu_ir.TmaDescriptorTiledType.get(),
|
||
generic_ptr.memspace,
|
||
generic_ptr.alignment,
|
||
)
|
||
return _cute_ir.recast_iter(tma_desc_ptr_ty, generic_ptr.value)
|
||
|
||
|
||
# =============================================================================
|
||
# MoE Utilities
|
||
# =============================================================================
|
||
|
||
|
||
@dsl_user_op
|
||
@cute.jit
|
||
def compute_expert_token_range(
|
||
offs: cute.Tensor,
|
||
expert_idx: Int32,
|
||
*,
|
||
loc=None,
|
||
ip=None,
|
||
) -> Tuple[Int32, Int32]:
|
||
"""
|
||
Compute token offset and count for a given expert from the cumsum offs tensor.
|
||
|
||
:param offs: Cumulative sum tensor of token counts per expert, shape (experts,)
|
||
:param expert_idx: Index of the expert
|
||
:return: (token_offset, tokens_i) where token_offset is the start position
|
||
and tokens_i is the number of tokens for this expert
|
||
"""
|
||
token_offset = Int32(0)
|
||
if expert_idx > Int32(0):
|
||
token_offset = offs[expert_idx - 1] # type: ignore[assignment]
|
||
tokens_i = offs[expert_idx] - token_offset
|
||
return token_offset, tokens_i
|
||
|
||
|
||
@dsl_user_op
|
||
def rewrite_tensor_shape(
|
||
tensor: cute.Tensor,
|
||
new_shape: Tuple,
|
||
*,
|
||
loc=None,
|
||
ip=None,
|
||
) -> cute.Tensor:
|
||
"""
|
||
Rewrite tensor shape while keeping the same stride and iterator.
|
||
|
||
This is primarily for debug friendliness - shows the actual expert's shape
|
||
instead of the fake global shape. No runtime overhead as it becomes
|
||
dead code in non-debug builds.
|
||
|
||
:param tensor: Source tensor whose stride and iterator to preserve
|
||
:param new_shape: New shape to apply
|
||
:return: New tensor with the given shape but original stride and iterator
|
||
"""
|
||
new_layout = cute.make_layout(new_shape, stride=tensor.stride, loc=loc, ip=ip)
|
||
return cute.make_tensor(tensor.iterator, new_layout, loc=loc, ip=ip)
|
||
|
||
|
||
# =============================================================================
|
||
# TMA Descriptor Workspace Helper
|
||
# =============================================================================
|
||
|
||
|
||
class TensormapWorkspace:
|
||
"""
|
||
Helper for linear workspace layout of TMA descriptors.
|
||
|
||
Manages address calculation for a workspace buffer containing TMA descriptors
|
||
organized as: for each executor (e.g., expert or group), a fixed set of
|
||
named descriptor slots.
|
||
|
||
Layout: [slot_0_exec_0, slot_1_exec_0, ..., slot_0_exec_1, slot_1_exec_1, ...]
|
||
|
||
Example:
|
||
# 2Dx3D MoE: only C is expert-wise
|
||
workspace = TensormapWorkspace(workspace_ptr, ["c"])
|
||
|
||
# 2Dx2D MoE: A and B are expert-wise
|
||
workspace = TensormapWorkspace(workspace_ptr, ["a", "b"])
|
||
|
||
# General grouped GEMM: all three tensors
|
||
workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "c"])
|
||
"""
|
||
|
||
def __init__(self, workspace_ptr: Pointer, slot_names: list):
|
||
"""
|
||
:param workspace_ptr: Pointer to the beginning of the workspace buffer
|
||
:param slot_names: Ordered list of tensor names, defining the slot layout
|
||
per executor. e.g., ["a", "b", "c"]
|
||
"""
|
||
self.workspace_ptr = workspace_ptr
|
||
self._name_to_slot = {name: i for i, name in enumerate(slot_names)}
|
||
self._slots_per_executor = len(slot_names)
|
||
|
||
@cute.jit
|
||
def get_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
|
||
"""
|
||
Get the workspace pointer for a specific TMA descriptor.
|
||
|
||
:param tensor_name: Name of the tensor (must be one of the slot_names)
|
||
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx)
|
||
:return: Aligned pointer to the TMA descriptor in workspace
|
||
"""
|
||
if cutlass.const_expr(tensor_name not in self._name_to_slot):
|
||
raise ValueError(
|
||
f"Invalid tensor_name '{tensor_name}', "
|
||
f"expected one of {list(self._name_to_slot.keys())}"
|
||
)
|
||
slot = self._name_to_slot[tensor_name]
|
||
byte_offset = (
|
||
executor_idx * self._slots_per_executor + slot
|
||
) * TensormapDescBytes
|
||
return ptr_offset_bytes(self.workspace_ptr, byte_offset).align(
|
||
TensormapDescBytes
|
||
)
|
||
|
||
@staticmethod
|
||
def size_bytes(num_slots: int, num_executors: int) -> int:
|
||
"""
|
||
Calculate workspace size in bytes.
|
||
|
||
:param num_slots: Number of descriptor slots per executor
|
||
:param num_executors: Total number of executors (e.g., expert_cnt or group_cnt)
|
||
:return: Total workspace size in bytes
|
||
"""
|
||
return num_slots * num_executors * TensormapDescBytes
|
||
|
||
|
||
# =============================================================================
|
||
# Online TMA Descriptor Creator (Abstract Base Class)
|
||
# =============================================================================
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class OnlineTensormapDescCreator(ABC):
|
||
"""
|
||
Abstract base class for building TMA descriptors online (at kernel runtime).
|
||
|
||
Subclasses store all needed parameters (both codegen-time configs and runtime
|
||
values) as explicit instance attributes in __init__. No dict-based APIs.
|
||
|
||
Subclasses must implement exactly 2 abstract methods:
|
||
- construct_and_write: Build TMA descriptor(s) for one executor and write to workspace
|
||
- get_desc_ptr: Return raw gmem pointer to a specific descriptor in workspace
|
||
|
||
To convert the raw pointer for use with cute.copy, callers should use the
|
||
standalone tensormap_ptr_for_copy() utility.
|
||
"""
|
||
|
||
@abstractmethod
|
||
def construct_and_write(self, executor_idx: Int32, dependency=None) -> None:
|
||
"""
|
||
Build TMA descriptor(s) for one executor and write to workspace.
|
||
|
||
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx).
|
||
Semantics may vary by subclass when ``dependency`` is provided.
|
||
:param dependency: Optional pipeline consumer for inter-warp-group
|
||
synchronization. When provided, the subclass decides when to wait
|
||
(via ``dependency.wait_and_advance()``) and release. The subclass
|
||
also decides how to interpret ``executor_idx`` in this mode.
|
||
"""
|
||
...
|
||
|
||
@abstractmethod
|
||
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
|
||
"""
|
||
Get the raw gmem pointer to a specific TMA descriptor in workspace.
|
||
|
||
:param tensor_name: Name identifying which tensor's descriptor
|
||
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx)
|
||
:return: Raw pointer (gmem) to the TMA descriptor
|
||
"""
|
||
...
|
||
|
||
|
||
# =============================================================================
|
||
# MoE Grouped GEMM Tensormap Constructor
|
||
# =============================================================================
|
||
|
||
|
||
class MoEGroupedGemmTensormapConstructor(OnlineTensormapDescCreator):
|
||
"""
|
||
Tensormap descriptor constructor for MoE Grouped GEMM (expert-wise descriptors only).
|
||
|
||
Non-expert-wise descriptors are passed directly at kernel launch.
|
||
This class only handles:
|
||
- 2Dx3D: C descriptors (expert-wise, to avoid write conflicts)
|
||
- 2Dx2D: A and B descriptors (expert-wise, tokens is reduction axis)
|
||
|
||
All parameters are stored as explicit instance attributes (no dicts).
|
||
|
||
Workspace layout:
|
||
- 2Dx3D: [C_0, C_1, ..., C_{n-1}]
|
||
- 2Dx2D: [A_0, A_1, ..., A_{n-1}, B_0, B_1, ..., B_{n-1}]
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
scenario: Literal["2Dx3D", "2Dx2D"],
|
||
# Codegen-time configs
|
||
a_dtype,
|
||
b_dtype,
|
||
c_dtype,
|
||
a_smem_layout,
|
||
b_smem_layout,
|
||
epi_smem_layout,
|
||
a_tma_op,
|
||
b_tma_op,
|
||
c_tma_op,
|
||
tiled_mma,
|
||
mma_tiler,
|
||
cluster_layout_vmnk_shape,
|
||
epi_tile,
|
||
# Runtime params
|
||
a_tensor: cute.Tensor, # fake GEMM domain A
|
||
b_tensor: cute.Tensor, # fake GEMM domain B
|
||
c_tensor: cute.Tensor, # fake GEMM domain C
|
||
offs: cute.Tensor, # (experts,) cumsum
|
||
workspace_ptr: Pointer,
|
||
) -> None:
|
||
super().__init__()
|
||
self.scenario = scenario
|
||
# Codegen-time configs
|
||
self.a_dtype = a_dtype
|
||
self.b_dtype = b_dtype
|
||
self.c_dtype = c_dtype
|
||
self.a_smem_layout = a_smem_layout
|
||
self.b_smem_layout = b_smem_layout
|
||
self.epi_smem_layout = epi_smem_layout
|
||
self.a_tma_op = a_tma_op
|
||
self.b_tma_op = b_tma_op
|
||
self.c_tma_op = c_tma_op
|
||
self.tiled_mma = tiled_mma
|
||
self.mma_tiler = mma_tiler
|
||
self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape
|
||
self.epi_tile = epi_tile
|
||
# Runtime params
|
||
self.a_tensor = a_tensor
|
||
self.b_tensor = b_tensor
|
||
self.c_tensor = c_tensor
|
||
self.offs = offs
|
||
# Workspace with scenario-specific slot layout
|
||
if scenario == "2Dx3D":
|
||
self.workspace = TensormapWorkspace(workspace_ptr, ["c"])
|
||
else:
|
||
self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b"])
|
||
|
||
@staticmethod
|
||
def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int:
|
||
"""Calculate workspace size in bytes for tensormap descriptors."""
|
||
if scenario == "2Dx3D":
|
||
return TensormapWorkspace.size_bytes(1, expert_cnt) # only C
|
||
else:
|
||
return TensormapWorkspace.size_bytes(2, expert_cnt) # A and B
|
||
|
||
@cute.jit
|
||
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
|
||
return self.workspace.get_ptr(tensor_name, executor_idx)
|
||
|
||
@cute.jit
|
||
def construct_and_write(self, executor_idx: Int32, dependency=None) -> None:
|
||
"""
|
||
Create expert-wise tensormap descriptors for the given expert.
|
||
|
||
- 2Dx3D: Creates C descriptor for this expert
|
||
- 2Dx2D: Creates A and B descriptors for this expert
|
||
"""
|
||
if cutlass.const_expr(self.scenario == "2Dx3D"):
|
||
self._construct_c_desc_2dx3d(executor_idx)
|
||
else: # 2Dx2D
|
||
self._construct_ab_descs_2dx2d(executor_idx)
|
||
|
||
@cute.jit
|
||
def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None:
|
||
"""
|
||
2Dx3D: Create expert-wise C descriptor.
|
||
C tensor: (fake_m, n, 1) = (tokens_sum, intermediate, 1)
|
||
Slice fake_m -> (tokens_i, intermediate, 1) per expert.
|
||
"""
|
||
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
|
||
|
||
c_ptr = self.c_tensor.iterator
|
||
c_stride = self.c_tensor.stride
|
||
intermediate = self.c_tensor.shape[1] # type: ignore[index]
|
||
|
||
c1 = cutlass.Int32(1)
|
||
c0 = cutlass.Int32(0)
|
||
|
||
c_ptr_i = c_ptr + token_offset * c_stride[0] # type: ignore[index]
|
||
c_layout_i = cute.make_layout(
|
||
(tokens_i, intermediate, c1),
|
||
stride=(c_stride[0], c_stride[1], c0), # type: ignore[index]
|
||
)
|
||
c_tensor_i = cute.make_tensor(c_ptr_i, c_layout_i)
|
||
|
||
tma_atom_c, _ = cpasync.make_tiled_tma_atom(
|
||
self.c_tma_op,
|
||
c_tensor_i,
|
||
self.epi_smem_layout,
|
||
self.epi_tile,
|
||
)
|
||
cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx))
|
||
|
||
@cute.jit
|
||
def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None:
|
||
"""
|
||
2Dx2D: Create expert-wise A and B descriptors.
|
||
A: (m, fake_k, 1) -> slice to (m, tokens_i, 1)
|
||
B: (n, fake_k, 1) -> slice to (n, tokens_i, 1)
|
||
"""
|
||
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
|
||
|
||
c1 = cutlass.Int32(1)
|
||
c0 = cutlass.Int32(0)
|
||
|
||
# A tensor: (m, fake_k, 1) -> (m, tokens_i, 1)
|
||
a_ptr = self.a_tensor.iterator
|
||
a_stride = self.a_tensor.stride
|
||
a_m = self.a_tensor.shape[0] # type: ignore[index]
|
||
|
||
a_ptr_i = a_ptr + token_offset * a_stride[1] # type: ignore[index]
|
||
a_layout_i = cute.make_layout(
|
||
(a_m, tokens_i, c1),
|
||
stride=(a_stride[0], a_stride[1], c0), # type: ignore[index]
|
||
)
|
||
a_tensor_i = cute.make_tensor(a_ptr_i, a_layout_i)
|
||
|
||
tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A(
|
||
self.a_tma_op,
|
||
a_tensor_i,
|
||
self.a_smem_layout,
|
||
self.mma_tiler,
|
||
self.tiled_mma,
|
||
self.cluster_layout_vmnk_shape,
|
||
)
|
||
cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx))
|
||
|
||
# B tensor: (n, fake_k, 1) -> (n, tokens_i, 1)
|
||
b_ptr = self.b_tensor.iterator
|
||
b_stride = self.b_tensor.stride
|
||
b_n = self.b_tensor.shape[0] # type: ignore[index]
|
||
|
||
b_ptr_i = b_ptr + token_offset * b_stride[1] # type: ignore[index]
|
||
b_layout_i = cute.make_layout(
|
||
(b_n, tokens_i, c1),
|
||
stride=(b_stride[0], b_stride[1], c0), # type: ignore[index]
|
||
)
|
||
b_tensor_i = cute.make_tensor(b_ptr_i, b_layout_i)
|
||
|
||
tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
|
||
self.b_tma_op,
|
||
b_tensor_i,
|
||
self.b_smem_layout,
|
||
self.mma_tiler,
|
||
self.tiled_mma,
|
||
self.cluster_layout_vmnk_shape,
|
||
)
|
||
cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx))
|
||
|
||
|
||
# =============================================================================
|
||
# MoE Scaled Grouped GEMM Tensormap Constructor
|
||
# =============================================================================
|
||
|
||
|
||
class MoEScaledGroupedGemmTensormapConstructor(OnlineTensormapDescCreator):
|
||
"""
|
||
Tensormap descriptor constructor for MoE Scaled Grouped GEMM (block-scaled).
|
||
|
||
.. py:attribute:: ChunkSize
|
||
:value: 128
|
||
|
||
Number of experts processed per chunk in the desc_init_kernel.
|
||
Must match the warp-group width (4 warps × 32 threads).
|
||
|
||
Extends MoEGroupedGemmTensormapConstructor with SFA/SFB descriptor support.
|
||
|
||
Expert-wise descriptors only — non-expert-wise descriptors are passed
|
||
directly at kernel launch.
|
||
|
||
Workspace layout:
|
||
- 2Dx3D: [C_0, C_1, ..., C_{n-1}] (1 slot per expert)
|
||
- 2Dx2D: [A_0, B_0, SFA_0, SFB_0, A_1, B_1, SFA_1, SFB_1, ...] (4 slots per expert)
|
||
|
||
:param scenario: "2Dx3D" or "2Dx2D"
|
||
:param sf_vec_size: Scale factor vector size (32 for MXFP8/MXFP4, 16 for NVFP4)
|
||
:param a_dtype: Data type for tensor A
|
||
:param b_dtype: Data type for tensor B
|
||
:param c_dtype: Data type for tensor C
|
||
:param sf_dtype: Data type for scale factors (SFA/SFB)
|
||
:param a_smem_layout: SMEM layout for A TMA
|
||
:param b_smem_layout: SMEM layout for B TMA
|
||
:param epi_smem_layout: SMEM layout for epilogue (C) TMA
|
||
:param sfa_smem_layout: SMEM layout for SFA TMA
|
||
:param sfb_smem_layout: SMEM layout for SFB TMA
|
||
:param a_tma_op: TMA operation for A
|
||
:param b_tma_op: TMA operation for B
|
||
:param c_tma_op: TMA operation for C (S2G store or reduce)
|
||
:param sfa_tma_op: TMA operation for SFA
|
||
:param sfb_tma_op: TMA operation for SFB
|
||
:param tiled_mma: TiledMma for A/B/SFA/C TMA atom construction
|
||
:param tiled_mma_sfb: TiledMma for SFB (separate due to 2CTA replication)
|
||
:param mma_tiler: MMA tiler shape (M, N, K)
|
||
:param mma_tiler_sfb: MMA tiler shape for SFB
|
||
:param cluster_layout_vmnk_shape: Cluster layout shape for A/B/SFA multicast
|
||
:param cluster_layout_sfb_vmnk_shape: Cluster layout shape for SFB multicast
|
||
:param epi_tile: Epilogue tile shape
|
||
:param a_tensor: Fake GEMM domain A tensor
|
||
:param b_tensor: Fake GEMM domain B tensor
|
||
:param c_tensor: Fake GEMM domain C tensor
|
||
:param sfa_tensor: Fake GEMM domain SFA tensor (atom-tiled layout)
|
||
:param sfb_tensor: Fake GEMM domain SFB tensor (atom-tiled layout)
|
||
:param offs: (experts,) cumsum offsets in data domain
|
||
:param offs_padded: (experts,) cumsum offsets in padded scale domain
|
||
:param workspace_ptr: Pointer to workspace for TMA descriptors
|
||
:param expert_cnt: Total number of experts
|
||
"""
|
||
|
||
ChunkSize = 128
|
||
|
||
def __init__(
|
||
self,
|
||
scenario: Literal["2Dx3D", "2Dx2D"],
|
||
sf_vec_size: int,
|
||
# Codegen-time configs: dtypes
|
||
a_dtype,
|
||
b_dtype,
|
||
c_dtype,
|
||
sf_dtype,
|
||
# Codegen-time configs: SMEM layouts
|
||
a_smem_layout,
|
||
b_smem_layout,
|
||
epi_smem_layout,
|
||
sfa_smem_layout,
|
||
sfb_smem_layout,
|
||
# Codegen-time configs: TMA ops
|
||
a_tma_op,
|
||
b_tma_op,
|
||
c_tma_op,
|
||
sfa_tma_op,
|
||
sfb_tma_op,
|
||
# Codegen-time configs: MMA / cluster / tile
|
||
tiled_mma,
|
||
tiled_mma_sfb,
|
||
mma_tiler,
|
||
mma_tiler_sfb,
|
||
cluster_layout_vmnk_shape,
|
||
cluster_layout_sfb_vmnk_shape,
|
||
epi_tile,
|
||
# Runtime params
|
||
a_tensor: cute.Tensor,
|
||
b_tensor: cute.Tensor,
|
||
c_tensor: cute.Tensor,
|
||
sfa_tensor: cute.Tensor,
|
||
sfb_tensor: cute.Tensor,
|
||
offs: cute.Tensor,
|
||
offs_padded: cute.Tensor,
|
||
workspace_ptr: Pointer,
|
||
expert_cnt: Optional[Union[Int32, int]] = None,
|
||
) -> None:
|
||
super().__init__()
|
||
self.scenario = scenario
|
||
self.sf_vec_size = sf_vec_size
|
||
# Dtypes
|
||
self.a_dtype = a_dtype
|
||
self.b_dtype = b_dtype
|
||
self.c_dtype = c_dtype
|
||
self.sf_dtype = sf_dtype
|
||
# SMEM layouts
|
||
self.a_smem_layout = a_smem_layout
|
||
self.b_smem_layout = b_smem_layout
|
||
self.epi_smem_layout = epi_smem_layout
|
||
self.sfa_smem_layout = sfa_smem_layout
|
||
self.sfb_smem_layout = sfb_smem_layout
|
||
# TMA ops
|
||
self.a_tma_op = a_tma_op
|
||
self.b_tma_op = b_tma_op
|
||
self.c_tma_op = c_tma_op
|
||
self.sfa_tma_op = sfa_tma_op
|
||
self.sfb_tma_op = sfb_tma_op
|
||
# MMA / cluster / tile
|
||
self.tiled_mma = tiled_mma
|
||
self.tiled_mma_sfb = tiled_mma_sfb
|
||
self.mma_tiler = mma_tiler
|
||
self.mma_tiler_sfb = mma_tiler_sfb
|
||
self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape
|
||
self.cluster_layout_sfb_vmnk_shape = cluster_layout_sfb_vmnk_shape
|
||
self.epi_tile = epi_tile
|
||
# Runtime params
|
||
self.a_tensor = a_tensor
|
||
self.b_tensor = b_tensor
|
||
self.c_tensor = c_tensor
|
||
self.sfa_tensor = sfa_tensor
|
||
self.sfb_tensor = sfb_tensor
|
||
self.offs = offs
|
||
self.offs_padded = offs_padded
|
||
self.expert_cnt = expert_cnt
|
||
# Workspace with scenario-specific slot layout
|
||
if scenario == "2Dx3D":
|
||
self.workspace = TensormapWorkspace(workspace_ptr, ["c"])
|
||
else:
|
||
self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "sfa", "sfb"])
|
||
|
||
@staticmethod
|
||
def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int:
|
||
"""Calculate workspace size in bytes for tensormap descriptors."""
|
||
if scenario == "2Dx3D":
|
||
return TensormapWorkspace.size_bytes(1, expert_cnt) # C only
|
||
else:
|
||
return TensormapWorkspace.size_bytes(4, expert_cnt) # A, B, SFA, SFB
|
||
|
||
@cute.jit
|
||
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
|
||
return self.workspace.get_ptr(tensor_name, executor_idx)
|
||
|
||
@cute.jit
|
||
def construct_and_write(self, lane_in_group: Int32, dependency=None) -> None:
|
||
"""
|
||
Create expert-wise tensormap descriptors for all experts.
|
||
|
||
``lane_in_group`` is the thread's position within its warp group
|
||
(0..ChunkSize-1). The method loops internally over all experts in
|
||
chunks of ``ChunkSize``, with two-phase pipeline synchronization
|
||
per chunk.
|
||
|
||
Per-chunk execution:
|
||
|
||
1. Phase 1: Build descriptors that do NOT depend on ``offs_padded``
|
||
(A/B for 2Dx2D, C for 2Dx3D). Overlaps with Group A's prefix sum.
|
||
2. Barrier: ``consumer.wait_and_advance()`` — all threads participate.
|
||
3. Phase 2: Build descriptors that depend on ``offs_padded``
|
||
(SFA/SFB for 2Dx2D). Reads padded offsets from SMEM buffer.
|
||
4. Release: ``handle.release()`` — all threads participate.
|
||
|
||
:param lane_in_group: Thread's position within the warp group (0..127).
|
||
:param dependency: ``(PipelineConsumer, smem_offs_padded)`` — the
|
||
consumer for mbarrier sync, and the SMEM tensor of shape
|
||
``(ChunkSize + 1,)`` with layout ``[carry, offs_padded[0..127]]``.
|
||
"""
|
||
consumer, smem_offs_padded = dependency
|
||
assert self.expert_cnt is not None
|
||
num_chunks = (self.expert_cnt + self.ChunkSize - 1) // self.ChunkSize
|
||
|
||
chunk_idx = cutlass.Int32(0)
|
||
while chunk_idx < num_chunks:
|
||
expert_idx = chunk_idx * self.ChunkSize + lane_in_group
|
||
in_bounds = expert_idx < self.expert_cnt
|
||
|
||
# Phase 1: non-dependent descriptors
|
||
if in_bounds:
|
||
if cutlass.const_expr(self.scenario == "2Dx2D"):
|
||
self._construct_ab_descs_2dx2d(expert_idx)
|
||
else:
|
||
self._construct_c_desc_2dx3d(expert_idx)
|
||
|
||
# All threads participate in barrier (fixed arrive count)
|
||
handle = consumer.wait_and_advance()
|
||
|
||
# Phase 2: dependent descriptors (read padded offsets from SMEM)
|
||
if in_bounds:
|
||
if cutlass.const_expr(self.scenario == "2Dx2D"):
|
||
# smem_offs_padded layout: [carry, chunk[0], ..., chunk[127]]
|
||
# padded_offset = smem[lane] (prev expert's cumulative)
|
||
# padded_end = smem[lane + 1] (this expert's cumulative)
|
||
padded_offset = smem_offs_padded[lane_in_group]
|
||
padded_size_i = smem_offs_padded[lane_in_group + 1] - padded_offset
|
||
self._construct_sf_descs_2dx2d_direct(
|
||
expert_idx, padded_offset, padded_size_i
|
||
)
|
||
|
||
# All threads release (fixed arrive count)
|
||
handle.release()
|
||
|
||
chunk_idx += 1
|
||
|
||
# -----------------------------------------------------------------
|
||
# 2Dx3D: C descriptor (same as MoEGroupedGemmTensormapConstructor)
|
||
# -----------------------------------------------------------------
|
||
|
||
@cute.jit
|
||
def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None:
|
||
"""
|
||
2Dx3D: Create expert-wise C descriptor.
|
||
C: (fake_m, n, 1) -> slice to (tokens_i, n, 1) per expert.
|
||
"""
|
||
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
|
||
c1 = cutlass.Int32(1)
|
||
|
||
c_i = cute.domain_offset((token_offset, 0, 0), self.c_tensor)
|
||
c_i = rewrite_tensor_shape(c_i, (tokens_i, self.c_tensor.shape[1], c1)) # type: ignore[index]
|
||
|
||
tma_atom_c, _ = cpasync.make_tiled_tma_atom(
|
||
self.c_tma_op,
|
||
c_i,
|
||
self.epi_smem_layout,
|
||
self.epi_tile,
|
||
)
|
||
cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx))
|
||
|
||
# -----------------------------------------------------------------
|
||
# 2Dx2D: A, B descriptors (same as MoEGroupedGemmTensormapConstructor)
|
||
# -----------------------------------------------------------------
|
||
|
||
@cute.jit
|
||
def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None:
|
||
"""
|
||
2Dx2D: Create expert-wise A and B descriptors.
|
||
A: (m, fake_k, 1) -> slice to (m, tokens_i, 1)
|
||
B: (n, fake_k, 1) -> slice to (n, tokens_i, 1)
|
||
"""
|
||
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
|
||
c1 = cutlass.Int32(1)
|
||
|
||
# A: (m, fake_k, 1) -> domain_offset + rewrite shape
|
||
a_i = cute.domain_offset((0, token_offset, 0), self.a_tensor)
|
||
a_i = rewrite_tensor_shape(a_i, (self.a_tensor.shape[0], tokens_i, c1)) # type: ignore[index]
|
||
|
||
tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A(
|
||
self.a_tma_op,
|
||
a_i,
|
||
self.a_smem_layout,
|
||
self.mma_tiler,
|
||
self.tiled_mma,
|
||
self.cluster_layout_vmnk_shape,
|
||
)
|
||
cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx))
|
||
|
||
# B: (n, fake_k, 1) -> domain_offset + rewrite shape
|
||
b_i = cute.domain_offset((0, token_offset, 0), self.b_tensor)
|
||
b_i = rewrite_tensor_shape(b_i, (self.b_tensor.shape[0], tokens_i, c1)) # type: ignore[index]
|
||
|
||
tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
|
||
self.b_tma_op,
|
||
b_i,
|
||
self.b_smem_layout,
|
||
self.mma_tiler,
|
||
self.tiled_mma,
|
||
self.cluster_layout_vmnk_shape,
|
||
)
|
||
cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx))
|
||
|
||
# -----------------------------------------------------------------
|
||
# 2Dx2D: SFA, SFB descriptors (new for block-scaled)
|
||
# -----------------------------------------------------------------
|
||
|
||
@cute.jit
|
||
def _construct_sf_descs_2dx2d_direct(
|
||
self,
|
||
expert_idx: Int32,
|
||
padded_offset: Int32,
|
||
padded_size_i: Int32,
|
||
) -> None:
|
||
"""
|
||
2Dx2D: Create expert-wise SFA and SFB descriptors with pre-computed
|
||
padded offset and size.
|
||
|
||
This variant allows the caller to supply padded offsets from SMEM
|
||
(in desc_init_kernel) instead of reading from ``self.offs_padded`` in GMEM.
|
||
"""
|
||
c1 = cutlass.Int32(1)
|
||
|
||
a_chunks_to_move = (
|
||
padded_offset
|
||
// self.sf_vec_size
|
||
* cute.size(self.sfa_tensor, mode=[0])
|
||
// 128
|
||
)
|
||
a_elems_to_move = (
|
||
cute.size(self.sfa_tensor, mode=[0]) * padded_offset // self.sf_vec_size
|
||
)
|
||
b_chunks_to_move = (
|
||
padded_offset
|
||
// self.sf_vec_size
|
||
* cute.size(self.sfb_tensor, mode=[0])
|
||
// 128
|
||
)
|
||
b_elems_to_move = (
|
||
cute.size(self.sfb_tensor, mode=[0]) * padded_offset // self.sf_vec_size
|
||
)
|
||
|
||
per_expert_sfa_shape = (self.sfa_tensor.shape[0], padded_size_i, c1) # type: ignore[index]
|
||
sfa_layout_i = tile_atom_to_shape_SF(per_expert_sfa_shape, self.sf_vec_size)
|
||
sfa_i = cute.make_tensor(
|
||
self.sfa_tensor.iterator + a_elems_to_move, sfa_layout_i
|
||
)
|
||
|
||
tma_atom_sfa, _ = cute.nvgpu.make_tiled_tma_atom_A(
|
||
self.sfa_tma_op,
|
||
sfa_i,
|
||
self.sfa_smem_layout,
|
||
self.mma_tiler,
|
||
self.tiled_mma,
|
||
self.cluster_layout_vmnk_shape,
|
||
internal_type=cutlass.Uint64,
|
||
)
|
||
cpasync.copy_tensormap(tma_atom_sfa, self.get_desc_ptr("sfa", expert_idx))
|
||
|
||
per_expert_sfb_shape = (self.sfb_tensor.shape[0], padded_size_i, c1) # type: ignore[index]
|
||
sfb_layout_i = tile_atom_to_shape_SF(per_expert_sfb_shape, self.sf_vec_size)
|
||
sfb_i = cute.make_tensor(
|
||
self.sfb_tensor.iterator + b_elems_to_move, sfb_layout_i
|
||
)
|
||
|
||
tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B(
|
||
self.sfb_tma_op,
|
||
sfb_i,
|
||
self.sfb_smem_layout,
|
||
self.mma_tiler_sfb,
|
||
self.tiled_mma_sfb,
|
||
self.cluster_layout_sfb_vmnk_shape,
|
||
internal_type=cutlass.Uint64,
|
||
)
|
||
cpasync.copy_tensormap(tma_atom_sfb, self.get_desc_ptr("sfb", expert_idx))
|