enhanced the example for tvm-ffi

This commit is contained in:
Fung Xie
2025-11-25 20:56:24 -08:00
parent b9154d65b3
commit 2664cac685

View File

@@ -70,85 +70,102 @@ def bmm(
def compile_bmm_dynamic_layout():
from cutlass.cute.runtime import make_fake_compact_tensor
m = cute.sym_int()
n = cute.sym_int(divisibility=16)
k = cute.sym_int(divisibility=16)
l = cute.sym_int()
# Contiguous on K
a_shape = (cute.sym_int(), cute.sym_int(), cute.sym_int(divisibility=16))
# Contiguous on N
b_shape = (cute.sym_int(), cute.sym_int(), cute.sym_int(divisibility=16))
# Contiguous on N
c_shape = (cute.sym_int(), cute.sym_int(), cute.sym_int(divisibility=16))
fake_a = make_fake_compact_tensor(
cutlass.Float16, a_shape, stride_order=(2, 1, 0), assumed_align=16
cutlass.Float16, (l, m, k), stride_order=(2, 1, 0), assumed_align=16
)
# Contiguous on N
fake_b = make_fake_compact_tensor(
cutlass.Float16, b_shape, stride_order=(2, 1, 0), assumed_align=16
cutlass.Float16, (l, k, n), stride_order=(2, 1, 0), assumed_align=16
)
# Contiguous on N
fake_c = make_fake_compact_tensor(
cutlass.Float16, c_shape, stride_order=(2, 1, 0), assumed_align=16
cutlass.Float16, (l, m, n), stride_order=(2, 1, 0), assumed_align=16
)
compiled_fn = cute.compile(bmm, fake_a, fake_b, fake_c, options="--enable-tvm-ffi")
return compiled_fn
def compile_bmm_static_layout(a, b, c):
from cutlass.cute.runtime import make_fake_tensor, make_fake_compact_tensor
# fake_a = make_fake_tensor(cutlass.Float16, a.shape, a.stride(), assumed_align=16)
# fake_b = make_fake_tensor(cutlass.Float16, b.shape, b.stride(), assumed_align=16)
# fake_c = make_fake_tensor(cutlass.Float16, c.shape, c.stride(), assumed_align=16)
def compile_bmm_static_layout(m, n, k, l):
from cutlass.cute.runtime import make_fake_compact_tensor
fake_a = make_fake_compact_tensor(
cutlass.Float16, a.shape, stride_order=(2, 1, 0), assumed_align=16
cutlass.Float16, (l, m, k), stride_order=(2, 1, 0), assumed_align=16
)
fake_b = make_fake_compact_tensor(
cutlass.Float16, b.shape, stride_order=(2, 1, 0), assumed_align=16
cutlass.Float16, (l, k, n), stride_order=(2, 1, 0), assumed_align=16
)
fake_c = make_fake_compact_tensor(
cutlass.Float16, c.shape, stride_order=(2, 1, 0), assumed_align=16
cutlass.Float16, (l, m, n), stride_order=(2, 1, 0), assumed_align=16
)
compiled_fn = cute.compile(bmm, fake_a, fake_b, fake_c, options="--enable-tvm-ffi")
return compiled_fn
def run_bmm_and_verify(compiled_fn, a, b, c):
def run_bmm_and_verify(compiled_fn, m, n, k, l):
torch.manual_seed(1112)
# pass in torch tensor as input
compiled_fn(a, b, c)
torch.cuda.synchronize()
# measure the launch overhead of tvm ffi function
repeat = 100
start_time = time.time()
for i in range(repeat):
compiled_fn(a, b, c)
end_time = time.time()
print(
f"Launch overhead of tvm ffi function: {(end_time - start_time) / repeat} seconds"
)
ref = torch.bmm(a, b)
torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05)
print("\n[DSL INFO] Results verified successfully!")
print(f"First few elements of result: \n{c[:3, :3, :3]}")
if __name__ == "__main__":
m, n, k, l = (512, 512, 256, 1)
a = torch.randn(l, m, k, dtype=torch.float16, device="cuda")
b = torch.randn(l, k, n, dtype=torch.float16, device="cuda")
c = torch.randn(l, m, n, dtype=torch.float16, device="cuda")
print("Input tensor shapes:")
print("[Runtime INFO] Input tensor shapes:")
print(f"a: {a.shape=}, {a.stride()=}, {a.dtype=}")
print(f"b: {b.shape=}, {b.stride()=}, {b.dtype=}")
print(f"c: {c.shape=}, {c.stride()=}, {c.dtype=}\n")
compiled_fn = compile_bmm_dynamic_layout()
run_bmm_and_verify(compiled_fn, a, b, c)
# pass in torch tensor as input
compiled_fn(a, b, c)
torch.cuda.synchronize()
compiled_fn = compile_bmm_static_layout(a, b, c)
run_bmm_and_verify(compiled_fn, a, b, c)
ref = torch.bmm(a, b)
torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05)
print("[Runtime INFO] Verification successful!")
print(f" First few elements of result: \n{c[:3, :3, :3]}")
if __name__ == "__main__":
m, n, k, l = (512, 512, 256, 2)
compiled_fn_dynamic = compile_bmm_dynamic_layout()
run_bmm_and_verify(compiled_fn_dynamic, m, n, k, l)
compiled_fn_static = compile_bmm_static_layout(m, n, k, l)
run_bmm_and_verify(compiled_fn_static, m, n, k, l)
# Error Check:
# 1. mis-matched tensor dim raise error
a = torch.randn(l, m, k, dtype=torch.float16, device="cuda")
b = torch.randn(l, 2 * k, n, dtype=torch.float16, device="cuda")
c = torch.randn(l, m, n, dtype=torch.float16, device="cuda")
try:
compiled_fn_dynamic(a, b, c)
except Exception as e:
print(f"\n[Runtime Error]: {e}")
# 2. mis-matched divisibility
a = torch.randn(l, m, k + 1, dtype=torch.float16, device="cuda")
b = torch.randn(l, k + 1, n, dtype=torch.float16, device="cuda")
c = torch.randn(l, m, n, dtype=torch.float16, device="cuda")
try:
compiled_fn_dynamic(a, b, c)
except Exception as e:
print(f"\n[Runtime Error]: {e}")
# 3. mis-matched static shape constraint
a = torch.randn(l * 2, m, k, dtype=torch.float16, device="cuda")
b = torch.randn(l * 2, k, n, dtype=torch.float16, device="cuda")
c = torch.randn(l * 2, m, n, dtype=torch.float16, device="cuda")
try:
compiled_fn_static(a, b, c)
except Exception as e:
print(f"\n[Runtime Error]: {e}")