Fix batch adding for EFC

This commit is contained in:
jkosaian
2025-12-16 14:08:23 -08:00
parent ead2fbfe13
commit dfcb55de16
2 changed files with 45 additions and 0 deletions

View File

@@ -251,6 +251,15 @@ class PersistentDenseGemmEFCKernelImpl:
a = cute.make_tensor(a.iterator, cute.select(a.layout, [1, 2, 0]))
# B: (L, K, N) -> (N, K, L)
b = cute.make_tensor(b.iterator, cute.select(b.layout, [2, 1, 0]))
# Add batch mode to epilogue parameters
supplemental_parameters = (
add_batch_mode(t)
if isinstance(t, cute.Tensor)
else t
for t in supplemental_parameters
)
# epilogue tensors: (L, M, N) -> (M, N, L)
supplemental_parameters = (
cute.make_tensor(t.iterator, cute.select(t.layout, [1, 2, 0]))

View File

@@ -100,6 +100,42 @@ def test_gemm_sm100(
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.parametrize(
"use_tvm_ffi",
[True, False],
)
@pytest.mark.skipif(
not is_device_cc_supported({100})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_sm100_2d(use_tvm_ffi: bool):
ab_dtype = torch.float16
c_dtype = torch.float16
accumulator_type = torch.float32
M = 256
N = 512
K = 128
A = torch.randint(-1, 2, (M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (K, N), device="cuda", dtype=ab_dtype)
D = torch.empty((M, N), device="cuda", dtype=c_dtype)
GlobalOptions().use_tvm_ffi = use_tvm_ffi
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernel = kernels[0]
logger.debug(f"Picked kernel: {kernel.metadata.kernel_name}")
logger.debug(f"Kernel metadata:\n{pformat(kernel.metadata)}")
kernel.run(args)
reference = A @ B
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.skipif(
not is_device_cc_supported({100})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a"]),