mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
reorganize examples for tvm-ffi
This commit is contained in:
@@ -34,6 +34,7 @@ import cuda.bindings.driver as cuda
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
||||
@@ -1704,62 +1705,6 @@ def bmm(
|
||||
gemm_op(a, b, c, max_active_clusters, stream, epilogue_op)
|
||||
|
||||
|
||||
def compile_bmm(
|
||||
gemm_op: PersistentDenseGemmKernel,
|
||||
a_dtype: Type[cutlass.Numeric],
|
||||
b_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
max_active_clusters: cutlass.Constexpr,
|
||||
stream: cuda.CUstream,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
):
|
||||
from cutlass.cute.runtime import make_fake_compact_tensor
|
||||
|
||||
a_shape = (cute.sym_int(), cute.sym_int(divisibility=16), cute.sym_int())
|
||||
b_shape = (cute.sym_int(), cute.sym_int(divisibility=16), cute.sym_int())
|
||||
c_shape = (cute.sym_int(), cute.sym_int(divisibility=16), cute.sym_int())
|
||||
|
||||
if a_major == "k":
|
||||
a_order = (2, 1, 0) # k is leading dimension
|
||||
elif a_major == "m":
|
||||
a_order = (2, 0, 1) # m is leading dimension
|
||||
|
||||
if b_major == "n":
|
||||
b_order = (2, 1, 0) # n is leading dimension
|
||||
elif b_major == "k":
|
||||
b_order = (2, 0, 1) # k is leading dimension
|
||||
|
||||
if c_major == "n":
|
||||
c_order = (2, 1, 0) # n is leading dimension
|
||||
elif c_major == "m":
|
||||
c_order = (2, 0, 1) # m is leading dimension
|
||||
|
||||
a = make_fake_compact_tensor(
|
||||
a_dtype, a_shape, stride_order=a_order, assumed_align=16
|
||||
)
|
||||
b = make_fake_compact_tensor(
|
||||
b_dtype, b_shape, stride_order=b_order, assumed_align=16
|
||||
)
|
||||
c = make_fake_compact_tensor(
|
||||
c_dtype, c_shape, stride_order=c_order, assumed_align=16
|
||||
)
|
||||
|
||||
return cute.compile(
|
||||
bmm,
|
||||
gemm_op,
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
epilogue_op,
|
||||
options="--enable-tvm-ffi",
|
||||
)
|
||||
|
||||
|
||||
def prepare_tensors(
|
||||
mnkl: Tuple[int, int, int, int],
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
@@ -1878,16 +1823,15 @@ def run(
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||||
print(f"Use TVM FFI")
|
||||
|
||||
import torch
|
||||
from cutlass.torch import dtype as torch_dtype
|
||||
|
||||
# Build GEMM object
|
||||
gemm = PersistentDenseGemmKernel(
|
||||
gemm_op = PersistentDenseGemmKernel(
|
||||
acc_dtype, use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, use_tma_store
|
||||
)
|
||||
can_implement = gemm.can_implement(
|
||||
can_implement = gemm_op.can_implement(
|
||||
mnkl, ab_dtype, c_dtype, a_major, b_major, c_major
|
||||
)
|
||||
if not can_implement:
|
||||
@@ -1910,24 +1854,32 @@ def run(
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1]
|
||||
)
|
||||
|
||||
compiled_fn = compile_bmm(
|
||||
gemm,
|
||||
ab_dtype,
|
||||
ab_dtype,
|
||||
c_dtype,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
max_active_clusters,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
# Run and verify BMM with torch
|
||||
a, b, c = prepare_tensors(mnkl, ab_dtype, c_dtype, a_major, b_major, c_major)
|
||||
|
||||
# Leading dim is 2
|
||||
leading_dim_a = 2 if a_major == "k" else 1
|
||||
leading_dim_b = 1 if b_major == "k" else 2
|
||||
leading_dim_c = 2 if c_major == "n" else 1
|
||||
|
||||
a_ = from_dlpack(a).mark_layout_dynamic(leading_dim=leading_dim_a)
|
||||
b_ = from_dlpack(b).mark_layout_dynamic(leading_dim=leading_dim_b)
|
||||
c_ = from_dlpack(c).mark_layout_dynamic(leading_dim=leading_dim_c)
|
||||
|
||||
compiled_fn = cute.compile(
|
||||
bmm,
|
||||
gemm_op,
|
||||
a_,
|
||||
b_,
|
||||
c_,
|
||||
max_active_clusters,
|
||||
current_stream,
|
||||
epilogue_op=lambda x: x,
|
||||
)
|
||||
|
||||
if not skip_ref_check:
|
||||
# Use small random number for deterministic result for reference check
|
||||
compiled_fn(a, b, c, torch_stream)
|
||||
compiled_fn(a_, b_, c_, current_stream)
|
||||
|
||||
# Manually quantize to be comparable
|
||||
ref = (
|
||||
@@ -1953,7 +1905,10 @@ def run(
|
||||
c_major,
|
||||
init_random=not init_normal,
|
||||
)
|
||||
return testing.JitArguments(a, b, c, torch_stream)
|
||||
a_ = from_dlpack(a).mark_layout_dynamic(leading_dim=leading_dim_a)
|
||||
b_ = from_dlpack(b).mark_layout_dynamic(leading_dim=leading_dim_b)
|
||||
c_ = from_dlpack(c).mark_layout_dynamic(leading_dim=leading_dim_c)
|
||||
return testing.JitArguments(a_, b_, c_, current_stream)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
|
||||
@@ -22,26 +22,26 @@ def run():
|
||||
|
||||
shape = (3, 4)
|
||||
a = make_fake_compact_tensor(cutlass.Float16, (3, 4), stride_order=(1, 0))
|
||||
cute.compile(print_tensor_type, a)
|
||||
cute.compile(print_tensor_type, a, options="--enable-tvm-ffi")
|
||||
|
||||
# 32-bit symbolic integer with divisibility 8
|
||||
shape = (3, cute.sym_int32(divisibility=8))
|
||||
a = make_fake_compact_tensor(cutlass.Float16, shape, stride_order=(1, 0))
|
||||
cute.compile(print_tensor_type, a)
|
||||
cute.compile(print_tensor_type, a, options="--enable-tvm-ffi")
|
||||
|
||||
# with static stride
|
||||
a = make_fake_tensor(cutlass.Float16, shape, stride=(4, 1))
|
||||
cute.compile(print_tensor_type, a)
|
||||
cute.compile(print_tensor_type, a, options="--enable-tvm-ffi")
|
||||
|
||||
# with dynamic stride using 32bit integer
|
||||
stride = (cute.sym_int32(divisibility=8), 1)
|
||||
a = make_fake_tensor(cutlass.Float16, shape, stride=stride)
|
||||
cute.compile(print_tensor_type, a)
|
||||
cute.compile(print_tensor_type, a, options="--enable-tvm-ffi")
|
||||
|
||||
# with dynamic stride using 64bit integer
|
||||
stride = (cute.sym_int64(divisibility=8), 1)
|
||||
a = make_fake_tensor(cutlass.Float16, shape, stride=stride)
|
||||
cute.compile(print_tensor_type, a)
|
||||
cute.compile(print_tensor_type, a, options="--enable-tvm-ffi")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Reference in New Issue
Block a user