mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-22 05:48:18 +00:00
Fix examples and pytest, run ruff (#3230)
Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com>
This commit is contained in:
@@ -30,6 +30,8 @@ import argparse
|
||||
import ctypes
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
@@ -44,10 +46,11 @@ from cutlass import Boolean, Float32, Int32, Int64
|
||||
from cutlass.cute.runtime import make_ptr
|
||||
|
||||
# Support both direct execution and module import
|
||||
try:
|
||||
from .reduce import row_reduce
|
||||
except ImportError:
|
||||
from reduce import row_reduce
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../.."))
|
||||
|
||||
from blackwell.kernel.reduce.reduce import row_reduce
|
||||
|
||||
"""
|
||||
RMSNorm: Root Mean Square Layer Normalization for Hopper & Blackwell (SM90+)
|
||||
|
||||
@@ -39,6 +39,8 @@ import numpy as np
|
||||
|
||||
|
||||
project_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
|
||||
cute_example_path = project_root / "examples" / "python" / "CuTeDSL" / "cute"
|
||||
example_path = project_root / "examples" / "python" / "CuTeDSL"
|
||||
utils_path = project_root / "test" / "utils"
|
||||
|
||||
@@ -50,9 +52,11 @@ utils_path = project_root / "test" / "utils"
|
||||
# Importing cutlass here, while sys.path is still clean, avoids that race.
|
||||
import cutlass # noqa: E402 (intentional early import)
|
||||
|
||||
sys.path.append(str(cute_example_path))
|
||||
sys.path.append(str(example_path))
|
||||
sys.path.append(str(utils_path))
|
||||
|
||||
|
||||
# The helper class to prevent modification of sys.path from test files
|
||||
# Only allow modification of sys.path from pytest monkeypatch API calls
|
||||
class ImmutableSysPath(list):
|
||||
@@ -70,6 +74,7 @@ class ImmutableSysPath(list):
|
||||
}
|
||||
|
||||
for mtd in mutating_methods:
|
||||
|
||||
def mutating_method(self, *args, mtd=mtd, **kwargs):
|
||||
frame = sys._getframe().f_back
|
||||
if (
|
||||
@@ -98,6 +103,7 @@ sys.path = ImmutableSysPath(list(sys.path))
|
||||
|
||||
pytest_plugins = ["test_sharding"]
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--sample-interval",
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.default_SMs[__file__] = "90a"
|
||||
config.addinivalue_line(
|
||||
|
||||
@@ -55,7 +55,7 @@ Coverage
|
||||
|
||||
import pytest
|
||||
import cutlass
|
||||
from hopper.dense_gemm_fp8_2xacc import run
|
||||
from hopper.kernel.dense_gemm.dense_gemm_fp8_2xacc import run
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type aliases
|
||||
@@ -169,8 +169,8 @@ def _run_benchmark(
|
||||
[
|
||||
pytest.param((128, 256), (2048, 2048, 2048, 1), id="tile128x256"),
|
||||
pytest.param((128, 128), (2048, 2048, 2048, 1), id="tile128x128"),
|
||||
pytest.param((128, 64), (2048, 2048, 2048, 1), id="tile128x64"),
|
||||
pytest.param((64, 64), (2048, 2048, 2048, 1), id="tile64x64"),
|
||||
pytest.param((128, 64), (2048, 2048, 2048, 1), id="tile128x64"),
|
||||
pytest.param((64, 64), (2048, 2048, 2048, 1), id="tile64x64"),
|
||||
],
|
||||
)
|
||||
def test_l0_tile_shapes(tile_shape_mn, mnkl):
|
||||
@@ -195,8 +195,11 @@ def test_l0_tile_shapes(tile_shape_mn, mnkl):
|
||||
)
|
||||
def test_l0_cluster_shapes(cluster_shape_mn):
|
||||
"""All valid cluster shapes compile (tile 128x128, 2048^3)."""
|
||||
_run_compile(mnkl=(2048, 2048, 2048, 1), tile_shape_mn=(128, 128),
|
||||
cluster_shape_mn=cluster_shape_mn)
|
||||
_run_compile(
|
||||
mnkl=(2048, 2048, 2048, 1),
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -208,8 +211,8 @@ def test_l0_cluster_shapes(cluster_shape_mn):
|
||||
@pytest.mark.parametrize(
|
||||
"c_dtype",
|
||||
[
|
||||
pytest.param(F16, id="Float16"),
|
||||
pytest.param(F32, id="Float32"),
|
||||
pytest.param(F16, id="Float16"),
|
||||
pytest.param(F32, id="Float32"),
|
||||
pytest.param(F8E4, id="Float8E4M3FN"),
|
||||
],
|
||||
)
|
||||
@@ -227,8 +230,8 @@ def test_l0_output_dtypes(c_dtype):
|
||||
@pytest.mark.parametrize(
|
||||
"mma_promotion_interval",
|
||||
[
|
||||
pytest.param(4, id="interval4"),
|
||||
pytest.param(8, id="interval8"),
|
||||
pytest.param(4, id="interval4"),
|
||||
pytest.param(8, id="interval8"),
|
||||
pytest.param(16, id="interval16"),
|
||||
],
|
||||
)
|
||||
@@ -249,8 +252,8 @@ def test_l0_mma_promotion_intervals(mma_promotion_interval):
|
||||
[
|
||||
pytest.param((128, 256), (2048, 2048, 2048, 1), id="tile128x256"),
|
||||
pytest.param((128, 128), (2048, 2048, 2048, 1), id="tile128x128"),
|
||||
pytest.param((128, 64), (2048, 2048, 2048, 1), id="tile128x64"),
|
||||
pytest.param((64, 64), (2048, 2048, 2048, 1), id="tile64x64"),
|
||||
pytest.param((128, 64), (2048, 2048, 2048, 1), id="tile128x64"),
|
||||
pytest.param((64, 64), (2048, 2048, 2048, 1), id="tile64x64"),
|
||||
],
|
||||
)
|
||||
def test_l1_tile_shapes(tile_shape_mn, mnkl):
|
||||
@@ -276,8 +279,11 @@ def test_l1_tile_shapes(tile_shape_mn, mnkl):
|
||||
)
|
||||
def test_l1_cluster_shapes(cluster_shape_mn):
|
||||
"""All cluster shapes (including A/B multicast paths) produce correct results."""
|
||||
_run_correctness(mnkl=(2048, 2048, 2048, 1), tile_shape_mn=(128, 128),
|
||||
cluster_shape_mn=cluster_shape_mn)
|
||||
_run_correctness(
|
||||
mnkl=(2048, 2048, 2048, 1),
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -290,9 +296,9 @@ def test_l1_cluster_shapes(cluster_shape_mn):
|
||||
@pytest.mark.parametrize(
|
||||
"c_dtype, tolerance",
|
||||
[
|
||||
pytest.param(F16, 0.1, id="Float16"),
|
||||
pytest.param(F32, 0.1, id="Float32"),
|
||||
pytest.param(F8E4, 0.5, id="Float8E4M3FN"),
|
||||
pytest.param(F16, 0.1, id="Float16"),
|
||||
pytest.param(F32, 0.1, id="Float32"),
|
||||
pytest.param(F8E4, 0.5, id="Float8E4M3FN"),
|
||||
],
|
||||
)
|
||||
def test_l1_output_dtypes(c_dtype, tolerance):
|
||||
@@ -310,8 +316,8 @@ def test_l1_output_dtypes(c_dtype, tolerance):
|
||||
@pytest.mark.parametrize(
|
||||
"mma_promotion_interval",
|
||||
[
|
||||
pytest.param(4, id="interval4"),
|
||||
pytest.param(8, id="interval8"),
|
||||
pytest.param(4, id="interval4"),
|
||||
pytest.param(8, id="interval8"),
|
||||
pytest.param(16, id="interval16"),
|
||||
],
|
||||
)
|
||||
@@ -330,9 +336,9 @@ def test_l1_mma_promotion_intervals(mma_promotion_interval):
|
||||
@pytest.mark.parametrize(
|
||||
"scale_a_val, scale_b_val",
|
||||
[
|
||||
pytest.param(0.5, 2.0, id="scale_a0.5_b2.0"),
|
||||
pytest.param(0.25, 4.0, id="scale_a0.25_b4.0"),
|
||||
pytest.param(2.0, 0.5, id="scale_a2.0_b0.5"),
|
||||
pytest.param(0.5, 2.0, id="scale_a0.5_b2.0"),
|
||||
pytest.param(0.25, 4.0, id="scale_a0.25_b4.0"),
|
||||
pytest.param(2.0, 0.5, id="scale_a2.0_b0.5"),
|
||||
],
|
||||
)
|
||||
def test_l1_scale_factors(scale_a_val, scale_b_val):
|
||||
@@ -351,7 +357,7 @@ def test_l1_scale_factors(scale_a_val, scale_b_val):
|
||||
"mnkl",
|
||||
[
|
||||
pytest.param((1024, 1024, 1024, 2), id="L2"),
|
||||
pytest.param((512, 512, 512, 4), id="L4"),
|
||||
pytest.param((512, 512, 512, 4), id="L4"),
|
||||
],
|
||||
)
|
||||
def test_l1_batched(mnkl):
|
||||
@@ -370,25 +376,118 @@ def test_l1_batched(mnkl):
|
||||
"mnkl, tile_shape_mn, cluster_shape_mn, mma_promotion_interval, label",
|
||||
[
|
||||
# Square 4096^3 — tile / cluster sweep
|
||||
pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 1), 4, "4096^3 tile=128x128 cluster=1x1", id="4096-128x128-1x1"),
|
||||
pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 4, "4096^3 tile=128x128 cluster=1x2", id="4096-128x128-1x2"),
|
||||
pytest.param((4096, 4096, 4096, 1), (128, 128), (2, 2), 4, "4096^3 tile=128x128 cluster=2x2", id="4096-128x128-2x2"),
|
||||
pytest.param((4096, 4096, 4096, 1), (128, 256), (1, 2), 4, "4096^3 tile=128x256 cluster=1x2", id="4096-128x256-1x2"),
|
||||
pytest.param((4096, 4096, 4096, 1), (128, 256), (2, 2), 4, "4096^3 tile=128x256 cluster=2x2", id="4096-128x256-2x2"),
|
||||
pytest.param((4096, 4096, 4096, 1), (128, 64), (1, 2), 4, "4096^3 tile=128x64 cluster=1x2", id="4096-128x64-1x2"),
|
||||
pytest.param((4096, 4096, 4096, 1), (64, 64), (1, 2), 4, "4096^3 tile=64x64 cluster=1x2", id="4096-64x64-1x2"),
|
||||
pytest.param(
|
||||
(4096, 4096, 4096, 1),
|
||||
(128, 128),
|
||||
(1, 1),
|
||||
4,
|
||||
"4096^3 tile=128x128 cluster=1x1",
|
||||
id="4096-128x128-1x1",
|
||||
),
|
||||
pytest.param(
|
||||
(4096, 4096, 4096, 1),
|
||||
(128, 128),
|
||||
(1, 2),
|
||||
4,
|
||||
"4096^3 tile=128x128 cluster=1x2",
|
||||
id="4096-128x128-1x2",
|
||||
),
|
||||
pytest.param(
|
||||
(4096, 4096, 4096, 1),
|
||||
(128, 128),
|
||||
(2, 2),
|
||||
4,
|
||||
"4096^3 tile=128x128 cluster=2x2",
|
||||
id="4096-128x128-2x2",
|
||||
),
|
||||
pytest.param(
|
||||
(4096, 4096, 4096, 1),
|
||||
(128, 256),
|
||||
(1, 2),
|
||||
4,
|
||||
"4096^3 tile=128x256 cluster=1x2",
|
||||
id="4096-128x256-1x2",
|
||||
),
|
||||
pytest.param(
|
||||
(4096, 4096, 4096, 1),
|
||||
(128, 256),
|
||||
(2, 2),
|
||||
4,
|
||||
"4096^3 tile=128x256 cluster=2x2",
|
||||
id="4096-128x256-2x2",
|
||||
),
|
||||
pytest.param(
|
||||
(4096, 4096, 4096, 1),
|
||||
(128, 64),
|
||||
(1, 2),
|
||||
4,
|
||||
"4096^3 tile=128x64 cluster=1x2",
|
||||
id="4096-128x64-1x2",
|
||||
),
|
||||
pytest.param(
|
||||
(4096, 4096, 4096, 1),
|
||||
(64, 64),
|
||||
(1, 2),
|
||||
4,
|
||||
"4096^3 tile=64x64 cluster=1x2",
|
||||
id="4096-64x64-1x2",
|
||||
),
|
||||
# LLM-like: 8192x8192x4096
|
||||
pytest.param((8192, 8192, 4096, 1), (128, 128), (1, 2), 4, "8192x8192x4096 tile=128x128 cluster=1x2", id="llm-128x128-1x2"),
|
||||
pytest.param((8192, 8192, 4096, 1), (128, 256), (2, 2), 4, "8192x8192x4096 tile=128x256 cluster=2x2", id="llm-128x256-2x2"),
|
||||
pytest.param(
|
||||
(8192, 8192, 4096, 1),
|
||||
(128, 128),
|
||||
(1, 2),
|
||||
4,
|
||||
"8192x8192x4096 tile=128x128 cluster=1x2",
|
||||
id="llm-128x128-1x2",
|
||||
),
|
||||
pytest.param(
|
||||
(8192, 8192, 4096, 1),
|
||||
(128, 256),
|
||||
(2, 2),
|
||||
4,
|
||||
"8192x8192x4096 tile=128x256 cluster=2x2",
|
||||
id="llm-128x256-2x2",
|
||||
),
|
||||
# mma_promotion_interval sweep (shows precision/performance trade-off)
|
||||
pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 4, "4096^3 interval=4", id="4096-interval4"),
|
||||
pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 8, "4096^3 interval=8", id="4096-interval8"),
|
||||
pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 16, "4096^3 interval=16", id="4096-interval16"),
|
||||
pytest.param(
|
||||
(4096, 4096, 4096, 1),
|
||||
(128, 128),
|
||||
(1, 2),
|
||||
4,
|
||||
"4096^3 interval=4",
|
||||
id="4096-interval4",
|
||||
),
|
||||
pytest.param(
|
||||
(4096, 4096, 4096, 1),
|
||||
(128, 128),
|
||||
(1, 2),
|
||||
8,
|
||||
"4096^3 interval=8",
|
||||
id="4096-interval8",
|
||||
),
|
||||
pytest.param(
|
||||
(4096, 4096, 4096, 1),
|
||||
(128, 128),
|
||||
(1, 2),
|
||||
16,
|
||||
"4096^3 interval=16",
|
||||
id="4096-interval16",
|
||||
),
|
||||
# FP8 output
|
||||
pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 4, "4096^3 out=FP8E4M3", id="4096-fp8-out"),
|
||||
pytest.param(
|
||||
(4096, 4096, 4096, 1),
|
||||
(128, 128),
|
||||
(1, 2),
|
||||
4,
|
||||
"4096^3 out=FP8E4M3",
|
||||
id="4096-fp8-out",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_bench(mnkl, tile_shape_mn, cluster_shape_mn, mma_promotion_interval, label, capsys):
|
||||
def test_bench(
|
||||
mnkl, tile_shape_mn, cluster_shape_mn, mma_promotion_interval, label, capsys
|
||||
):
|
||||
"""
|
||||
Performance benchmark — run with: pytest -m bench -s
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ os.environ.setdefault("GROUPED_GEMM_FORCE_CUTE_COPY", "0")
|
||||
import pytest
|
||||
import cutlass
|
||||
import cutlass.utils as utils
|
||||
from hopper.grouped_gemm import run
|
||||
from hopper.kernel.grouped_gemm.grouped_gemm import run
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -203,8 +203,8 @@ def _run_case(
|
||||
[
|
||||
pytest.param((128, 256), [(128, 256, 64, 1)], id="tile128x256"),
|
||||
pytest.param((128, 128), [(128, 128, 64, 1)], id="tile128x128"),
|
||||
pytest.param((128, 64), [(128, 64, 64, 1)], id="tile128x64"),
|
||||
pytest.param((64, 64), [(64, 64, 64, 1)], id="tile64x64"),
|
||||
pytest.param((128, 64), [(128, 64, 64, 1)], id="tile128x64"),
|
||||
pytest.param((64, 64), [(64, 64, 64, 1)], id="tile64x64"),
|
||||
],
|
||||
)
|
||||
def test_l0_tile_shapes(tile_shape_mn, problem_sizes_mnkl, tmap_mode):
|
||||
@@ -222,15 +222,20 @@ def test_l0_tile_shapes(tile_shape_mn, problem_sizes_mnkl, tmap_mode):
|
||||
@pytest.mark.parametrize(
|
||||
"num_groups, problem_sizes_mnkl",
|
||||
[
|
||||
pytest.param(2, [(128, 256, 64, 1)] * 2, id="2g-uniform"),
|
||||
pytest.param(4, [(128, 256, 64, 1), (64, 128, 64, 1),
|
||||
(256, 128, 64, 1), (192, 256, 64, 1)], id="4g-mixed"),
|
||||
pytest.param(8, [(128, 256, 64, 1)] * 8, id="8g-uniform"),
|
||||
pytest.param(2, [(128, 256, 64, 1)] * 2, id="2g-uniform"),
|
||||
pytest.param(
|
||||
4,
|
||||
[(128, 256, 64, 1), (64, 128, 64, 1), (256, 128, 64, 1), (192, 256, 64, 1)],
|
||||
id="4g-mixed",
|
||||
),
|
||||
pytest.param(8, [(128, 256, 64, 1)] * 8, id="8g-uniform"),
|
||||
],
|
||||
)
|
||||
def test_l0_group_counts(num_groups, problem_sizes_mnkl, tmap_mode):
|
||||
"""Various group counts compile for tile (128,256) fp16."""
|
||||
_run_compile(num_groups, problem_sizes_mnkl, (128, 256), tensormap_update_mode=tmap_mode)
|
||||
_run_compile(
|
||||
num_groups, problem_sizes_mnkl, (128, 256), tensormap_update_mode=tmap_mode
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -244,28 +249,43 @@ def test_l0_group_counts(num_groups, problem_sizes_mnkl, tmap_mode):
|
||||
"a_dtype, b_dtype, c_dtype, acc_dtype, problem_sizes_mnkl",
|
||||
[
|
||||
# fp16 → fp16 output
|
||||
pytest.param(F16, F16, F16, F32, [(128, 256, 64, 1)], id="fp16-fp16-fp16-fp32"),
|
||||
pytest.param(F16, F16, F16, F32, [(128, 256, 64, 1)], id="fp16-fp16-fp16-fp32"),
|
||||
# fp16 → fp32 output
|
||||
pytest.param(F16, F16, F32, F32, [(128, 256, 64, 1)], id="fp16-fp16-fp32-fp32"),
|
||||
pytest.param(F16, F16, F32, F32, [(128, 256, 64, 1)], id="fp16-fp16-fp32-fp32"),
|
||||
# fp16 with fp16 accumulator
|
||||
pytest.param(F16, F16, F16, F16, [(128, 256, 64, 1)], id="fp16-fp16-fp16-fp16"),
|
||||
pytest.param(F16, F16, F16, F16, [(128, 256, 64, 1)], id="fp16-fp16-fp16-fp16"),
|
||||
# fp8 E4M3 → fp16 output (K must be multiple of 16 for fp8 alignment)
|
||||
pytest.param(F8E4, F8E4, F16, F32, [(128, 256, 128, 1)], id="fp8e4-fp8e4-fp16-fp32"),
|
||||
pytest.param(
|
||||
F8E4, F8E4, F16, F32, [(128, 256, 128, 1)], id="fp8e4-fp8e4-fp16-fp32"
|
||||
),
|
||||
# fp8 E5M2 → fp16 output
|
||||
pytest.param(F8E5, F8E5, F16, F32, [(128, 256, 128, 1)], id="fp8e5-fp8e5-fp16-fp32"),
|
||||
pytest.param(
|
||||
F8E5, F8E5, F16, F32, [(128, 256, 128, 1)], id="fp8e5-fp8e5-fp16-fp32"
|
||||
),
|
||||
# mixed fp8: E4M3 × E5M2
|
||||
pytest.param(F8E4, F8E5, F16, F32, [(128, 256, 128, 1)], id="fp8e4-fp8e5-fp16-fp32"),
|
||||
pytest.param(
|
||||
F8E4, F8E5, F16, F32, [(128, 256, 128, 1)], id="fp8e4-fp8e5-fp16-fp32"
|
||||
),
|
||||
# int8 → int32 output (K must be multiple of 16)
|
||||
pytest.param(I8, I8, I32, I32, [(128, 256, 128, 1)], id="int8-int8-int32-int32"),
|
||||
pytest.param(
|
||||
I8, I8, I32, I32, [(128, 256, 128, 1)], id="int8-int8-int32-int32"
|
||||
),
|
||||
# uint8 → int32 output
|
||||
pytest.param(U8, U8, I32, I32, [(128, 256, 128, 1)], id="uint8-uint8-int32-int32"),
|
||||
pytest.param(
|
||||
U8, U8, I32, I32, [(128, 256, 128, 1)], id="uint8-uint8-int32-int32"
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_l0_dtypes(a_dtype, b_dtype, c_dtype, acc_dtype, problem_sizes_mnkl, tmap_mode):
|
||||
"""Data type combinations compile for tile (128,256)."""
|
||||
_run_compile(
|
||||
1, problem_sizes_mnkl, (128, 256),
|
||||
a_dtype=a_dtype, b_dtype=b_dtype, c_dtype=c_dtype, acc_dtype=acc_dtype,
|
||||
1,
|
||||
problem_sizes_mnkl,
|
||||
(128, 256),
|
||||
a_dtype=a_dtype,
|
||||
b_dtype=b_dtype,
|
||||
c_dtype=c_dtype,
|
||||
acc_dtype=acc_dtype,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
|
||||
@@ -289,14 +309,22 @@ def test_l0_dtypes(a_dtype, b_dtype, c_dtype, acc_dtype, problem_sizes_mnkl, tma
|
||||
# m-major C output (M must be multiple of 8)
|
||||
pytest.param("k", "k", "m", [(128, 256, 64, 1)], (128, 256), id="akm-bkm-cmaj"),
|
||||
# m-major A + n-major B
|
||||
pytest.param("m", "n", "n", [(128, 128, 64, 1)], (128, 128), id="amaj-bnmaj-cn"),
|
||||
pytest.param(
|
||||
"m", "n", "n", [(128, 128, 64, 1)], (128, 128), id="amaj-bnmaj-cn"
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_l0_major_modes(a_major, b_major, c_major, problem_sizes_mnkl, tile_shape_mn, tmap_mode):
|
||||
def test_l0_major_modes(
|
||||
a_major, b_major, c_major, problem_sizes_mnkl, tile_shape_mn, tmap_mode
|
||||
):
|
||||
"""Matrix major mode combinations compile."""
|
||||
_run_compile(
|
||||
1, problem_sizes_mnkl, tile_shape_mn,
|
||||
a_major=a_major, b_major=b_major, c_major=c_major,
|
||||
1,
|
||||
problem_sizes_mnkl,
|
||||
tile_shape_mn,
|
||||
a_major=a_major,
|
||||
b_major=b_major,
|
||||
c_major=c_major,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
|
||||
@@ -321,10 +349,14 @@ def test_l0_major_modes(a_major, b_major, c_major, problem_sizes_mnkl, tile_shap
|
||||
pytest.param((2, 2), [(256, 512, 64, 1)], (128, 256), id="cluster2x2"),
|
||||
],
|
||||
)
|
||||
def test_l0_cluster_shapes(cluster_shape_mn, problem_sizes_mnkl, tile_shape_mn, tmap_mode):
|
||||
def test_l0_cluster_shapes(
|
||||
cluster_shape_mn, problem_sizes_mnkl, tile_shape_mn, tmap_mode
|
||||
):
|
||||
"""Cluster shapes including multicast paths compile."""
|
||||
_run_compile(
|
||||
1, problem_sizes_mnkl, tile_shape_mn,
|
||||
1,
|
||||
problem_sizes_mnkl,
|
||||
tile_shape_mn,
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
@@ -341,18 +373,20 @@ def test_l0_cluster_shapes(cluster_shape_mn, problem_sizes_mnkl, tile_shape_mn,
|
||||
"num_groups, problem_sizes_mnkl",
|
||||
[
|
||||
# groups with very different shapes
|
||||
pytest.param(4, [(64, 64, 64, 1),
|
||||
(128, 128, 64, 1),
|
||||
(256, 128, 64, 1),
|
||||
(128, 256, 64, 1)], id="4g-all-tiles"),
|
||||
pytest.param(
|
||||
4,
|
||||
[(64, 64, 64, 1), (128, 128, 64, 1), (256, 128, 64, 1), (128, 256, 64, 1)],
|
||||
id="4g-all-tiles",
|
||||
),
|
||||
# tiny vs large
|
||||
pytest.param(2, [(64, 64, 64, 1),
|
||||
(512, 512, 64, 1)], id="2g-tiny-large"),
|
||||
pytest.param(2, [(64, 64, 64, 1), (512, 512, 64, 1)], id="2g-tiny-large"),
|
||||
],
|
||||
)
|
||||
def test_l0_mixed_problem_sizes(num_groups, problem_sizes_mnkl, tmap_mode):
|
||||
"""Heterogeneous per-group problem sizes compile."""
|
||||
_run_compile(num_groups, problem_sizes_mnkl, (128, 256), tensormap_update_mode=tmap_mode)
|
||||
_run_compile(
|
||||
num_groups, problem_sizes_mnkl, (128, 256), tensormap_update_mode=tmap_mode
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -386,13 +420,15 @@ def test_l1_fp16_4g_mixed(tmap_mode):
|
||||
[
|
||||
pytest.param((128, 256), [(128, 256, 64, 1)], id="tile128x256"),
|
||||
pytest.param((128, 128), [(128, 128, 64, 1)], id="tile128x128"),
|
||||
pytest.param((128, 64), [(128, 64, 64, 1)], id="tile128x64"),
|
||||
pytest.param((64, 64), [(64, 64, 64, 1)], id="tile64x64"),
|
||||
pytest.param((128, 64), [(128, 64, 64, 1)], id="tile128x64"),
|
||||
pytest.param((64, 64), [(64, 64, 64, 1)], id="tile64x64"),
|
||||
],
|
||||
)
|
||||
def test_l1_tile_shapes_fp16(tile_shape_mn, problem_sizes_mnkl, tmap_mode):
|
||||
"""All tile shapes produce correct results."""
|
||||
_run_correctness(1, problem_sizes_mnkl, tile_shape_mn, tensormap_update_mode=tmap_mode)
|
||||
_run_correctness(
|
||||
1, problem_sizes_mnkl, tile_shape_mn, tensormap_update_mode=tmap_mode
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -429,8 +465,11 @@ def test_l1_group_count_scaling(num_groups, tmap_mode):
|
||||
def test_l1_fp16_c_fp32(tmap_mode):
|
||||
"""fp16 inputs with fp32 output are numerically correct."""
|
||||
_run_correctness(
|
||||
1, [(128, 256, 64, 1)], (128, 256),
|
||||
c_dtype=F32, acc_dtype=F32,
|
||||
1,
|
||||
[(128, 256, 64, 1)],
|
||||
(128, 256),
|
||||
c_dtype=F32,
|
||||
acc_dtype=F32,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
|
||||
@@ -441,8 +480,13 @@ def test_l1_fp16_c_fp32(tmap_mode):
|
||||
def test_l1_fp8_e4m3(tmap_mode):
|
||||
"""fp8 E4M3FN inputs are numerically correct (K=128 for 16B alignment)."""
|
||||
_run_correctness(
|
||||
1, [(128, 256, 128, 1)], (128, 256),
|
||||
a_dtype=F8E4, b_dtype=F8E4, c_dtype=F16, acc_dtype=F32,
|
||||
1,
|
||||
[(128, 256, 128, 1)],
|
||||
(128, 256),
|
||||
a_dtype=F8E4,
|
||||
b_dtype=F8E4,
|
||||
c_dtype=F16,
|
||||
acc_dtype=F32,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
tolerance=0.5,
|
||||
)
|
||||
@@ -454,8 +498,13 @@ def test_l1_fp8_e4m3(tmap_mode):
|
||||
def test_l1_fp8_mixed(tmap_mode):
|
||||
"""Mixed fp8 inputs (E4M3 × E5M2) are numerically correct."""
|
||||
_run_correctness(
|
||||
1, [(128, 256, 128, 1)], (128, 256),
|
||||
a_dtype=F8E4, b_dtype=F8E5, c_dtype=F16, acc_dtype=F32,
|
||||
1,
|
||||
[(128, 256, 128, 1)],
|
||||
(128, 256),
|
||||
a_dtype=F8E4,
|
||||
b_dtype=F8E5,
|
||||
c_dtype=F16,
|
||||
acc_dtype=F32,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
tolerance=0.5,
|
||||
)
|
||||
@@ -467,8 +516,13 @@ def test_l1_fp8_mixed(tmap_mode):
|
||||
def test_l1_int8(tmap_mode):
|
||||
"""int8 inputs with int32 accumulator are correct."""
|
||||
_run_correctness(
|
||||
1, [(128, 256, 128, 1)], (128, 256),
|
||||
a_dtype=I8, b_dtype=I8, c_dtype=I32, acc_dtype=I32,
|
||||
1,
|
||||
[(128, 256, 128, 1)],
|
||||
(128, 256),
|
||||
a_dtype=I8,
|
||||
b_dtype=I8,
|
||||
c_dtype=I32,
|
||||
acc_dtype=I32,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
tolerance=0,
|
||||
)
|
||||
@@ -485,7 +539,9 @@ def test_l1_int8(tmap_mode):
|
||||
def test_l1_c_m_major(tmap_mode):
|
||||
"""m-major C output is correct."""
|
||||
_run_correctness(
|
||||
1, [(128, 256, 64, 1)], (128, 256),
|
||||
1,
|
||||
[(128, 256, 64, 1)],
|
||||
(128, 256),
|
||||
c_major="m",
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
@@ -498,8 +554,12 @@ def test_l1_c_m_major(tmap_mode):
|
||||
def test_l1_all_non_default_majors(tmap_mode):
|
||||
"""m-major A, n-major B, m-major C together are correct."""
|
||||
_run_correctness(
|
||||
1, [(64, 64, 64, 1)], (128, 128),
|
||||
a_major="m", b_major="n", c_major="m",
|
||||
1,
|
||||
[(64, 64, 64, 1)],
|
||||
(128, 128),
|
||||
a_major="m",
|
||||
b_major="n",
|
||||
c_major="m",
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
|
||||
@@ -521,7 +581,9 @@ def test_l1_all_non_default_majors(tmap_mode):
|
||||
def test_l1_cluster_shapes(cluster_shape_mn, problem_sizes_mnkl, tmap_mode):
|
||||
"""Multicast cluster shapes produce correct results."""
|
||||
_run_correctness(
|
||||
1, problem_sizes_mnkl, (128, 256),
|
||||
1,
|
||||
problem_sizes_mnkl,
|
||||
(128, 256),
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
tensormap_update_mode=tmap_mode,
|
||||
)
|
||||
@@ -540,14 +602,14 @@ def test_l1_8g_mixed_sizes(tmap_mode):
|
||||
_run_correctness(
|
||||
8,
|
||||
[
|
||||
(128, 256, 64, 1),
|
||||
(64, 128, 64, 1),
|
||||
(256, 128, 64, 1),
|
||||
(128, 256, 64, 1),
|
||||
(64, 128, 64, 1),
|
||||
(256, 128, 64, 1),
|
||||
(128, 128, 128, 1),
|
||||
(192, 256, 64, 1),
|
||||
(64, 64, 64, 1),
|
||||
(192, 256, 64, 1),
|
||||
(64, 64, 64, 1),
|
||||
(128, 256, 128, 1),
|
||||
(256, 256, 64, 1),
|
||||
(256, 256, 64, 1),
|
||||
],
|
||||
(128, 256),
|
||||
tensormap_update_mode=tmap_mode,
|
||||
|
||||
@@ -26,5 +26,6 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.default_SMs[__file__] = "100f"
|
||||
config.default_SMs[__file__] = "100f"
|
||||
|
||||
@@ -41,29 +41,33 @@ from typing import Tuple, Type, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from blackwell.dense_blockscaled_gemm_persistent_prefetch import (
|
||||
from blackwell.kernel.blockscaled_gemm.dense_blockscaled_gemm_persistent_prefetch import (
|
||||
Sm100BlockScaledPersistentDenseGemmKernel,
|
||||
run,
|
||||
)
|
||||
|
||||
import cutlass
|
||||
|
||||
pytestmark = [pytest.mark.arch(["100a"])]
|
||||
|
||||
|
||||
@pytest.mark.invalid_case(
|
||||
lambda: not Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
|
||||
ab_dtype,
|
||||
sf_dtype,
|
||||
sf_vec_size,
|
||||
c_dtype,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
mnkl[0],
|
||||
mnkl[1],
|
||||
mnkl[2],
|
||||
mnkl[3],
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
lambda: (
|
||||
not Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
|
||||
ab_dtype,
|
||||
sf_dtype,
|
||||
sf_vec_size,
|
||||
c_dtype,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
mnkl[0],
|
||||
mnkl[1],
|
||||
mnkl[2],
|
||||
mnkl[3],
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
)
|
||||
)
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
@@ -110,7 +114,7 @@ pytestmark = [pytest.mark.arch(["100a"])]
|
||||
"prefetch_dist",
|
||||
[
|
||||
None, # Default: auto (uses num_ab_stage)
|
||||
0, # Disabled
|
||||
0, # Disabled
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("tolerance", [1e-01])
|
||||
@@ -145,20 +149,22 @@ def test_dense_blockscaled_gemm_prefetch(
|
||||
|
||||
|
||||
@pytest.mark.invalid_case(
|
||||
lambda: not Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
|
||||
ab_dtype,
|
||||
sf_dtype,
|
||||
sf_vec_size,
|
||||
c_dtype,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
mnkl[0],
|
||||
mnkl[1],
|
||||
mnkl[2],
|
||||
mnkl[3],
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
lambda: (
|
||||
not Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
|
||||
ab_dtype,
|
||||
sf_dtype,
|
||||
sf_vec_size,
|
||||
c_dtype,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
mnkl[0],
|
||||
mnkl[1],
|
||||
mnkl[2],
|
||||
mnkl[3],
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
)
|
||||
)
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
@@ -190,7 +196,7 @@ def test_dense_blockscaled_gemm_prefetch(
|
||||
"prefetch_dist",
|
||||
[
|
||||
None, # Default: auto (uses num_ab_stage)
|
||||
4, # Explicit distance
|
||||
4, # Explicit distance
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("tolerance", [1e-01])
|
||||
@@ -228,15 +234,15 @@ def test_dense_blockscaled_gemm_prefetch_L0(
|
||||
"prefetch_dist",
|
||||
[
|
||||
None, # Auto: uses num_ab_stage
|
||||
0, # Disabled
|
||||
2, # Small distance
|
||||
4, # Medium distance
|
||||
0, # Disabled
|
||||
2, # Small distance
|
||||
4, # Medium distance
|
||||
],
|
||||
)
|
||||
def test_prefetch_dist_configurations(prefetch_dist: Optional[int]):
|
||||
"""
|
||||
Test different prefetch_dist configurations specifically for blockscaled GEMM.
|
||||
|
||||
|
||||
- None: Auto mode, uses num_ab_stage as prefetch distance
|
||||
- 0: Prefetch disabled
|
||||
- >0: Explicit prefetch distance
|
||||
@@ -451,4 +457,3 @@ def test_invalid_tensor_alignment(
|
||||
cluster_shape_mn,
|
||||
tolerance,
|
||||
)
|
||||
|
||||
|
||||
@@ -40,8 +40,7 @@ from typing import Tuple, Type, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from blackwell.dense_gemm_persistent_prefetch import (
|
||||
PersistentDenseGemmKernel,
|
||||
from blackwell.kernel.dense_gemm.dense_gemm_persistent_prefetch import (
|
||||
run,
|
||||
)
|
||||
|
||||
@@ -92,8 +91,8 @@ import cutlass.cute.testing as testing
|
||||
"prefetch_dist",
|
||||
[
|
||||
None, # Default: auto (uses num_ab_stage)
|
||||
0, # Disabled
|
||||
2, # Explicit distance
|
||||
0, # Disabled
|
||||
2, # Explicit distance
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("tolerance", [1e-01])
|
||||
@@ -168,7 +167,7 @@ def test_dense_gemm_prefetch(
|
||||
"prefetch_dist",
|
||||
[
|
||||
None, # Default: auto (uses num_ab_stage)
|
||||
4, # Explicit distance
|
||||
4, # Explicit distance
|
||||
],
|
||||
)
|
||||
def test_dense_gemm_prefetch_L0(
|
||||
@@ -215,15 +214,15 @@ def test_dense_gemm_prefetch_L0(
|
||||
"prefetch_dist",
|
||||
[
|
||||
None, # Auto: uses num_ab_stage
|
||||
0, # Disabled
|
||||
2, # Small distance
|
||||
4, # Medium distance
|
||||
0, # Disabled
|
||||
2, # Small distance
|
||||
4, # Medium distance
|
||||
],
|
||||
)
|
||||
def test_prefetch_dist_configurations(prefetch_dist: Optional[int]):
|
||||
"""
|
||||
Test different prefetch_dist configurations specifically.
|
||||
|
||||
|
||||
- None: Auto mode, uses num_ab_stage as prefetch distance
|
||||
- 0: Prefetch disabled
|
||||
- >0: Explicit prefetch distance
|
||||
@@ -259,4 +258,3 @@ def test_prefetch_dist_configurations(prefetch_dist: Optional[int]):
|
||||
)
|
||||
except testing.CantImplementError:
|
||||
pytest.skip(f"Skip unsupported testcase with prefetch_dist={prefetch_dist}")
|
||||
|
||||
|
||||
@@ -38,11 +38,10 @@ Tests various configurations of:
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
|
||||
from blackwell.rmsnorm import (
|
||||
from blackwell.kernel.rmsnorm.rmsnorm import (
|
||||
run,
|
||||
get_sm_version,
|
||||
supports_cluster,
|
||||
@@ -104,7 +103,9 @@ class TestRMSNormCorrectness:
|
||||
class TestRMSNormClusterPath:
|
||||
"""Test the cluster path for large N (SM90+/SM100 only)."""
|
||||
|
||||
@pytest.mark.skipif(not supports_cluster(), reason="Cluster not supported on this GPU")
|
||||
@pytest.mark.skipif(
|
||||
not supports_cluster(), reason="Cluster not supported on this GPU"
|
||||
)
|
||||
@pytest.mark.parametrize("N", [32768, 65536])
|
||||
def test_cluster_path_correctness(self, N):
|
||||
"""Test cluster path produces correct results."""
|
||||
@@ -119,6 +120,7 @@ class TestRMSNormClusterPath:
|
||||
benchmark=False,
|
||||
)
|
||||
|
||||
|
||||
class TestRMSNormLargeN:
|
||||
"""Test RMSNorm with large N values."""
|
||||
|
||||
@@ -151,7 +153,6 @@ class TestRMSNormLargeN:
|
||||
)
|
||||
|
||||
|
||||
|
||||
class TestRMSNormEdgeCases:
|
||||
"""Test edge cases for RMSNorm."""
|
||||
|
||||
@@ -197,4 +198,4 @@ class TestRMSNormFloat32:
|
||||
tolerance=1e-4, # Tighter tolerance for FP32
|
||||
skip_ref_check=False,
|
||||
benchmark=False,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -26,14 +26,14 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from blackwell.tutorial_gemm import fp16_gemm_0
|
||||
from blackwell.tutorial_gemm import fp16_gemm_1
|
||||
from blackwell.tutorial_gemm import fp16_gemm_2
|
||||
from blackwell.tutorial_gemm import fp16_gemm_3
|
||||
from blackwell.tutorial_gemm import fp16_gemm_3_1
|
||||
from blackwell.tutorial_gemm import fp16_gemm_4
|
||||
from blackwell.tutorial_gemm import fp16_gemm_5
|
||||
from blackwell.tutorial_gemm import fp16_gemm_6
|
||||
from blackwell.tutorial.tutorial_gemm import fp16_gemm_0
|
||||
from blackwell.tutorial.tutorial_gemm import fp16_gemm_1
|
||||
from blackwell.tutorial.tutorial_gemm import fp16_gemm_2
|
||||
from blackwell.tutorial.tutorial_gemm import fp16_gemm_3
|
||||
from blackwell.tutorial.tutorial_gemm import fp16_gemm_3_1
|
||||
from blackwell.tutorial.tutorial_gemm import fp16_gemm_4
|
||||
from blackwell.tutorial.tutorial_gemm import fp16_gemm_5
|
||||
from blackwell.tutorial.tutorial_gemm import fp16_gemm_6
|
||||
|
||||
import pytest
|
||||
from typing import Tuple
|
||||
@@ -63,7 +63,6 @@ def test_fp16_gemm_1(
|
||||
fp16_gemm_1.run_dense_gemm(mnk, tolerance)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mnk",
|
||||
[(512, 512, 256)],
|
||||
|
||||
@@ -36,8 +36,10 @@ from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
@cute.kernel
|
||||
def _unary_ops_kernel(
|
||||
absf_inp: cute.Tensor, absf_out: cute.Tensor,
|
||||
floor_inp: cute.Tensor, floor_out: cute.Tensor,
|
||||
absf_inp: cute.Tensor,
|
||||
absf_out: cute.Tensor,
|
||||
floor_inp: cute.Tensor,
|
||||
floor_out: cute.Tensor,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
absf_out[tidx] = cute.math.absf(absf_inp[tidx])
|
||||
@@ -46,8 +48,10 @@ def _unary_ops_kernel(
|
||||
|
||||
@cute.jit
|
||||
def _unary_ops_host(
|
||||
absf_inp: cute.Tensor, absf_out: cute.Tensor,
|
||||
floor_inp: cute.Tensor, floor_out: cute.Tensor,
|
||||
absf_inp: cute.Tensor,
|
||||
absf_out: cute.Tensor,
|
||||
floor_inp: cute.Tensor,
|
||||
floor_out: cute.Tensor,
|
||||
):
|
||||
_unary_ops_kernel(absf_inp, absf_out, floor_inp, floor_out).launch(
|
||||
grid=[1, 1, 1], block=[absf_inp.shape[0], 1, 1]
|
||||
@@ -77,7 +81,9 @@ def test_unary_ops():
|
||||
|
||||
@cute.kernel
|
||||
def _binary_ops_kernel(
|
||||
mag_inp: cute.Tensor, sign_inp: cute.Tensor, out: cute.Tensor,
|
||||
mag_inp: cute.Tensor,
|
||||
sign_inp: cute.Tensor,
|
||||
out: cute.Tensor,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
out[tidx] = cute.math.copysign(mag_inp[tidx], sign_inp[tidx])
|
||||
@@ -85,7 +91,9 @@ def _binary_ops_kernel(
|
||||
|
||||
@cute.jit
|
||||
def _binary_ops_host(
|
||||
mag_inp: cute.Tensor, sign_inp: cute.Tensor, out: cute.Tensor,
|
||||
mag_inp: cute.Tensor,
|
||||
sign_inp: cute.Tensor,
|
||||
out: cute.Tensor,
|
||||
):
|
||||
_binary_ops_kernel(mag_inp, sign_inp, out).launch(
|
||||
grid=[1, 1, 1], block=[mag_inp.shape[0], 1, 1]
|
||||
|
||||
Reference in New Issue
Block a user