mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
Don't access data_ptr of fake tensor. Fix EFC w/o epilogue
This commit is contained in:
@@ -34,6 +34,7 @@ import cutlass
|
||||
from cutlass_api.arguments import (
|
||||
EpilogueArguments,
|
||||
GemmArguments,
|
||||
KernelOperand,
|
||||
)
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import (
|
||||
@@ -325,14 +326,12 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
|
||||
max_active_clusters = get_max_active_clusters(self.impl.cluster_shape_mn)
|
||||
|
||||
if args.epilogue is not None:
|
||||
epilogue_params = args.epilogue.parameters
|
||||
epilogue_params = [
|
||||
e.compile_time_tensor if isinstance(e, TensorWrapper) else e
|
||||
for e in args.epilogue.parameters
|
||||
]
|
||||
else:
|
||||
epilogue_params = [args.out]
|
||||
|
||||
epilogue_params = [
|
||||
e.compile_time_tensor if isinstance(e, TensorWrapper) else e
|
||||
for e in epilogue_params
|
||||
]
|
||||
epilogue_params = [args.out.tensor.compile_time_tensor]
|
||||
|
||||
# EFC needs special handling for supplemental arguments
|
||||
self.impl.efc.compile(epilogue_params)
|
||||
@@ -348,7 +347,9 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
|
||||
# Wrap the compiled kernel to handle supplemental argument packing at launch time
|
||||
def wrapped_launch(a_tensor, b_tensor, stream, *supplemental_args):
|
||||
runtime_args = [
|
||||
e.runtime_tensor if isinstance(e, TensorWrapper) else e
|
||||
e.runtime_tensor
|
||||
if isinstance(e, TensorWrapper)
|
||||
else (e.tensor.runtime_tensor if isinstance(e, KernelOperand) else e)
|
||||
for e in supplemental_args
|
||||
]
|
||||
return compiled_gemm(
|
||||
|
||||
@@ -91,7 +91,6 @@ def is_numpy_tensor(inp) -> bool:
|
||||
"""Check if the input is a numpy tensor."""
|
||||
if is_numpy_available():
|
||||
import numpy as np
|
||||
|
||||
return isinstance(inp, np.ndarray)
|
||||
return False
|
||||
|
||||
@@ -105,6 +104,15 @@ def is_torch_tensor(inp) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_torch_fake_tensor(inp) -> bool:
|
||||
"""Check if the input is a torch fake tensor."""
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
return isinstance(inp, torch._subclasses.fake_tensor.FakeTensor)
|
||||
return False
|
||||
|
||||
|
||||
def _lazy_import(mod_name: str) -> Any:
|
||||
"""Internal utility to lazily import a module only when needed."""
|
||||
|
||||
@@ -367,32 +375,31 @@ class TensorWrapper:
|
||||
if isinstance(tensor, cute.Tensor):
|
||||
# Regardless of whether TVM-FFI is enabled, if the tensor passed in is a cute.Tensor,
|
||||
# it can be used as the runtime tensor and compile time tensor.
|
||||
self.runtime_tensor = tensor
|
||||
self._runtime_tensor = tensor
|
||||
self.compile_time_tensor = tensor
|
||||
self._shape = tensor.shape
|
||||
self._stride = tensor.stride
|
||||
self._data_ptr = tensor.iterator._pointer
|
||||
elif GlobalOptions().use_tvm_ffi:
|
||||
# If TVM-FFI is enabled, runtime tensor is set simply as the tensor passed in, but
|
||||
# we must make a fake tensor for compilation.
|
||||
self.runtime_tensor = tensor
|
||||
if is_torch_tensor(self.runtime_tensor):
|
||||
dtype = cutlass_type_from_torch_type(self.runtime_tensor.dtype)
|
||||
elif is_torch_fake_tensor(tensor) or (
|
||||
is_torch_tensor(tensor) and GlobalOptions().use_tvm_ffi
|
||||
):
|
||||
self._shape = tuple(tensor.shape)
|
||||
self._stride = tensor.stride()
|
||||
|
||||
rank = self.runtime_tensor.dim()
|
||||
self._stride = self.runtime_tensor.stride()
|
||||
stride_order = get_stride_rank(self._stride)
|
||||
leading_dim_idx = stride_order.index(0)
|
||||
shape = [cute.SymInt() for _ in range(rank)]
|
||||
shape[leading_dim_idx] = cute.SymInt(
|
||||
divisibility=alignment_bytes * 8 // dtype.width
|
||||
)
|
||||
self._shape = tuple(self.runtime_tensor.shape)
|
||||
self._data_ptr = self.runtime_tensor.data_ptr()
|
||||
if is_torch_fake_tensor(tensor):
|
||||
self._data_ptr = 0
|
||||
self._runtime_tensor = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported tensor type: {type(self.runtime_tensor)}"
|
||||
)
|
||||
self._runtime_tensor = tensor
|
||||
self._data_ptr = tensor.data_ptr()
|
||||
|
||||
dtype = cutlass_type_from_torch_type(tensor.dtype)
|
||||
stride_order = get_stride_rank(tensor.stride())
|
||||
leading_dim_idx = stride_order.index(0)
|
||||
shape = [cute.SymInt() for _ in range(tensor.dim())]
|
||||
shape[leading_dim_idx] = cute.SymInt(
|
||||
divisibility=alignment_bytes * 8 // dtype.width
|
||||
)
|
||||
|
||||
self.compile_time_tensor = cute.runtime.make_fake_compact_tensor(
|
||||
dtype,
|
||||
@@ -400,8 +407,11 @@ class TensorWrapper:
|
||||
stride_order=stride_order,
|
||||
assumed_align=alignment_bytes,
|
||||
)
|
||||
|
||||
elif GlobalOptions().use_tvm_ffi:
|
||||
raise ValueError("TVM-FFI is currently only supported for torch tensors.")
|
||||
else:
|
||||
# TVM-FFI is disabled and the tensor passed in is not a cute.Tensor,
|
||||
# TVM-FFI is disabled and the tensor passed in is not a cute.Tensor or torch fake tensor,
|
||||
# We must convert it to a cute.Tensor
|
||||
if is_torch_tensor(tensor):
|
||||
dtype = to_cutlass_type(tensor.dtype)
|
||||
@@ -409,7 +419,7 @@ class TensorWrapper:
|
||||
else:
|
||||
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
|
||||
stride_order = get_stride_order(stride)
|
||||
self.runtime_tensor = (
|
||||
self._runtime_tensor = (
|
||||
from_dlpack(
|
||||
tensor,
|
||||
assumed_align=alignment_bytes,
|
||||
@@ -422,13 +432,21 @@ class TensorWrapper:
|
||||
)
|
||||
)
|
||||
|
||||
self._shape = self.runtime_tensor.shape
|
||||
self._stride = self.runtime_tensor.stride
|
||||
self._data_ptr = self.runtime_tensor.iterator._pointer
|
||||
self._shape = self._runtime_tensor.shape
|
||||
self._stride = self._runtime_tensor.stride
|
||||
self._data_ptr = self._runtime_tensor.iterator._pointer
|
||||
|
||||
# Since the runtime tensor is now a cute.Tensor, we can use it at
|
||||
# compile time as well
|
||||
self.compile_time_tensor = self.runtime_tensor
|
||||
self.compile_time_tensor = self._runtime_tensor
|
||||
|
||||
@property
|
||||
def runtime_tensor(self):
|
||||
if self._runtime_tensor is None:
|
||||
raise ValueError(
|
||||
"Attempting to access runtime tensor from argument constructed with a fake tensor."
|
||||
)
|
||||
return self._runtime_tensor
|
||||
|
||||
@property
|
||||
def element_type(self) -> type[cutlass.Numeric]:
|
||||
|
||||
@@ -180,6 +180,99 @@ def test_mxfp8_gemm_sm100(
|
||||
reference = reference_scaled_mm(A, B, SFA, SFB, c_dtype)
|
||||
torch.testing.assert_close(D, reference)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
|
||||
)
|
||||
def test_mxfp8_gemm_sm100_fake_tensor(fixture_enable_tvm_ffi):
|
||||
import torch._functorch.config
|
||||
|
||||
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = False
|
||||
|
||||
M, N, K, L = 256, 512, 1024, 1
|
||||
ab_dtype = torch.float8_e4m3fn
|
||||
c_dtype = torch.float32
|
||||
accumulator_type = torch.float32
|
||||
scale_dtype = torch.float8_e8m0fnu
|
||||
scale_size = 32
|
||||
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode():
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(ab_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
|
||||
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(ab_dtype).transpose(1, 2)
|
||||
SFA = torch.rand(
|
||||
(
|
||||
L,
|
||||
M,
|
||||
prep_k(K, scale_size),
|
||||
),
|
||||
device="cuda",
|
||||
).to(scale_dtype)
|
||||
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(scale_dtype)
|
||||
|
||||
fake_args = cutlass_api.arguments.GemmArguments(
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=ScaledTensor(
|
||||
B,
|
||||
SFB,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D,
|
||||
accumulator_type=accumulator_type,
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(fake_args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
assert kernel.supports(fake_args)
|
||||
compiled_artifact = kernel.compile(fake_args)
|
||||
|
||||
A_real = torch.randint(-1, 2, (L, M, K), device="cuda").to(ab_dtype)
|
||||
B_real = torch.randint(-1, 2, (L, N, K), device="cuda").to(ab_dtype).transpose(1, 2)
|
||||
D_real = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
|
||||
SFA_real = torch.rand(
|
||||
(
|
||||
L,
|
||||
M,
|
||||
prep_k(K, scale_size),
|
||||
),
|
||||
device="cuda",
|
||||
).to(scale_dtype)
|
||||
SFB_real = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(scale_dtype)
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=ScaledTensor(
|
||||
A_real,
|
||||
SFA_real,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=ScaledTensor(
|
||||
B_real,
|
||||
SFB_real,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D_real,
|
||||
accumulator_type=accumulator_type,
|
||||
)
|
||||
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
|
||||
|
||||
# torch._scaled_mm does not support f8e5m2 * f8e5m2 currently.
|
||||
# Simply skip reference check in that case (but test that a CUTLASS API kernel
|
||||
# is found and runs)
|
||||
if ab_dtype != torch.float8_e5m2:
|
||||
reference = reference_scaled_mm(A_real, B_real, SFA_real, SFB_real, c_dtype)
|
||||
torch.testing.assert_close(D_real, reference)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K",
|
||||
[
|
||||
|
||||
@@ -33,6 +33,7 @@ import random
|
||||
import torch
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.utils import is_device_cc_supported
|
||||
|
||||
|
||||
torch.manual_seed(2025)
|
||||
@@ -67,3 +68,54 @@ def test_incorrect_offset_length():
|
||||
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
assert len(kernels) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_contiguous_offset_dense_gemm_2d3d_fake_tensor(fixture_toggle_tvm_ffi):
|
||||
import torch._functorch.config
|
||||
|
||||
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = False
|
||||
|
||||
M, N, K, L = 256, 512, 128, 2
|
||||
ab_dtype = torch.float16
|
||||
c_dtype = torch.float16
|
||||
accumulator_type = torch.float32
|
||||
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode():
|
||||
A = torch.randint(-1, 2, (M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, N, K), device="cuda", dtype=ab_dtype).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
out = torch.empty((M, N), device="cuda", dtype=c_dtype)
|
||||
offsets = torch.empty((L,), device="cuda", dtype=torch.int32)
|
||||
|
||||
fake_args = cutlass_api.arguments.GroupedGemmArguments(
|
||||
A=A, B=B, out=out, accumulator_type=accumulator_type, offsets=offsets
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(fake_args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
compiled_artifact = kernel.compile(fake_args)
|
||||
|
||||
A_real = torch.randint(-1, 2, (M, K), device="cuda", dtype=ab_dtype)
|
||||
B_real = torch.randint(-1, 2, (L, N, K), device="cuda", dtype=ab_dtype).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
out_real = torch.empty((M, N), device="cuda", dtype=c_dtype)
|
||||
offsets_real = torch.Tensor([128, 256]).to("cuda").to(torch.int32)
|
||||
args = cutlass_api.arguments.GroupedGemmArguments(
|
||||
A=A_real,
|
||||
B=B_real,
|
||||
out=out_real,
|
||||
accumulator_type=accumulator_type,
|
||||
offsets=offsets_real,
|
||||
)
|
||||
kernel.run(args, compiled_artifact=compiled_artifact)
|
||||
|
||||
reference = torch._grouped_mm(A_real, B_real, offsets_real)
|
||||
torch.testing.assert_close(out_real, reference.to(out_real.dtype))
|
||||
|
||||
@@ -64,3 +64,33 @@ def test_elementwise_add(M: int, N: int, dtype: torch.dtype, fixture_toggle_tvm_
|
||||
reference = A + B
|
||||
|
||||
assert torch.allclose(D, reference)
|
||||
|
||||
|
||||
def test_elementwise_add_fake_tensor(fixture_toggle_tvm_ffi):
|
||||
import torch._functorch.config
|
||||
|
||||
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = False
|
||||
|
||||
M, N = 256, 512
|
||||
dtype = torch.float16
|
||||
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode():
|
||||
A = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
|
||||
B = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
|
||||
D = torch.empty((M, N), device="cuda", dtype=dtype)
|
||||
|
||||
fake_args = cutlass_api.arguments.ElementwiseArguments(A=A, B=B, out=D)
|
||||
kernels = cutlass_api.get_kernels(fake_args)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
compiled_artifact = kernel.compile(fake_args)
|
||||
|
||||
A = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
|
||||
B = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
|
||||
D = torch.empty((M, N), device="cuda", dtype=dtype)
|
||||
args = cutlass_api.arguments.ElementwiseArguments(A=A, B=B, out=D)
|
||||
kernel.run(args, compiled_artifact=compiled_artifact)
|
||||
|
||||
reference = A + B
|
||||
assert torch.allclose(D, reference)
|
||||
|
||||
@@ -273,3 +273,43 @@ def test_metadata_filter():
|
||||
assert kernel.metadata.operands.B.dtype == cutlass.Float16
|
||||
assert kernel.metadata.operands.out.dtype == cutlass.Float16
|
||||
assert kernel.metadata.operands.accumulator_type == cutlass.Float16
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_sm100_fake_tensor(fixture_toggle_tvm_ffi):
|
||||
import torch._functorch.config
|
||||
|
||||
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = False
|
||||
|
||||
M, N, K, L = 256, 512, 128, 1
|
||||
ab_dtype = torch.float16
|
||||
c_dtype = torch.float16
|
||||
accumulator_type = torch.float32
|
||||
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode():
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
|
||||
|
||||
fake_args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
|
||||
|
||||
kernels = cutlass_api.get_kernels(fake_args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
|
||||
compiled_artifact = kernel.compile(fake_args)
|
||||
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
|
||||
|
||||
kernel.run(args, compiled_artifact=compiled_artifact)
|
||||
|
||||
reference = A @ B
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
@@ -61,12 +61,6 @@ def base_data_types():
|
||||
]
|
||||
|
||||
|
||||
def supports_sm100af():
|
||||
return is_device_cc_supported({100}) and (
|
||||
os.getenv("CUTE_DSL_ARCH", "") in ["", "sm_100a", "sm_100f"]
|
||||
)
|
||||
|
||||
|
||||
# Unary operation strings and functions
|
||||
identity = ("", lambda x: x)
|
||||
relu = ("relu", torch.relu)
|
||||
@@ -86,6 +80,36 @@ mul = (lambda a, b: f"{a} * {b}", lambda a, b: a * b)
|
||||
binary_ops = [add, sub, mul]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_no_fusion(fixture_toggle_tvm_ffi):
|
||||
"""
|
||||
Tests EFC GEMM with no fusion provided.
|
||||
|
||||
"""
|
||||
M, N, K, L = 256, 512, 128, 2
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=torch.float16)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=torch.float16)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=torch.float16)
|
||||
|
||||
def metadata_filter(metadata):
|
||||
return (
|
||||
metadata.kernel_class
|
||||
== cutlass_api.providers.cutedsl.gemm.sm100_static_persistent_efc.PersistentDenseGemmEFCKernel
|
||||
)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.float16)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100, metadata_filter=metadata_filter)
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference = A @ B
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
|
||||
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
|
||||
@pytest.mark.parametrize(
|
||||
@@ -94,7 +118,8 @@ binary_ops = [add, sub, mul]
|
||||
)
|
||||
@pytest.mark.parametrize("unary_str, unary_op", unary_ops)
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_unary(
|
||||
@@ -132,7 +157,8 @@ def test_gemm_fusion_unary(
|
||||
@pytest.mark.parametrize("unary_str, unary_op", [relu])
|
||||
@pytest.mark.parametrize("unary_str2, unary_op2", [sigmoid, tanh])
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_unary_composition(
|
||||
@@ -180,7 +206,8 @@ def test_gemm_fusion_unary_composition(
|
||||
# Restrict unary to identity and relu to avoid rounding errors
|
||||
@pytest.mark.parametrize("unary_str, unary_op", [identity, relu])
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_unary_literal(
|
||||
@@ -216,7 +243,8 @@ def test_gemm_fusion_unary_literal(
|
||||
@pytest.mark.parametrize("unary_str, unary_op", [identity, relu])
|
||||
@pytest.mark.parametrize("binary_str, binary_op", binary_ops)
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_unary_binary_composition(
|
||||
@@ -270,7 +298,8 @@ def test_gemm_fusion_unary_binary_composition(
|
||||
@pytest.mark.parametrize("binary_str0, binary_op0", [add, sub])
|
||||
@pytest.mark.parametrize("binary_str1, binary_op1", [add, sub])
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_binary_binary_composition(
|
||||
@@ -317,7 +346,8 @@ def test_gemm_fusion_binary_binary_composition(
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_division():
|
||||
@@ -384,7 +414,8 @@ def test_gemm_fusion_division():
|
||||
)
|
||||
@pytest.mark.parametrize("unary_str, unary_op", [sigmoid, tanh])
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_unary_multi_output(
|
||||
@@ -427,7 +458,8 @@ def test_gemm_fusion_unary_multi_output(
|
||||
@pytest.mark.parametrize("c_dtype", [torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize("binary_str, binary_op", binary_ops)
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_binary_multi_output(
|
||||
@@ -463,7 +495,8 @@ def test_gemm_fusion_binary_multi_output(
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_return_acc():
|
||||
@@ -501,7 +534,8 @@ def test_gemm_fusion_return_acc():
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_acc_as_multiple_input():
|
||||
@@ -562,7 +596,8 @@ def test_gemm_fusion_acc_as_multiple_input():
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_matmul_input_as_aux():
|
||||
@@ -627,7 +662,8 @@ def test_gemm_fusion_matmul_input_as_aux():
|
||||
"ab_dtype, c_dtype, d_dtype, accumulator_type", base_data_types()
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_alpha_beta(
|
||||
@@ -669,7 +705,69 @@ def test_gemm_alpha_beta(
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_alpha_beta_fake_tensor(fixture_toggle_tvm_ffi):
|
||||
import torch._functorch.config
|
||||
|
||||
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = False
|
||||
|
||||
M, N, K, L = 256, 512, 128, 2
|
||||
ab_dtype = torch.float16
|
||||
c_dtype = torch.float32
|
||||
d_dtype = torch.bfloat16
|
||||
accumulator_type = torch.float16
|
||||
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode():
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
def epi(accum, C, alpha, beta):
|
||||
D = alpha * accum + beta * C
|
||||
return D
|
||||
|
||||
alpha = 0.5
|
||||
beta = 0.5
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(
|
||||
epi, C=C, alpha=alpha, beta=beta, D=D
|
||||
)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
|
||||
A_real = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B_real = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
C_real = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
D_real = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
for a, b in [(0.5, 0.5), (1.0, 0.0), (0.0, 1.0)]:
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(
|
||||
epi, C=C_real, alpha=a, beta=b, D=D_real
|
||||
)
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A_real,
|
||||
B=B_real,
|
||||
out=D_real,
|
||||
accumulator_type=accumulator_type,
|
||||
epilogue=epi_args,
|
||||
)
|
||||
kernel.run(args)
|
||||
reference = epi(A_real @ B_real, C_real, a, b)
|
||||
torch.testing.assert_close(D_real, reference.to(D_real.dtype))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_big_epi(fixture_toggle_tvm_ffi):
|
||||
@@ -798,7 +896,8 @@ def test_gemm_big_epi(fixture_toggle_tvm_ffi):
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_not_available(fixture_toggle_tvm_ffi):
|
||||
|
||||
@@ -209,3 +209,35 @@ def test_gemm_sm80_layouts(
|
||||
pytest.fail(
|
||||
f"Kernel {idx+1}/{len(kernels_to_test)} ({kernel.metadata.kernel_name}) failed for layout {layout_A}{layout_B}{layout_C}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def test_gemm_sm80_fake_tensor(fixture_toggle_tvm_ffi):
|
||||
import torch._functorch.config
|
||||
|
||||
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = False
|
||||
|
||||
M, N, K, L = 256, 512, 128, 2
|
||||
ab_dtype = torch.float16
|
||||
c_dtype = torch.float16
|
||||
accumulator_type = torch.float32
|
||||
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode():
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
|
||||
|
||||
fake_args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
|
||||
kernels = cutlass_api.get_kernels(fake_args, cc=80)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
compiled_artifact = kernel.compile(fake_args)
|
||||
|
||||
A_real = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B_real = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
D_real = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
|
||||
args = cutlass_api.arguments.GemmArguments(A_real, B_real, D_real, accumulator_type)
|
||||
kernel.run(args, compiled_artifact=compiled_artifact)
|
||||
|
||||
reference = A_real @ B_real
|
||||
torch.testing.assert_close(D_real, reference.to(D_real.dtype))
|
||||
|
||||
Reference in New Issue
Block a user