Don't access data_ptr of fake tensor. Fix EFC w/o epilogue

This commit is contained in:
jkosaian
2026-01-14 18:00:08 -08:00
parent e222b2a9b9
commit e594def95e
8 changed files with 420 additions and 55 deletions

View File

@@ -34,6 +34,7 @@ import cutlass
from cutlass_api.arguments import (
EpilogueArguments,
GemmArguments,
KernelOperand,
)
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import (
@@ -325,14 +326,12 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
max_active_clusters = get_max_active_clusters(self.impl.cluster_shape_mn)
if args.epilogue is not None:
epilogue_params = args.epilogue.parameters
epilogue_params = [
e.compile_time_tensor if isinstance(e, TensorWrapper) else e
for e in args.epilogue.parameters
]
else:
epilogue_params = [args.out]
epilogue_params = [
e.compile_time_tensor if isinstance(e, TensorWrapper) else e
for e in epilogue_params
]
epilogue_params = [args.out.tensor.compile_time_tensor]
# EFC needs special handling for supplemental arguments
self.impl.efc.compile(epilogue_params)
@@ -348,7 +347,9 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
# Wrap the compiled kernel to handle supplemental argument packing at launch time
def wrapped_launch(a_tensor, b_tensor, stream, *supplemental_args):
runtime_args = [
e.runtime_tensor if isinstance(e, TensorWrapper) else e
e.runtime_tensor
if isinstance(e, TensorWrapper)
else (e.tensor.runtime_tensor if isinstance(e, KernelOperand) else e)
for e in supplemental_args
]
return compiled_gemm(

View File

@@ -91,7 +91,6 @@ def is_numpy_tensor(inp) -> bool:
"""Check if the input is a numpy tensor."""
if is_numpy_available():
import numpy as np
return isinstance(inp, np.ndarray)
return False
@@ -105,6 +104,15 @@ def is_torch_tensor(inp) -> bool:
return False
def is_torch_fake_tensor(inp) -> bool:
"""Check if the input is a torch fake tensor."""
if is_torch_available():
import torch
return isinstance(inp, torch._subclasses.fake_tensor.FakeTensor)
return False
def _lazy_import(mod_name: str) -> Any:
"""Internal utility to lazily import a module only when needed."""
@@ -367,32 +375,31 @@ class TensorWrapper:
if isinstance(tensor, cute.Tensor):
# Regardless of whether TVM-FFI is enabled, if the tensor passed in is a cute.Tensor,
# it can be used as the runtime tensor and compile time tensor.
self.runtime_tensor = tensor
self._runtime_tensor = tensor
self.compile_time_tensor = tensor
self._shape = tensor.shape
self._stride = tensor.stride
self._data_ptr = tensor.iterator._pointer
elif GlobalOptions().use_tvm_ffi:
# If TVM-FFI is enabled, runtime tensor is set simply as the tensor passed in, but
# we must make a fake tensor for compilation.
self.runtime_tensor = tensor
if is_torch_tensor(self.runtime_tensor):
dtype = cutlass_type_from_torch_type(self.runtime_tensor.dtype)
elif is_torch_fake_tensor(tensor) or (
is_torch_tensor(tensor) and GlobalOptions().use_tvm_ffi
):
self._shape = tuple(tensor.shape)
self._stride = tensor.stride()
rank = self.runtime_tensor.dim()
self._stride = self.runtime_tensor.stride()
stride_order = get_stride_rank(self._stride)
leading_dim_idx = stride_order.index(0)
shape = [cute.SymInt() for _ in range(rank)]
shape[leading_dim_idx] = cute.SymInt(
divisibility=alignment_bytes * 8 // dtype.width
)
self._shape = tuple(self.runtime_tensor.shape)
self._data_ptr = self.runtime_tensor.data_ptr()
if is_torch_fake_tensor(tensor):
self._data_ptr = 0
self._runtime_tensor = None
else:
raise ValueError(
f"Unsupported tensor type: {type(self.runtime_tensor)}"
)
self._runtime_tensor = tensor
self._data_ptr = tensor.data_ptr()
dtype = cutlass_type_from_torch_type(tensor.dtype)
stride_order = get_stride_rank(tensor.stride())
leading_dim_idx = stride_order.index(0)
shape = [cute.SymInt() for _ in range(tensor.dim())]
shape[leading_dim_idx] = cute.SymInt(
divisibility=alignment_bytes * 8 // dtype.width
)
self.compile_time_tensor = cute.runtime.make_fake_compact_tensor(
dtype,
@@ -400,8 +407,11 @@ class TensorWrapper:
stride_order=stride_order,
assumed_align=alignment_bytes,
)
elif GlobalOptions().use_tvm_ffi:
raise ValueError("TVM-FFI is currently only supported for torch tensors.")
else:
# TVM-FFI is disabled and the tensor passed in is not a cute.Tensor,
# TVM-FFI is disabled and the tensor passed in is not a cute.Tensor or torch fake tensor,
# We must convert it to a cute.Tensor
if is_torch_tensor(tensor):
dtype = to_cutlass_type(tensor.dtype)
@@ -409,7 +419,7 @@ class TensorWrapper:
else:
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
stride_order = get_stride_order(stride)
self.runtime_tensor = (
self._runtime_tensor = (
from_dlpack(
tensor,
assumed_align=alignment_bytes,
@@ -422,13 +432,21 @@ class TensorWrapper:
)
)
self._shape = self.runtime_tensor.shape
self._stride = self.runtime_tensor.stride
self._data_ptr = self.runtime_tensor.iterator._pointer
self._shape = self._runtime_tensor.shape
self._stride = self._runtime_tensor.stride
self._data_ptr = self._runtime_tensor.iterator._pointer
# Since the runtime tensor is now a cute.Tensor, we can use it at
# compile time as well
self.compile_time_tensor = self.runtime_tensor
self.compile_time_tensor = self._runtime_tensor
@property
def runtime_tensor(self):
if self._runtime_tensor is None:
raise ValueError(
"Attempting to access runtime tensor from argument constructed with a fake tensor."
)
return self._runtime_tensor
@property
def element_type(self) -> type[cutlass.Numeric]:

View File

@@ -180,6 +180,99 @@ def test_mxfp8_gemm_sm100(
reference = reference_scaled_mm(A, B, SFA, SFB, c_dtype)
torch.testing.assert_close(D, reference)
@pytest.mark.skipif(
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
)
def test_mxfp8_gemm_sm100_fake_tensor(fixture_enable_tvm_ffi):
import torch._functorch.config
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = False
M, N, K, L = 256, 512, 1024, 1
ab_dtype = torch.float8_e4m3fn
c_dtype = torch.float32
accumulator_type = torch.float32
scale_dtype = torch.float8_e8m0fnu
scale_size = 32
with torch._subclasses.fake_tensor.FakeTensorMode():
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(ab_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(ab_dtype).transpose(1, 2)
SFA = torch.rand(
(
L,
M,
prep_k(K, scale_size),
),
device="cuda",
).to(scale_dtype)
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(scale_dtype)
fake_args = cutlass_api.arguments.GemmArguments(
A=ScaledTensor(
A,
SFA,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
B=ScaledTensor(
B,
SFB,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D,
accumulator_type=accumulator_type,
)
kernels = cutlass_api.get_kernels(fake_args, cc=100)
assert len(kernels) > 0
kernel = kernels[0]
assert kernel.supports(fake_args)
compiled_artifact = kernel.compile(fake_args)
A_real = torch.randint(-1, 2, (L, M, K), device="cuda").to(ab_dtype)
B_real = torch.randint(-1, 2, (L, N, K), device="cuda").to(ab_dtype).transpose(1, 2)
D_real = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
SFA_real = torch.rand(
(
L,
M,
prep_k(K, scale_size),
),
device="cuda",
).to(scale_dtype)
SFB_real = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(scale_dtype)
args = cutlass_api.arguments.GemmArguments(
A=ScaledTensor(
A_real,
SFA_real,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
B=ScaledTensor(
B_real,
SFB_real,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D_real,
accumulator_type=accumulator_type,
)
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
# torch._scaled_mm does not support f8e5m2 * f8e5m2 currently.
# Simply skip reference check in that case (but test that a CUTLASS API kernel
# is found and runs)
if ab_dtype != torch.float8_e5m2:
reference = reference_scaled_mm(A_real, B_real, SFA_real, SFB_real, c_dtype)
torch.testing.assert_close(D_real, reference)
@pytest.mark.parametrize(
"M, N, K",
[

View File

@@ -33,6 +33,7 @@ import random
import torch
import cutlass_api
from cutlass_api.utils import is_device_cc_supported
torch.manual_seed(2025)
@@ -67,3 +68,54 @@ def test_incorrect_offset_length():
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) == 0
@pytest.mark.skipif(
not is_device_cc_supported({100})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_contiguous_offset_dense_gemm_2d3d_fake_tensor(fixture_toggle_tvm_ffi):
import torch._functorch.config
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = False
M, N, K, L = 256, 512, 128, 2
ab_dtype = torch.float16
c_dtype = torch.float16
accumulator_type = torch.float32
with torch._subclasses.fake_tensor.FakeTensorMode():
A = torch.randint(-1, 2, (M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, N, K), device="cuda", dtype=ab_dtype).permute(
0, 2, 1
)
out = torch.empty((M, N), device="cuda", dtype=c_dtype)
offsets = torch.empty((L,), device="cuda", dtype=torch.int32)
fake_args = cutlass_api.arguments.GroupedGemmArguments(
A=A, B=B, out=out, accumulator_type=accumulator_type, offsets=offsets
)
kernels = cutlass_api.get_kernels(fake_args, cc=100)
assert len(kernels) > 0
kernel = kernels[0]
compiled_artifact = kernel.compile(fake_args)
A_real = torch.randint(-1, 2, (M, K), device="cuda", dtype=ab_dtype)
B_real = torch.randint(-1, 2, (L, N, K), device="cuda", dtype=ab_dtype).permute(
0, 2, 1
)
out_real = torch.empty((M, N), device="cuda", dtype=c_dtype)
offsets_real = torch.Tensor([128, 256]).to("cuda").to(torch.int32)
args = cutlass_api.arguments.GroupedGemmArguments(
A=A_real,
B=B_real,
out=out_real,
accumulator_type=accumulator_type,
offsets=offsets_real,
)
kernel.run(args, compiled_artifact=compiled_artifact)
reference = torch._grouped_mm(A_real, B_real, offsets_real)
torch.testing.assert_close(out_real, reference.to(out_real.dtype))

View File

@@ -64,3 +64,33 @@ def test_elementwise_add(M: int, N: int, dtype: torch.dtype, fixture_toggle_tvm_
reference = A + B
assert torch.allclose(D, reference)
def test_elementwise_add_fake_tensor(fixture_toggle_tvm_ffi):
import torch._functorch.config
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = False
M, N = 256, 512
dtype = torch.float16
with torch._subclasses.fake_tensor.FakeTensorMode():
A = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
B = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
D = torch.empty((M, N), device="cuda", dtype=dtype)
fake_args = cutlass_api.arguments.ElementwiseArguments(A=A, B=B, out=D)
kernels = cutlass_api.get_kernels(fake_args)
assert len(kernels) > 0
kernel = kernels[0]
compiled_artifact = kernel.compile(fake_args)
A = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
B = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
D = torch.empty((M, N), device="cuda", dtype=dtype)
args = cutlass_api.arguments.ElementwiseArguments(A=A, B=B, out=D)
kernel.run(args, compiled_artifact=compiled_artifact)
reference = A + B
assert torch.allclose(D, reference)

View File

@@ -273,3 +273,43 @@ def test_metadata_filter():
assert kernel.metadata.operands.B.dtype == cutlass.Float16
assert kernel.metadata.operands.out.dtype == cutlass.Float16
assert kernel.metadata.operands.accumulator_type == cutlass.Float16
@pytest.mark.skipif(
not is_device_cc_supported({100})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_sm100_fake_tensor(fixture_toggle_tvm_ffi):
import torch._functorch.config
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = False
M, N, K, L = 256, 512, 128, 1
ab_dtype = torch.float16
c_dtype = torch.float16
accumulator_type = torch.float32
with torch._subclasses.fake_tensor.FakeTensorMode():
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
fake_args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
kernels = cutlass_api.get_kernels(fake_args, cc=100)
assert len(kernels) > 0
kernel = kernels[0]
compiled_artifact = kernel.compile(fake_args)
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
kernel.run(args, compiled_artifact=compiled_artifact)
reference = A @ B
torch.testing.assert_close(D, reference.to(D.dtype))

View File

@@ -61,12 +61,6 @@ def base_data_types():
]
def supports_sm100af():
return is_device_cc_supported({100}) and (
os.getenv("CUTE_DSL_ARCH", "") in ["", "sm_100a", "sm_100f"]
)
# Unary operation strings and functions
identity = ("", lambda x: x)
relu = ("relu", torch.relu)
@@ -86,6 +80,36 @@ mul = (lambda a, b: f"{a} * {b}", lambda a, b: a * b)
binary_ops = [add, sub, mul]
@pytest.mark.skipif(
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_no_fusion(fixture_toggle_tvm_ffi):
"""
Tests EFC GEMM with no fusion provided.
"""
M, N, K, L = 256, 512, 128, 2
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=torch.float16)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=torch.float16)
D = torch.empty((L, M, N), device="cuda", dtype=torch.float16)
def metadata_filter(metadata):
return (
metadata.kernel_class
== cutlass_api.providers.cutedsl.gemm.sm100_static_persistent_efc.PersistentDenseGemmEFCKernel
)
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.float16)
kernels = cutlass_api.get_kernels(args, cc=100, metadata_filter=metadata_filter)
assert len(kernels) > 0
kernels[0].run(args)
reference = A @ B
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
@pytest.mark.parametrize(
@@ -94,7 +118,8 @@ binary_ops = [add, sub, mul]
)
@pytest.mark.parametrize("unary_str, unary_op", unary_ops)
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_unary(
@@ -132,7 +157,8 @@ def test_gemm_fusion_unary(
@pytest.mark.parametrize("unary_str, unary_op", [relu])
@pytest.mark.parametrize("unary_str2, unary_op2", [sigmoid, tanh])
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_unary_composition(
@@ -180,7 +206,8 @@ def test_gemm_fusion_unary_composition(
# Restrict unary to identity and relu to avoid rounding errors
@pytest.mark.parametrize("unary_str, unary_op", [identity, relu])
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_unary_literal(
@@ -216,7 +243,8 @@ def test_gemm_fusion_unary_literal(
@pytest.mark.parametrize("unary_str, unary_op", [identity, relu])
@pytest.mark.parametrize("binary_str, binary_op", binary_ops)
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_unary_binary_composition(
@@ -270,7 +298,8 @@ def test_gemm_fusion_unary_binary_composition(
@pytest.mark.parametrize("binary_str0, binary_op0", [add, sub])
@pytest.mark.parametrize("binary_str1, binary_op1", [add, sub])
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_binary_binary_composition(
@@ -317,7 +346,8 @@ def test_gemm_fusion_binary_binary_composition(
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_division():
@@ -384,7 +414,8 @@ def test_gemm_fusion_division():
)
@pytest.mark.parametrize("unary_str, unary_op", [sigmoid, tanh])
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_unary_multi_output(
@@ -427,7 +458,8 @@ def test_gemm_fusion_unary_multi_output(
@pytest.mark.parametrize("c_dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("binary_str, binary_op", binary_ops)
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_binary_multi_output(
@@ -463,7 +495,8 @@ def test_gemm_fusion_binary_multi_output(
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_return_acc():
@@ -501,7 +534,8 @@ def test_gemm_fusion_return_acc():
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_acc_as_multiple_input():
@@ -562,7 +596,8 @@ def test_gemm_fusion_acc_as_multiple_input():
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_matmul_input_as_aux():
@@ -627,7 +662,8 @@ def test_gemm_fusion_matmul_input_as_aux():
"ab_dtype, c_dtype, d_dtype, accumulator_type", base_data_types()
)
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_alpha_beta(
@@ -669,7 +705,69 @@ def test_gemm_alpha_beta(
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_alpha_beta_fake_tensor(fixture_toggle_tvm_ffi):
import torch._functorch.config
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = False
M, N, K, L = 256, 512, 128, 2
ab_dtype = torch.float16
c_dtype = torch.float32
d_dtype = torch.bfloat16
accumulator_type = torch.float16
with torch._subclasses.fake_tensor.FakeTensorMode():
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
def epi(accum, C, alpha, beta):
D = alpha * accum + beta * C
return D
alpha = 0.5
beta = 0.5
epi_args = cutlass_api.arguments.EpilogueArguments(
epi, C=C, alpha=alpha, beta=beta, D=D
)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernel = kernels[0]
A_real = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B_real = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
C_real = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
D_real = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
for a, b in [(0.5, 0.5), (1.0, 0.0), (0.0, 1.0)]:
epi_args = cutlass_api.arguments.EpilogueArguments(
epi, C=C_real, alpha=a, beta=b, D=D_real
)
args = cutlass_api.arguments.GemmArguments(
A=A_real,
B=B_real,
out=D_real,
accumulator_type=accumulator_type,
epilogue=epi_args,
)
kernel.run(args)
reference = epi(A_real @ B_real, C_real, a, b)
torch.testing.assert_close(D_real, reference.to(D_real.dtype))
@pytest.mark.skipif(
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_big_epi(fixture_toggle_tvm_ffi):
@@ -798,7 +896,8 @@ def test_gemm_big_epi(fixture_toggle_tvm_ffi):
@pytest.mark.skipif(
not supports_sm100af(),
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_not_available(fixture_toggle_tvm_ffi):

View File

@@ -209,3 +209,35 @@ def test_gemm_sm80_layouts(
pytest.fail(
f"Kernel {idx+1}/{len(kernels_to_test)} ({kernel.metadata.kernel_name}) failed for layout {layout_A}{layout_B}{layout_C}: {e}"
)
def test_gemm_sm80_fake_tensor(fixture_toggle_tvm_ffi):
import torch._functorch.config
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = False
M, N, K, L = 256, 512, 128, 2
ab_dtype = torch.float16
c_dtype = torch.float16
accumulator_type = torch.float32
with torch._subclasses.fake_tensor.FakeTensorMode():
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
fake_args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
kernels = cutlass_api.get_kernels(fake_args, cc=80)
assert len(kernels) > 0
kernel = kernels[0]
compiled_artifact = kernel.compile(fake_args)
A_real = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B_real = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
D_real = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
args = cutlass_api.arguments.GemmArguments(A_real, B_real, D_real, accumulator_type)
kernel.run(args, compiled_artifact=compiled_artifact)
reference = A_real @ B_real
torch.testing.assert_close(D_real, reference.to(D_real.dtype))