mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
[Hopper CuTeDSL] Add grouped GEMM persistent kernel and tests (#3091)
Implement grouped GEMM (C_g = A_g x B_g for g groups) on Hopper using CuTe DSL, extending the dense persistent GEMM with per-group TMA descriptor management. Kernel design (grouped_gemm.py): - Warp-specialized pipeline: DMA warp group handles TMA loads and per-group tensormap updates; MMA warp group runs WGMMA and stores C - StaticPersistentGroupTileScheduler for cross-group tile scheduling - Per-group TMA descriptor updates via GMEM or SMEM mode - Supports fp16, fp8 (E4M3FN/E5M2), int8 with mixed A/B dtypes - Configurable tile shapes (128x128, 128x256) and cluster shapes - Fix base TensorMapManager: hoist uniform_smem_ptrs outside predicated block to avoid illegal @P0 R2UR on sm_90a Tests (test/examples/CuTeDSL/hopper/test_grouped_gemm.py): - L0 compile and L1 correctness pytest suite covering tile shapes, dtypes, major modes, cluster shapes, group counts, and mixed sizes - Move to test/examples/CuTeDSL/hopper/ following sm_100a convention - Fix deprecated startdir arg in test_sharding.py pytest hook
This commit is contained in:
2421
examples/python/CuTeDSL/hopper/grouped_gemm.py
Normal file
2421
examples/python/CuTeDSL/hopper/grouped_gemm.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -19,6 +19,8 @@ from cutlass.cutlass_dsl import dsl_user_op
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass import const_expr
|
||||
from cutlass.cute.core import AddressSpace as _CuteAddressSpace
|
||||
from cutlass.cute.core import make_ptr as _cute_make_ptr
|
||||
|
||||
|
||||
class TensorMapUpdateMode(Enum):
|
||||
@@ -138,11 +140,25 @@ class TensorMapManager:
|
||||
warp_idx = cute.arch.make_warp_uniform(
|
||||
cute.arch.warp_idx(loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
|
||||
# Hoist SMEM pointer integer values into warp-uniform registers before
|
||||
# entering predicated blocks. This avoids predicated R2UR lowering on sm_90a.
|
||||
uniform_smem_ptrs = tuple(
|
||||
_cute_make_ptr(
|
||||
p.dtype,
|
||||
cute.arch.make_warp_uniform(p.toint(), loc=loc, ip=ip),
|
||||
mem_space=_CuteAddressSpace.smem,
|
||||
assumed_align=p.alignment,
|
||||
)
|
||||
for p in tensormap_smem_ptr
|
||||
)
|
||||
else:
|
||||
uniform_smem_ptrs = tensormap_smem_ptr
|
||||
# updates before touching tensormap in global memory
|
||||
if warp_idx == warp_id:
|
||||
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
|
||||
for copy_atom, tensor, smem_ptr in zip(
|
||||
tma_copy_atom, tensor_gmem, tensormap_smem_ptr
|
||||
tma_copy_atom, tensor_gmem, uniform_smem_ptrs
|
||||
):
|
||||
cute.nvgpu.cpasync.update_tma_descriptor(
|
||||
copy_atom, tensor, smem_ptr, loc=loc, ip=ip
|
||||
@@ -154,7 +170,7 @@ class TensorMapManager:
|
||||
cute.arch.sync_warp(loc=loc, ip=ip)
|
||||
# updates to tensormap in global memory
|
||||
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
|
||||
for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
|
||||
for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, uniform_smem_ptrs):
|
||||
cute.nvgpu.cpasync.cp_fence_tma_desc_release(
|
||||
gmem_ptr, smem_ptr, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
30
test/examples/CuTeDSL/hopper/conftest.py
Normal file
30
test/examples/CuTeDSL/hopper/conftest.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
def pytest_configure(config):
|
||||
config.default_SMs[__file__] = "90a"
|
||||
554
test/examples/CuTeDSL/hopper/test_grouped_gemm.py
Normal file
554
test/examples/CuTeDSL/hopper/test_grouped_gemm.py
Normal file
@@ -0,0 +1,554 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Comprehensive pytest test suite for hopper/grouped_gemm.py.
|
||||
|
||||
Test organization
|
||||
-----------------
|
||||
L0 — compilation tests (skip_ref_check=True, iterations=0)
|
||||
Verify that the kernel compiles for a broad range of configurations
|
||||
without running on the GPU. Fast (~1-3 s each).
|
||||
|
||||
L1 — correctness tests (GPU execution, checked against torch.einsum)
|
||||
Verify numerical correctness for the key configurations.
|
||||
|
||||
Coverage
|
||||
--------
|
||||
* All tile shapes: (64,64), (128,64), (128,128), (128,256)
|
||||
* Both tensormap update modes: GMEM, SMEM
|
||||
* Data types: fp16, bf16-like (fp16/fp32 acc), fp8 (E4M3FN, E5M2), int8/uint8
|
||||
* Matrix major modes: A k/m-major, B k/n-major, C n/m-major
|
||||
* Cluster shapes: (1,1), (2,1), (1,2), (2,2) [mcast paths]
|
||||
* Group counts: 1, 2, 4, 8, 16
|
||||
* Mixed problem sizes across groups in the same batch
|
||||
* Edge cases: single tile, non-uniform groups, same-shape groups
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Keep test behavior deterministic regardless of caller shell env.
|
||||
# These are consumed at grouped_gemm import time.
|
||||
os.environ.setdefault("GROUPED_GEMM_FORCE_CUTE_COPY", "0")
|
||||
|
||||
import pytest
|
||||
import cutlass
|
||||
import cutlass.utils as utils
|
||||
from hopper.grouped_gemm import run
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
GMEM = utils.TensorMapUpdateMode.GMEM
|
||||
SMEM = utils.TensorMapUpdateMode.SMEM
|
||||
|
||||
F16 = cutlass.Float16
|
||||
F32 = cutlass.Float32
|
||||
F8E4 = cutlass.Float8E4M3FN
|
||||
F8E5 = cutlass.Float8E5M2
|
||||
I8 = cutlass.Int8
|
||||
U8 = cutlass.Uint8
|
||||
I32 = cutlass.Int32
|
||||
|
||||
TMAP_MODES = [SMEM, GMEM]
|
||||
TMAP_MODE_IDS = ["smem", "gmem"]
|
||||
|
||||
|
||||
def _run_compile(
|
||||
num_groups,
|
||||
problem_sizes_mnkl,
|
||||
tile_shape_mn,
|
||||
cluster_shape_mn=(1, 1),
|
||||
a_dtype=F16,
|
||||
b_dtype=F16,
|
||||
c_dtype=F16,
|
||||
acc_dtype=F32,
|
||||
a_major="k",
|
||||
b_major="k",
|
||||
c_major="n",
|
||||
tensormap_update_mode=SMEM,
|
||||
):
|
||||
"""Compile-only helper (iterations=0)."""
|
||||
_run_case(
|
||||
num_groups=num_groups,
|
||||
problem_sizes_mnkl=problem_sizes_mnkl,
|
||||
tile_shape_mn=tile_shape_mn,
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
a_dtype=a_dtype,
|
||||
b_dtype=b_dtype,
|
||||
c_dtype=c_dtype,
|
||||
acc_dtype=acc_dtype,
|
||||
a_major=a_major,
|
||||
b_major=b_major,
|
||||
c_major=c_major,
|
||||
tensormap_update_mode=tensormap_update_mode,
|
||||
skip_ref_check=True,
|
||||
warmup_iterations=0,
|
||||
iterations=0,
|
||||
)
|
||||
|
||||
|
||||
def _run_correctness(
|
||||
num_groups,
|
||||
problem_sizes_mnkl,
|
||||
tile_shape_mn,
|
||||
cluster_shape_mn=(1, 1),
|
||||
a_dtype=F16,
|
||||
b_dtype=F16,
|
||||
c_dtype=F16,
|
||||
acc_dtype=F32,
|
||||
a_major="k",
|
||||
b_major="k",
|
||||
c_major="n",
|
||||
tensormap_update_mode=SMEM,
|
||||
tolerance=1e-1,
|
||||
):
|
||||
"""Correctness helper (1 iteration, ref-checked)."""
|
||||
_run_case(
|
||||
num_groups=num_groups,
|
||||
problem_sizes_mnkl=problem_sizes_mnkl,
|
||||
tile_shape_mn=tile_shape_mn,
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
a_dtype=a_dtype,
|
||||
b_dtype=b_dtype,
|
||||
c_dtype=c_dtype,
|
||||
acc_dtype=acc_dtype,
|
||||
a_major=a_major,
|
||||
b_major=b_major,
|
||||
c_major=c_major,
|
||||
tensormap_update_mode=tensormap_update_mode,
|
||||
tolerance=tolerance,
|
||||
warmup_iterations=0,
|
||||
iterations=1,
|
||||
skip_ref_check=False,
|
||||
)
|
||||
|
||||
|
||||
def _run_case(
|
||||
*,
|
||||
num_groups,
|
||||
problem_sizes_mnkl,
|
||||
tile_shape_mn,
|
||||
cluster_shape_mn=(1, 1),
|
||||
a_dtype=F16,
|
||||
b_dtype=F16,
|
||||
c_dtype=F16,
|
||||
acc_dtype=F32,
|
||||
a_major="k",
|
||||
b_major="k",
|
||||
c_major="n",
|
||||
tensormap_update_mode=SMEM,
|
||||
tolerance=1e-1,
|
||||
warmup_iterations=0,
|
||||
iterations=1,
|
||||
skip_ref_check=False,
|
||||
):
|
||||
"""Shared invocation helper for compile-only and correctness tests."""
|
||||
run(
|
||||
num_groups=num_groups,
|
||||
problem_sizes_mnkl=problem_sizes_mnkl,
|
||||
a_dtype=a_dtype,
|
||||
b_dtype=b_dtype,
|
||||
c_dtype=c_dtype,
|
||||
acc_dtype=acc_dtype,
|
||||
a_major=a_major,
|
||||
b_major=b_major,
|
||||
c_major=c_major,
|
||||
tile_shape_mn=tile_shape_mn,
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
tensormap_update_mode=tensormap_update_mode,
|
||||
tolerance=tolerance,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
skip_ref_check=skip_ref_check,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L0 — tile shape coverage (both SMEM and GMEM)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
@pytest.mark.parametrize(
|
||||
"tile_shape_mn, problem_sizes_mnkl",
|
||||
[
|
||||
pytest.param((128, 256), [(128, 256, 64, 1)], id="tile128x256"),
|
||||
pytest.param((128, 128), [(128, 128, 64, 1)], id="tile128x128"),
|
||||
pytest.param((128, 64), [(128, 64, 64, 1)], id="tile128x64"),
|
||||
pytest.param((64, 64), [(64, 64, 64, 1)], id="tile64x64"),
|
||||
],
|
||||
)
|
||||
def test_l0_tile_shapes(tile_shape_mn, problem_sizes_mnkl, tmap_mode):
|
||||
"""All tile shapes compile under both SMEM and GMEM modes."""
|
||||
_run_compile(1, problem_sizes_mnkl, tile_shape_mn, tensormap_update_mode=tmap_mode)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L0 — group count coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
@pytest.mark.parametrize(
|
||||
"num_groups, problem_sizes_mnkl",
|
||||
[
|
||||
pytest.param(2, [(128, 256, 64, 1)] * 2, id="2g-uniform"),
|
||||
pytest.param(4, [(128, 256, 64, 1), (64, 128, 64, 1),
|
||||
(256, 128, 64, 1), (192, 256, 64, 1)], id="4g-mixed"),
|
||||
pytest.param(8, [(128, 256, 64, 1)] * 8, id="8g-uniform"),
|
||||
],
|
||||
)
|
||||
def test_l0_group_counts(num_groups, problem_sizes_mnkl, tmap_mode):
|
||||
"""Various group counts compile for tile (128,256) fp16."""
|
||||
_run_compile(num_groups, problem_sizes_mnkl, (128, 256), tensormap_update_mode=tmap_mode)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L0 — data type coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
@pytest.mark.parametrize(
|
||||
"a_dtype, b_dtype, c_dtype, acc_dtype, problem_sizes_mnkl",
|
||||
[
|
||||
# fp16 → fp16 output
|
||||
pytest.param(F16, F16, F16, F32, [(128, 256, 64, 1)], id="fp16-fp16-fp16-fp32"),
|
||||
# fp16 → fp32 output
|
||||
pytest.param(F16, F16, F32, F32, [(128, 256, 64, 1)], id="fp16-fp16-fp32-fp32"),
|
||||
# fp16 with fp16 accumulator
|
||||
pytest.param(F16, F16, F16, F16, [(128, 256, 64, 1)], id="fp16-fp16-fp16-fp16"),
|
||||
# fp8 E4M3 → fp16 output (K must be multiple of 16 for fp8 alignment)
|
||||
pytest.param(F8E4, F8E4, F16, F32, [(128, 256, 128, 1)], id="fp8e4-fp8e4-fp16-fp32"),
|
||||
# fp8 E5M2 → fp16 output
|
||||
pytest.param(F8E5, F8E5, F16, F32, [(128, 256, 128, 1)], id="fp8e5-fp8e5-fp16-fp32"),
|
||||
# mixed fp8: E4M3 × E5M2
|
||||
pytest.param(F8E4, F8E5, F16, F32, [(128, 256, 128, 1)], id="fp8e4-fp8e5-fp16-fp32"),
|
||||
# int8 → int32 output (K must be multiple of 16)
|
||||
pytest.param(I8, I8, I32, I32, [(128, 256, 128, 1)], id="int8-int8-int32-int32"),
|
||||
# uint8 → int32 output
|
||||
pytest.param(U8, U8, I32, I32, [(128, 256, 128, 1)], id="uint8-uint8-int32-int32"),
|
||||
],
|
||||
)
|
||||
def test_l0_dtypes(a_dtype, b_dtype, c_dtype, acc_dtype, problem_sizes_mnkl, tmap_mode):
|
||||
"""Data type combinations compile for tile (128,256)."""
|
||||
_run_compile(
|
||||
1, problem_sizes_mnkl, (128, 256),
|
||||
a_dtype=a_dtype, b_dtype=b_dtype, c_dtype=c_dtype, acc_dtype=acc_dtype,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L0 — matrix major modes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
@pytest.mark.parametrize(
|
||||
"a_major, b_major, c_major, problem_sizes_mnkl, tile_shape_mn",
|
||||
[
|
||||
# k-major A, k-major B, n-major C (default)
|
||||
pytest.param("k", "k", "n", [(128, 256, 64, 1)], (128, 256), id="akm-bkm-cn"),
|
||||
# m-major A (A contiguous in M; M must be multiple of 8 for fp16)
|
||||
pytest.param("m", "k", "n", [(128, 256, 64, 1)], (128, 128), id="amaj-bkm-cn"),
|
||||
# n-major B (B contiguous in N; N must be multiple of 8 for fp16)
|
||||
pytest.param("k", "n", "n", [(128, 128, 64, 1)], (128, 128), id="akm-bnmaj-cn"),
|
||||
# m-major C output (M must be multiple of 8)
|
||||
pytest.param("k", "k", "m", [(128, 256, 64, 1)], (128, 256), id="akm-bkm-cmaj"),
|
||||
# m-major A + n-major B
|
||||
pytest.param("m", "n", "n", [(128, 128, 64, 1)], (128, 128), id="amaj-bnmaj-cn"),
|
||||
],
|
||||
)
|
||||
def test_l0_major_modes(a_major, b_major, c_major, problem_sizes_mnkl, tile_shape_mn, tmap_mode):
|
||||
"""Matrix major mode combinations compile."""
|
||||
_run_compile(
|
||||
1, problem_sizes_mnkl, tile_shape_mn,
|
||||
a_major=a_major, b_major=b_major, c_major=c_major,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L0 — cluster shapes (mcast paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
@pytest.mark.parametrize(
|
||||
"cluster_shape_mn, problem_sizes_mnkl, tile_shape_mn",
|
||||
[
|
||||
# 1×1: no multicast (default, baseline)
|
||||
pytest.param((1, 1), [(128, 256, 64, 1)], (128, 256), id="cluster1x1"),
|
||||
# 2×1: A multicast across 2 CTAs in M; need M >= 2*tile_m
|
||||
pytest.param((2, 1), [(256, 256, 64, 1)], (128, 256), id="cluster2x1"),
|
||||
# 1×2: B multicast across 2 CTAs in N; need N >= 2*tile_n
|
||||
pytest.param((1, 2), [(128, 512, 64, 1)], (128, 256), id="cluster1x2"),
|
||||
# 2×2: both A and B multicast
|
||||
pytest.param((2, 2), [(256, 512, 64, 1)], (128, 256), id="cluster2x2"),
|
||||
],
|
||||
)
|
||||
def test_l0_cluster_shapes(cluster_shape_mn, problem_sizes_mnkl, tile_shape_mn, tmap_mode):
|
||||
"""Cluster shapes including multicast paths compile."""
|
||||
_run_compile(
|
||||
1, problem_sizes_mnkl, tile_shape_mn,
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L0 — mixed problem sizes (non-uniform groups)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
@pytest.mark.parametrize(
|
||||
"num_groups, problem_sizes_mnkl",
|
||||
[
|
||||
# groups with very different shapes
|
||||
pytest.param(4, [(64, 64, 64, 1),
|
||||
(128, 128, 64, 1),
|
||||
(256, 128, 64, 1),
|
||||
(128, 256, 64, 1)], id="4g-all-tiles"),
|
||||
# tiny vs large
|
||||
pytest.param(2, [(64, 64, 64, 1),
|
||||
(512, 512, 64, 1)], id="2g-tiny-large"),
|
||||
],
|
||||
)
|
||||
def test_l0_mixed_problem_sizes(num_groups, problem_sizes_mnkl, tmap_mode):
|
||||
"""Heterogeneous per-group problem sizes compile."""
|
||||
_run_compile(num_groups, problem_sizes_mnkl, (128, 256), tensormap_update_mode=tmap_mode)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L1 — correctness: both tensormap modes, fp16, tile (128,256)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.L0(0)
|
||||
@pytest.mark.L1
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
def test_l1_fp16_4g_mixed(tmap_mode):
|
||||
"""Four groups with mixed sizes are numerically correct."""
|
||||
_run_correctness(
|
||||
4,
|
||||
[(128, 256, 64, 1), (64, 128, 64, 1), (256, 128, 64, 1), (192, 256, 64, 1)],
|
||||
(128, 256),
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L1 — correctness: all tile shapes with fp16 SMEM + GMEM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.L0(0)
|
||||
@pytest.mark.L1
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
@pytest.mark.parametrize(
|
||||
"tile_shape_mn, problem_sizes_mnkl",
|
||||
[
|
||||
pytest.param((128, 256), [(128, 256, 64, 1)], id="tile128x256"),
|
||||
pytest.param((128, 128), [(128, 128, 64, 1)], id="tile128x128"),
|
||||
pytest.param((128, 64), [(128, 64, 64, 1)], id="tile128x64"),
|
||||
pytest.param((64, 64), [(64, 64, 64, 1)], id="tile64x64"),
|
||||
],
|
||||
)
|
||||
def test_l1_tile_shapes_fp16(tile_shape_mn, problem_sizes_mnkl, tmap_mode):
|
||||
"""All tile shapes produce correct results."""
|
||||
_run_correctness(1, problem_sizes_mnkl, tile_shape_mn, tensormap_update_mode=tmap_mode)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L1 — correctness: group count scaling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.L0(0)
|
||||
@pytest.mark.L1
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
@pytest.mark.parametrize(
|
||||
"num_groups",
|
||||
[2, 4],
|
||||
ids=["2g", "4g"],
|
||||
)
|
||||
def test_l1_group_count_scaling(num_groups, tmap_mode):
|
||||
"""Correctness scales correctly with group count."""
|
||||
_run_correctness(
|
||||
num_groups,
|
||||
[(128, 256, 64, 1)] * num_groups,
|
||||
(128, 256),
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L1 — correctness: data types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.L0(0)
|
||||
@pytest.mark.L1
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
def test_l1_fp16_c_fp32(tmap_mode):
|
||||
"""fp16 inputs with fp32 output are numerically correct."""
|
||||
_run_correctness(
|
||||
1, [(128, 256, 64, 1)], (128, 256),
|
||||
c_dtype=F32, acc_dtype=F32,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.L0(0)
|
||||
@pytest.mark.L1
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
def test_l1_fp8_e4m3(tmap_mode):
|
||||
"""fp8 E4M3FN inputs are numerically correct (K=128 for 16B alignment)."""
|
||||
_run_correctness(
|
||||
1, [(128, 256, 128, 1)], (128, 256),
|
||||
a_dtype=F8E4, b_dtype=F8E4, c_dtype=F16, acc_dtype=F32,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
tolerance=0.5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.L0(0)
|
||||
@pytest.mark.L1
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
def test_l1_fp8_mixed(tmap_mode):
|
||||
"""Mixed fp8 inputs (E4M3 × E5M2) are numerically correct."""
|
||||
_run_correctness(
|
||||
1, [(128, 256, 128, 1)], (128, 256),
|
||||
a_dtype=F8E4, b_dtype=F8E5, c_dtype=F16, acc_dtype=F32,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
tolerance=0.5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.L0(0)
|
||||
@pytest.mark.L1
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
def test_l1_int8(tmap_mode):
|
||||
"""int8 inputs with int32 accumulator are correct."""
|
||||
_run_correctness(
|
||||
1, [(128, 256, 128, 1)], (128, 256),
|
||||
a_dtype=I8, b_dtype=I8, c_dtype=I32, acc_dtype=I32,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
tolerance=0,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L1 — correctness: matrix major modes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.L0(0)
|
||||
@pytest.mark.L1
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
def test_l1_c_m_major(tmap_mode):
|
||||
"""m-major C output is correct."""
|
||||
_run_correctness(
|
||||
1, [(128, 256, 64, 1)], (128, 256),
|
||||
c_major="m",
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="JIT compile time too long for CI (~25 min); run manually")
|
||||
@pytest.mark.L0(0)
|
||||
@pytest.mark.L1
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
def test_l1_all_non_default_majors(tmap_mode):
|
||||
"""m-major A, n-major B, m-major C together are correct."""
|
||||
_run_correctness(
|
||||
1, [(64, 64, 64, 1)], (128, 128),
|
||||
a_major="m", b_major="n", c_major="m",
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L1 — correctness: cluster shapes (mcast paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.L0(0)
|
||||
@pytest.mark.L1
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
@pytest.mark.parametrize(
|
||||
"cluster_shape_mn, problem_sizes_mnkl",
|
||||
[
|
||||
pytest.param((2, 2), [(256, 512, 64, 1)], id="cluster2x2"),
|
||||
],
|
||||
)
|
||||
def test_l1_cluster_shapes(cluster_shape_mn, problem_sizes_mnkl, tmap_mode):
|
||||
"""Multicast cluster shapes produce correct results."""
|
||||
_run_correctness(
|
||||
1, problem_sizes_mnkl, (128, 256),
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# L1 — correctness: multi-group with mixed sizes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.L0(0)
|
||||
@pytest.mark.L1
|
||||
@pytest.mark.parametrize("tmap_mode", TMAP_MODES, ids=TMAP_MODE_IDS)
|
||||
def test_l1_8g_mixed_sizes(tmap_mode):
|
||||
"""8 groups with heterogeneous problem sizes are all correct."""
|
||||
_run_correctness(
|
||||
8,
|
||||
[
|
||||
(128, 256, 64, 1),
|
||||
(64, 128, 64, 1),
|
||||
(256, 128, 64, 1),
|
||||
(128, 128, 128, 1),
|
||||
(192, 256, 64, 1),
|
||||
(64, 64, 64, 1),
|
||||
(128, 256, 128, 1),
|
||||
(256, 256, 64, 1),
|
||||
],
|
||||
(128, 256),
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
@@ -495,7 +495,7 @@ def pytest_runtest_makereport(item, call):
|
||||
report.longrepr = f"Expect exception '{xfail_info}', but got '{type(call.excinfo.value)}'"
|
||||
|
||||
|
||||
def pytest_report_collectionfinish(config, start_path, startdir, items):
|
||||
def pytest_report_collectionfinish(config, start_path, items):
|
||||
if not config.getoption("--collect-only"):
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user