mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
Fix batch adding for EFC
This commit is contained in:
@@ -251,6 +251,15 @@ class PersistentDenseGemmEFCKernelImpl:
|
||||
a = cute.make_tensor(a.iterator, cute.select(a.layout, [1, 2, 0]))
|
||||
# B: (L, K, N) -> (N, K, L)
|
||||
b = cute.make_tensor(b.iterator, cute.select(b.layout, [2, 1, 0]))
|
||||
|
||||
# Add batch mode to epilogue parameters
|
||||
supplemental_parameters = (
|
||||
add_batch_mode(t)
|
||||
if isinstance(t, cute.Tensor)
|
||||
else t
|
||||
for t in supplemental_parameters
|
||||
)
|
||||
|
||||
# epilogue tensors: (L, M, N) -> (M, N, L)
|
||||
supplemental_parameters = (
|
||||
cute.make_tensor(t.iterator, cute.select(t.layout, [1, 2, 0]))
|
||||
|
||||
@@ -100,6 +100,42 @@ def test_gemm_sm100(
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"use_tvm_ffi",
|
||||
[True, False],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_sm100_2d(use_tvm_ffi: bool):
|
||||
ab_dtype = torch.float16
|
||||
c_dtype = torch.float16
|
||||
accumulator_type = torch.float32
|
||||
M = 256
|
||||
N = 512
|
||||
K = 128
|
||||
A = torch.randint(-1, 2, (M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (K, N), device="cuda", dtype=ab_dtype)
|
||||
D = torch.empty((M, N), device="cuda", dtype=c_dtype)
|
||||
|
||||
GlobalOptions().use_tvm_ffi = use_tvm_ffi
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
|
||||
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
logger.debug(f"Picked kernel: {kernel.metadata.kernel_name}")
|
||||
logger.debug(f"Kernel metadata:\n{pformat(kernel.metadata)}")
|
||||
kernel.run(args)
|
||||
|
||||
reference = A @ B
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a"]),
|
||||
|
||||
Reference in New Issue
Block a user