reorganize examples for tvm-ffi

This commit is contained in:
Fung Xie
2025-11-25 18:41:31 -08:00
parent 739fffce27
commit afe2f71522
3 changed files with 33 additions and 78 deletions

View File

@@ -34,6 +34,7 @@ import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
from cutlass.cute.runtime import from_dlpack
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
@@ -1704,62 +1705,6 @@ def bmm(
gemm_op(a, b, c, max_active_clusters, stream, epilogue_op)
def compile_bmm(
gemm_op: PersistentDenseGemmKernel,
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
max_active_clusters: cutlass.Constexpr,
stream: cuda.CUstream,
epilogue_op: cutlass.Constexpr = lambda x: x,
):
from cutlass.cute.runtime import make_fake_compact_tensor
a_shape = (cute.sym_int(), cute.sym_int(divisibility=16), cute.sym_int())
b_shape = (cute.sym_int(), cute.sym_int(divisibility=16), cute.sym_int())
c_shape = (cute.sym_int(), cute.sym_int(divisibility=16), cute.sym_int())
if a_major == "k":
a_order = (2, 1, 0) # k is leading dimension
elif a_major == "m":
a_order = (2, 0, 1) # m is leading dimension
if b_major == "n":
b_order = (2, 1, 0) # n is leading dimension
elif b_major == "k":
b_order = (2, 0, 1) # k is leading dimension
if c_major == "n":
c_order = (2, 1, 0) # n is leading dimension
elif c_major == "m":
c_order = (2, 0, 1) # m is leading dimension
a = make_fake_compact_tensor(
a_dtype, a_shape, stride_order=a_order, assumed_align=16
)
b = make_fake_compact_tensor(
b_dtype, b_shape, stride_order=b_order, assumed_align=16
)
c = make_fake_compact_tensor(
c_dtype, c_shape, stride_order=c_order, assumed_align=16
)
return cute.compile(
bmm,
gemm_op,
a,
b,
c,
max_active_clusters,
stream,
epilogue_op,
options="--enable-tvm-ffi",
)
def prepare_tensors(
mnkl: Tuple[int, int, int, int],
ab_dtype: Type[cutlass.Numeric],
@@ -1878,16 +1823,15 @@ def run(
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
print(f"Use TVM FFI")
import torch
from cutlass.torch import dtype as torch_dtype
# Build GEMM object
gemm = PersistentDenseGemmKernel(
gemm_op = PersistentDenseGemmKernel(
acc_dtype, use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, use_tma_store
)
can_implement = gemm.can_implement(
can_implement = gemm_op.can_implement(
mnkl, ab_dtype, c_dtype, a_major, b_major, c_major
)
if not can_implement:
@@ -1910,24 +1854,32 @@ def run(
cluster_shape_mn[0] * cluster_shape_mn[1]
)
compiled_fn = compile_bmm(
gemm,
ab_dtype,
ab_dtype,
c_dtype,
a_major,
b_major,
c_major,
max_active_clusters,
current_stream,
)
# Run and verify BMM with torch
a, b, c = prepare_tensors(mnkl, ab_dtype, c_dtype, a_major, b_major, c_major)
# Leading dim is 2
leading_dim_a = 2 if a_major == "k" else 1
leading_dim_b = 1 if b_major == "k" else 2
leading_dim_c = 2 if c_major == "n" else 1
a_ = from_dlpack(a).mark_layout_dynamic(leading_dim=leading_dim_a)
b_ = from_dlpack(b).mark_layout_dynamic(leading_dim=leading_dim_b)
c_ = from_dlpack(c).mark_layout_dynamic(leading_dim=leading_dim_c)
compiled_fn = cute.compile(
bmm,
gemm_op,
a_,
b_,
c_,
max_active_clusters,
current_stream,
epilogue_op=lambda x: x,
)
if not skip_ref_check:
# Use small random number for deterministic result for reference check
compiled_fn(a, b, c, torch_stream)
compiled_fn(a_, b_, c_, current_stream)
# Manually quantize to be comparable
ref = (
@@ -1953,7 +1905,10 @@ def run(
c_major,
init_random=not init_normal,
)
return testing.JitArguments(a, b, c, torch_stream)
a_ = from_dlpack(a).mark_layout_dynamic(leading_dim=leading_dim_a)
b_ = from_dlpack(b).mark_layout_dynamic(leading_dim=leading_dim_b)
c_ = from_dlpack(c).mark_layout_dynamic(leading_dim=leading_dim_c)
return testing.JitArguments(a_, b_, c_, current_stream)
workspace_count = 1
if use_cold_l2:

View File

@@ -22,26 +22,26 @@ def run():
shape = (3, 4)
a = make_fake_compact_tensor(cutlass.Float16, (3, 4), stride_order=(1, 0))
cute.compile(print_tensor_type, a)
cute.compile(print_tensor_type, a, options="--enable-tvm-ffi")
# 32-bit symbolic integer with divisibility 8
shape = (3, cute.sym_int32(divisibility=8))
a = make_fake_compact_tensor(cutlass.Float16, shape, stride_order=(1, 0))
cute.compile(print_tensor_type, a)
cute.compile(print_tensor_type, a, options="--enable-tvm-ffi")
# with static stride
a = make_fake_tensor(cutlass.Float16, shape, stride=(4, 1))
cute.compile(print_tensor_type, a)
cute.compile(print_tensor_type, a, options="--enable-tvm-ffi")
# with dynamic stride using 32bit integer
stride = (cute.sym_int32(divisibility=8), 1)
a = make_fake_tensor(cutlass.Float16, shape, stride=stride)
cute.compile(print_tensor_type, a)
cute.compile(print_tensor_type, a, options="--enable-tvm-ffi")
# with dynamic stride using 64bit integer
stride = (cute.sym_int64(divisibility=8), 1)
a = make_fake_tensor(cutlass.Float16, shape, stride=stride)
cute.compile(print_tensor_type, a)
cute.compile(print_tensor_type, a, options="--enable-tvm-ffi")
if __name__ == "__main__":