mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
enhanced the example for tvm-ffi
This commit is contained in:
@@ -70,85 +70,102 @@ def bmm(
|
||||
def compile_bmm_dynamic_layout():
|
||||
from cutlass.cute.runtime import make_fake_compact_tensor
|
||||
|
||||
m = cute.sym_int()
|
||||
n = cute.sym_int(divisibility=16)
|
||||
k = cute.sym_int(divisibility=16)
|
||||
l = cute.sym_int()
|
||||
|
||||
# Contiguous on K
|
||||
a_shape = (cute.sym_int(), cute.sym_int(), cute.sym_int(divisibility=16))
|
||||
# Contiguous on N
|
||||
b_shape = (cute.sym_int(), cute.sym_int(), cute.sym_int(divisibility=16))
|
||||
# Contiguous on N
|
||||
c_shape = (cute.sym_int(), cute.sym_int(), cute.sym_int(divisibility=16))
|
||||
|
||||
fake_a = make_fake_compact_tensor(
|
||||
cutlass.Float16, a_shape, stride_order=(2, 1, 0), assumed_align=16
|
||||
cutlass.Float16, (l, m, k), stride_order=(2, 1, 0), assumed_align=16
|
||||
)
|
||||
# Contiguous on N
|
||||
fake_b = make_fake_compact_tensor(
|
||||
cutlass.Float16, b_shape, stride_order=(2, 1, 0), assumed_align=16
|
||||
cutlass.Float16, (l, k, n), stride_order=(2, 1, 0), assumed_align=16
|
||||
)
|
||||
# Contiguous on N
|
||||
fake_c = make_fake_compact_tensor(
|
||||
cutlass.Float16, c_shape, stride_order=(2, 1, 0), assumed_align=16
|
||||
cutlass.Float16, (l, m, n), stride_order=(2, 1, 0), assumed_align=16
|
||||
)
|
||||
|
||||
compiled_fn = cute.compile(bmm, fake_a, fake_b, fake_c, options="--enable-tvm-ffi")
|
||||
return compiled_fn
|
||||
|
||||
|
||||
def compile_bmm_static_layout(a, b, c):
|
||||
from cutlass.cute.runtime import make_fake_tensor, make_fake_compact_tensor
|
||||
|
||||
# fake_a = make_fake_tensor(cutlass.Float16, a.shape, a.stride(), assumed_align=16)
|
||||
# fake_b = make_fake_tensor(cutlass.Float16, b.shape, b.stride(), assumed_align=16)
|
||||
# fake_c = make_fake_tensor(cutlass.Float16, c.shape, c.stride(), assumed_align=16)
|
||||
def compile_bmm_static_layout(m, n, k, l):
|
||||
from cutlass.cute.runtime import make_fake_compact_tensor
|
||||
|
||||
fake_a = make_fake_compact_tensor(
|
||||
cutlass.Float16, a.shape, stride_order=(2, 1, 0), assumed_align=16
|
||||
cutlass.Float16, (l, m, k), stride_order=(2, 1, 0), assumed_align=16
|
||||
)
|
||||
fake_b = make_fake_compact_tensor(
|
||||
cutlass.Float16, b.shape, stride_order=(2, 1, 0), assumed_align=16
|
||||
cutlass.Float16, (l, k, n), stride_order=(2, 1, 0), assumed_align=16
|
||||
)
|
||||
fake_c = make_fake_compact_tensor(
|
||||
cutlass.Float16, c.shape, stride_order=(2, 1, 0), assumed_align=16
|
||||
cutlass.Float16, (l, m, n), stride_order=(2, 1, 0), assumed_align=16
|
||||
)
|
||||
|
||||
compiled_fn = cute.compile(bmm, fake_a, fake_b, fake_c, options="--enable-tvm-ffi")
|
||||
return compiled_fn
|
||||
|
||||
|
||||
def run_bmm_and_verify(compiled_fn, a, b, c):
|
||||
def run_bmm_and_verify(compiled_fn, m, n, k, l):
|
||||
torch.manual_seed(1112)
|
||||
|
||||
# pass in torch tensor as input
|
||||
compiled_fn(a, b, c)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# measure the launch overhead of tvm ffi function
|
||||
repeat = 100
|
||||
start_time = time.time()
|
||||
for i in range(repeat):
|
||||
compiled_fn(a, b, c)
|
||||
end_time = time.time()
|
||||
print(
|
||||
f"Launch overhead of tvm ffi function: {(end_time - start_time) / repeat} seconds"
|
||||
)
|
||||
|
||||
ref = torch.bmm(a, b)
|
||||
torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05)
|
||||
print("\n[DSL INFO] Results verified successfully!")
|
||||
print(f"First few elements of result: \n{c[:3, :3, :3]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
m, n, k, l = (512, 512, 256, 1)
|
||||
|
||||
a = torch.randn(l, m, k, dtype=torch.float16, device="cuda")
|
||||
b = torch.randn(l, k, n, dtype=torch.float16, device="cuda")
|
||||
c = torch.randn(l, m, n, dtype=torch.float16, device="cuda")
|
||||
|
||||
print("Input tensor shapes:")
|
||||
print("[Runtime INFO] Input tensor shapes:")
|
||||
print(f"a: {a.shape=}, {a.stride()=}, {a.dtype=}")
|
||||
print(f"b: {b.shape=}, {b.stride()=}, {b.dtype=}")
|
||||
print(f"c: {c.shape=}, {c.stride()=}, {c.dtype=}\n")
|
||||
|
||||
compiled_fn = compile_bmm_dynamic_layout()
|
||||
run_bmm_and_verify(compiled_fn, a, b, c)
|
||||
# pass in torch tensor as input
|
||||
compiled_fn(a, b, c)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
compiled_fn = compile_bmm_static_layout(a, b, c)
|
||||
run_bmm_and_verify(compiled_fn, a, b, c)
|
||||
ref = torch.bmm(a, b)
|
||||
torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05)
|
||||
print("[Runtime INFO] Verification successful!")
|
||||
print(f" First few elements of result: \n{c[:3, :3, :3]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
m, n, k, l = (512, 512, 256, 2)
|
||||
|
||||
compiled_fn_dynamic = compile_bmm_dynamic_layout()
|
||||
run_bmm_and_verify(compiled_fn_dynamic, m, n, k, l)
|
||||
|
||||
compiled_fn_static = compile_bmm_static_layout(m, n, k, l)
|
||||
run_bmm_and_verify(compiled_fn_static, m, n, k, l)
|
||||
|
||||
# Error Check:
|
||||
# 1. mis-matched tensor dim raise error
|
||||
a = torch.randn(l, m, k, dtype=torch.float16, device="cuda")
|
||||
b = torch.randn(l, 2 * k, n, dtype=torch.float16, device="cuda")
|
||||
c = torch.randn(l, m, n, dtype=torch.float16, device="cuda")
|
||||
try:
|
||||
compiled_fn_dynamic(a, b, c)
|
||||
except Exception as e:
|
||||
print(f"\n[Runtime Error]: {e}")
|
||||
|
||||
# 2. mis-matched divisibility
|
||||
a = torch.randn(l, m, k + 1, dtype=torch.float16, device="cuda")
|
||||
b = torch.randn(l, k + 1, n, dtype=torch.float16, device="cuda")
|
||||
c = torch.randn(l, m, n, dtype=torch.float16, device="cuda")
|
||||
|
||||
try:
|
||||
compiled_fn_dynamic(a, b, c)
|
||||
except Exception as e:
|
||||
print(f"\n[Runtime Error]: {e}")
|
||||
|
||||
# 3. mis-matched static shape constraint
|
||||
a = torch.randn(l * 2, m, k, dtype=torch.float16, device="cuda")
|
||||
b = torch.randn(l * 2, k, n, dtype=torch.float16, device="cuda")
|
||||
c = torch.randn(l * 2, m, n, dtype=torch.float16, device="cuda")
|
||||
|
||||
try:
|
||||
compiled_fn_static(a, b, c)
|
||||
except Exception as e:
|
||||
print(f"\n[Runtime Error]: {e}")
|
||||
|
||||
Reference in New Issue
Block a user