mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
Add tutorial fp16_gemm_1 (#2750)
* Add tutorial fp16_gemm_1 * refine * refine * refine * revert changes in fp16_gemm_0.py
This commit is contained in:
527
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_1.py
Normal file
527
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_1.py
Normal file
@@ -0,0 +1,527 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
# This is the second tutorial GEMM. It builds on the first tutorial by adding 2CTA MMA
|
||||
# instructions with a 2x1 cluster.
|
||||
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
The second tutorial GEMM demonstrating a simple kernel implementation in CuTeDSL
|
||||
|
||||
With large tile sizes, it can as well achieve very high performance on 8kx8kx8k problem sizes.
|
||||
Compared with fp16_gemm_0.py, this example adds 2CTA MMA & TMA multicast supports.
|
||||
For fp16_gemm_0.py running at relative high SM frequency, the dram latency will be a potential performance issue.
|
||||
This example can achieve better performance than fp16_gemm_0.py due to:
|
||||
|
||||
1. The 2CTA MMA can reduce B tensor smem size which allows larger ab stages to hide dram latency.
|
||||
|
||||
For both 1CTA & 2CTA, one stage of A tensor smem size is 128x64xsizeof(float16)=16KB
|
||||
Situation for B is different.
|
||||
For 1CTA, one stage of B tensor smem size is 256x64xsizeof(float16)=32KB,
|
||||
while for 2CTA, we can only take half size, i.e. 16KB.
|
||||
So, the maxmimum AB stage for 1CTA is 227 // (16 + 32) = 4, while for 2CTA is 227 // (16 + 16) = 7.
|
||||
The latency hiding capability is 512 * (4 - 1) = 1.5K cycles for 1CTA, while 512 * (7 - 1) = 3K cycles for 2CTA.
|
||||
|
||||
2. The L2 traffic is reduced due to the TMA multicast.
|
||||
|
||||
For a (m, n) cluster shape, the L2 traffic for one tile is 16KB / n + 32KB / m.
|
||||
|
||||
16KB / 1 + 32KB / 2 = 24KB for 2x1 cluster shape
|
||||
16KB / 4 + 32KB / 4 = 12KB for 4x4 cluster shape
|
||||
|
||||
If no TMA multicast enabled, the L2 traffic for one tile is less than 16KB + 32KB = 48KB, which depends on hardware optimization.
|
||||
|
||||
The first one can provide large latency hiding capability while the second one can reduce the data ready time.
|
||||
These two factors should be considered for latency/memory throughput bound cases.
|
||||
|
||||
|
||||
To run this example:
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/blackwell/tutorial_gemm/fp16_gemm_1.py \
|
||||
--mnk 8192,8192,8192
|
||||
|
||||
Constraints for this example:
|
||||
* The problem size of m and n must be divisible by the tile size m & n (256, 256)
|
||||
"""
|
||||
|
||||
io_dtype = cutlass.Float16
|
||||
acc_dtype = cutlass.Float32
|
||||
cluster_shape_mnk = (2, 1, 1)
|
||||
mma_inst_shape_mnk = (256, 256, 16)
|
||||
mma_tiler_mnk = (256, 256, 64)
|
||||
threads_per_cta = 128
|
||||
|
||||
# Pipeline stage configuration
|
||||
ab_stages = 7
|
||||
acc_stage = 1
|
||||
|
||||
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
# each stage has 2 kinds of barrier, i.e. empty & full
|
||||
ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, ab_stages * 2]
|
||||
acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, acc_stage * 2]
|
||||
tmem_dealloc_mbar_ptr: cutlass.Int64
|
||||
tmem_holding_buf: cutlass.Int32
|
||||
|
||||
|
||||
@cute.kernel()
|
||||
def kernel(
|
||||
tiled_mma: cute.TiledMma,
|
||||
tma_atom_a: cute.CopyAtom,
|
||||
mA_mkl: cute.Tensor,
|
||||
tma_atom_b: cute.CopyAtom,
|
||||
mB_nkl: cute.Tensor,
|
||||
mC_mnl: cute.Tensor,
|
||||
a_smem_layout: cute.ComposedLayout,
|
||||
b_smem_layout: cute.ComposedLayout,
|
||||
cta_layout_vmnk: cute.Layout,
|
||||
):
|
||||
# Current thread/warp/block coordinates
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
bidx, bidy, _ = cute.arch.block_idx()
|
||||
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
||||
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
||||
mma_coord_vmnk = (
|
||||
bidx % cute.size(cta_layout_vmnk, mode=[0]),
|
||||
bidx // cute.size(cta_layout_vmnk, mode=[0]),
|
||||
bidy,
|
||||
None,
|
||||
)
|
||||
mma_coord_mnk = mma_coord_vmnk[1:]
|
||||
|
||||
#
|
||||
# 1. Prepare args
|
||||
#
|
||||
|
||||
# Allocate SMEM
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
storage = smem.allocate(SharedStorage)
|
||||
sA = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=a_smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=a_smem_layout.inner,
|
||||
)
|
||||
sB = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=b_smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=b_smem_layout.inner,
|
||||
)
|
||||
|
||||
# Prefetch tma descriptor
|
||||
if warp_idx == 0:
|
||||
cpasync.prefetch_descriptor(tma_atom_a)
|
||||
cpasync.prefetch_descriptor(tma_atom_b)
|
||||
|
||||
# Pipeline configuration
|
||||
num_tma_copy_bytes = (
|
||||
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
|
||||
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
|
||||
) * cute.size(cta_layout_vmnk, mode=[0])
|
||||
num_mcast_ctas_a = cute.size(cta_layout_vmnk.shape[2])
|
||||
num_mcast_ctas_b = cute.size(cta_layout_vmnk.shape[1])
|
||||
num_tma_producer = num_mcast_ctas_a + num_mcast_ctas_b - 1
|
||||
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
|
||||
num_stages=ab_stages,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, num_tma_producer
|
||||
),
|
||||
tx_count=num_tma_copy_bytes,
|
||||
barrier_storage=storage.ab_mbar_ptr.data_ptr(),
|
||||
cta_layout_vmnk=cta_layout_vmnk,
|
||||
).make_participants()
|
||||
acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create(
|
||||
num_stages=acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
cute.size(cta_layout_vmnk, mode=[0]) * threads_per_cta,
|
||||
),
|
||||
barrier_storage=storage.acc_mbar_ptr.data_ptr(),
|
||||
cta_layout_vmnk=cta_layout_vmnk,
|
||||
).make_participants()
|
||||
|
||||
# Partition tensors for MMA and make fragments
|
||||
# (bM, bK, RestK)
|
||||
gA = cute.local_tile(mA_mkl, mma_tiler_mnk, mma_coord_mnk, proj=(1, None, 1))
|
||||
# (bN, bK, RestK)
|
||||
gB = cute.local_tile(mB_nkl, mma_tiler_mnk, mma_coord_mnk, proj=(None, 1, 1))
|
||||
# (bM, bN)
|
||||
gC = cute.local_tile(mC_mnl, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None))
|
||||
thr_mma = tiled_mma.get_slice(mma_coord_vmnk[0])
|
||||
# (MMA, MMA_M, MMA_K)
|
||||
tCgA = thr_mma.partition_A(gA)
|
||||
# (MMA, MMA_N, MMA_K)
|
||||
tCgB = thr_mma.partition_B(gB)
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
# (MMA, MMA_M, MMA_K)
|
||||
tCrA = tiled_mma.make_fragment_A(sA)
|
||||
# (MMA, MMA_N, MMA_K)
|
||||
tCrB = tiled_mma.make_fragment_B(sB)
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tiled_mma.make_fragment_C(acc_shape)
|
||||
|
||||
# Partition tensors for TMA; This requires the tensors partitioned for MMA
|
||||
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
cta_in_cluster_coord_vmnk[2],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
|
||||
cute.group_modes(sA, 0, 3),
|
||||
cute.group_modes(tCgA, 0, 3),
|
||||
)
|
||||
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
cta_in_cluster_coord_vmnk[1],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
|
||||
cute.group_modes(sB, 0, 3),
|
||||
cute.group_modes(tCgB, 0, 3),
|
||||
)
|
||||
tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2
|
||||
)
|
||||
tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1
|
||||
)
|
||||
|
||||
# Allocate TMEM and swap the pointer in tCtAcc
|
||||
tmem_alloc_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=1,
|
||||
num_threads=threads_per_cta,
|
||||
)
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buf,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
is_two_cta=cute.size(cta_layout_vmnk, mode=[0]) > 1,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
|
||||
)
|
||||
num_tmem_cols = 512
|
||||
tmem.allocate(num_tmem_cols)
|
||||
|
||||
# CTA-wide sync before retrieving the pointer to the start of the allocated TMEM
|
||||
# Only warp 0 does the allocation so we need to sync before retrieving the TMEM start address
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(acc_dtype)
|
||||
# Swap the pointer in tCtAcc
|
||||
tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout)
|
||||
|
||||
subtile_cnt = 4
|
||||
# (EpiTile)
|
||||
epi_tiler = (
|
||||
(cute.size(tCtAcc, mode=[0, 0]), cute.size(tCtAcc, mode=[0, 1]) // subtile_cnt),
|
||||
)
|
||||
# (EpiTile, NumTiles)
|
||||
tCtAcc_epi = cute.zipped_divide(tCtAcc, epi_tiler)
|
||||
# (EpiTile, NumTiles)
|
||||
gC_epi = cute.zipped_divide(tCgC, epi_tiler)
|
||||
|
||||
# Every thread loads 64 x fp32
|
||||
tmem_atom = cute.make_copy_atom(
|
||||
tcgen05.Ld32x32bOp(tcgen05.Repetition.x64),
|
||||
cutlass.Float32,
|
||||
)
|
||||
tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_atom, tCtAcc_epi[None, 0])
|
||||
tmem_thr_copy = tmem_tiled_copy.get_slice(tidx)
|
||||
|
||||
# (TmemCpy,NumTmemCpy,NumTiles)
|
||||
tDtC = tmem_thr_copy.partition_S(tCtAcc_epi)
|
||||
# (TmemCpy,NumTmemCpy,NumTiles)
|
||||
tDgC = tmem_thr_copy.partition_D(gC_epi)
|
||||
|
||||
# (TmemCpy,NumTmemCpy)
|
||||
tCrAcc = cute.make_rmem_tensor_like(tDgC[None, None, 0], acc_dtype)
|
||||
# (TmemCpy,NumTmemCpy)
|
||||
tCrC = cute.make_rmem_tensor_like(tDgC[None, None, 0], io_dtype)
|
||||
|
||||
#
|
||||
# 2. Main loop
|
||||
#
|
||||
is_leader_cta = mma_coord_vmnk[0] == 0
|
||||
num_k_tiles = cute.size(gA, mode=[2])
|
||||
if warp_idx == 0:
|
||||
# Wait for a empty accumulator buffer
|
||||
if is_leader_cta:
|
||||
acc_producer.acquire_and_advance()
|
||||
for _ in cutlass.range(num_k_tiles, prefetch_stages=ab_stages - 2):
|
||||
# Issue TMA loads
|
||||
ab_empty = ab_producer.acquire_and_advance()
|
||||
cute.copy(
|
||||
tma_atom_a,
|
||||
tAgA[(None, ab_empty.count)],
|
||||
tAsA[(None, ab_empty.index)],
|
||||
tma_bar_ptr=ab_empty.barrier,
|
||||
mcast_mask=tma_mcast_mask_a,
|
||||
)
|
||||
cute.copy(
|
||||
tma_atom_b,
|
||||
tBgB[(None, ab_empty.count)],
|
||||
tBsB[(None, ab_empty.index)],
|
||||
tma_bar_ptr=ab_empty.barrier,
|
||||
mcast_mask=tma_mcast_mask_b,
|
||||
)
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
if is_leader_cta:
|
||||
ab_full = ab_consumer.wait_and_advance()
|
||||
# Execute one K-block worth of MMA instructions
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, ab_full.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
ab_full.release()
|
||||
|
||||
# Signal that the accumulator is fully computed
|
||||
if is_leader_cta:
|
||||
acc_producer.commit()
|
||||
|
||||
#
|
||||
# 3. Epilogue
|
||||
#
|
||||
|
||||
# Release TMEM allocation lock
|
||||
tmem.relinquish_alloc_permit()
|
||||
|
||||
# Wait for the accumulator buffer to be full
|
||||
acc_full = acc_consumer.wait_and_advance()
|
||||
# TMEM -> RMEM -> GEMM
|
||||
# Sub-tiling for better instruction-level parallelism
|
||||
for i in cutlass.range(cute.size(tDtC, mode=[2])):
|
||||
cute.copy(tmem_tiled_copy, tDtC[None, None, i], tCrAcc)
|
||||
tCrC.store(tCrAcc.load().to(io_dtype))
|
||||
cute.autovec_copy(tCrC, tDgC[None, None, i])
|
||||
acc_full.release()
|
||||
|
||||
# Ensure used buffers are properly synchronized before producer exit.
|
||||
# This could avoid the invalid dsmem access due to early leading CTA exit.
|
||||
if warp_idx == 0:
|
||||
ab_producer.tail()
|
||||
if is_leader_cta:
|
||||
acc_producer.tail()
|
||||
|
||||
# Deallocate TMEM
|
||||
pipeline.sync(barrier_id=1)
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def host_function(
|
||||
a: cute.Tensor,
|
||||
b: cute.Tensor,
|
||||
c: cute.Tensor,
|
||||
):
|
||||
# Construct tiled MMA
|
||||
op = tcgen05.MmaF16BF16Op(
|
||||
io_dtype,
|
||||
acc_dtype,
|
||||
mma_inst_shape_mnk,
|
||||
tcgen05.CtaGroup.TWO,
|
||||
tcgen05.OperandSource.SMEM,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
)
|
||||
tiled_mma = cute.make_tiled_mma(op)
|
||||
|
||||
# Construct SMEM layouts for A and B
|
||||
a_smem_layout = utils.sm100.make_smem_layout_a(
|
||||
tiled_mma,
|
||||
mma_tiler_mnk,
|
||||
a.element_type,
|
||||
ab_stages,
|
||||
)
|
||||
b_smem_layout = utils.sm100.make_smem_layout_b(
|
||||
tiled_mma,
|
||||
mma_tiler_mnk,
|
||||
b.element_type,
|
||||
ab_stages,
|
||||
)
|
||||
a_smem_layout_one_stage = cute.select(a_smem_layout, mode=[0, 1, 2])
|
||||
b_smem_layout_one_stage = cute.select(b_smem_layout, mode=[0, 1, 2])
|
||||
|
||||
# Construct the VMNK layout
|
||||
cta_layout_mnk = cute.make_layout(cluster_shape_mnk)
|
||||
cta_layout_vmnk = cute.tiled_divide(cta_layout_mnk, (tiled_mma.thr_id,))
|
||||
|
||||
# Construct TMA load atoms
|
||||
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO)
|
||||
a_tma_atom, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
op,
|
||||
a,
|
||||
a_smem_layout_one_stage,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
)
|
||||
b_tma_atom, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
op,
|
||||
b,
|
||||
b_smem_layout_one_stage,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
)
|
||||
|
||||
grid_shape = cute.round_up(
|
||||
cute.ceil_div(
|
||||
(*c.layout.shape, 1), (mma_tiler_mnk[0] // 2, *mma_tiler_mnk[1:])
|
||||
),
|
||||
cluster_shape_mnk,
|
||||
)
|
||||
|
||||
# Pretty prints kernel attributes useful for debugging
|
||||
# print(f"a = {cute.pretty_str(a)}")
|
||||
# print(f"b = {cute.pretty_str(b)}")
|
||||
# print(f"c = {cute.pretty_str(c)}")
|
||||
# print(f"tiled_mma = {cute.pretty_str(tiled_mma)}")
|
||||
# print(f"a_smem_layout = {cute.pretty_str(a_smem_layout)}")
|
||||
# print(f"b_smem_layout = {cute.pretty_str(b_smem_layout)}")
|
||||
# print(f"cta_layout_mnk = {cute.pretty_str(cta_layout_mnk)}")
|
||||
# print(f"cta_layout_vmnk = {cute.pretty_str(cta_layout_vmnk)}")
|
||||
# print(f"a_tma_atom = {cute.pretty_str(a_tma_atom)}")
|
||||
# print(f"b_tma_atom = {cute.pretty_str(b_tma_atom)}")
|
||||
# print(f"a_tma_tensor = {cute.pretty_str(a_tma_tensor)}")
|
||||
# print(f"b_tma_tensor = {cute.pretty_str(b_tma_tensor)}")
|
||||
# cute.printf("grid_shape = {}", grid_shape)
|
||||
|
||||
# Launch the kernel
|
||||
kernel(
|
||||
tiled_mma,
|
||||
a_tma_atom,
|
||||
a_tma_tensor,
|
||||
b_tma_atom,
|
||||
b_tma_tensor,
|
||||
c,
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
cta_layout_vmnk,
|
||||
).launch(
|
||||
grid=grid_shape,
|
||||
block=[threads_per_cta, 1, 1],
|
||||
cluster=cluster_shape_mnk,
|
||||
)
|
||||
|
||||
|
||||
def run_dense_gemm(
|
||||
mnk: Tuple[int, int, int],
|
||||
tolerance: float,
|
||||
):
|
||||
print("===================================================================")
|
||||
print("Running Blackwell fp16 GEMM example 1 with:")
|
||||
print(f" mnk: {mnk}")
|
||||
print(f" tolerance: {tolerance}")
|
||||
print("===================================================================")
|
||||
print()
|
||||
|
||||
m, n, k = mnk
|
||||
torch.manual_seed(1111)
|
||||
|
||||
# Make K-major tensors (torch tensors are row-major)
|
||||
def make_tensors(mn, k, dtype):
|
||||
shape = (mn, k)
|
||||
return (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
.to(device="cuda", dtype=dtype)
|
||||
)
|
||||
|
||||
a = make_tensors(m, k, cutlass_torch.dtype(io_dtype))
|
||||
b = make_tensors(n, k, cutlass_torch.dtype(io_dtype))
|
||||
c = make_tensors(m, n, cutlass_torch.dtype(io_dtype))
|
||||
a_tensor = (
|
||||
from_dlpack(a, assumed_align=32)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=k)
|
||||
)
|
||||
b_tensor = (
|
||||
from_dlpack(b, assumed_align=32)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=k)
|
||||
)
|
||||
c_tensor = (
|
||||
from_dlpack(c, assumed_align=32)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=n)
|
||||
)
|
||||
|
||||
# Entry point to the host JIT function
|
||||
host_function(
|
||||
a_tensor,
|
||||
b_tensor,
|
||||
c_tensor,
|
||||
no_cache=True,
|
||||
)
|
||||
|
||||
# Compute reference result and verify
|
||||
ref = (torch.einsum("mk,nk->mn", a.to(torch.float32), b.to(torch.float32))).cpu()
|
||||
torch.testing.assert_close(
|
||||
c.cpu(), ref.to(cutlass_torch.dtype(io_dtype)), atol=tolerance, rtol=1e-05
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str):
|
||||
try:
|
||||
return [int(x.strip()) for x in s.split(",")]
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format. Expected comma-separated integers."
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("A GPU is required to run this example")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Blackwell fp16 GEMM example 1")
|
||||
parser.add_argument(
|
||||
"--mnk",
|
||||
type=parse_comma_separated_ints,
|
||||
default=[8192, 8192, 8192],
|
||||
help="MNK dimensions (comma-separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if len(args.mnk) != 3:
|
||||
parser.error("--mnk must contain exactly 3 values")
|
||||
if args.mnk[0] % mma_tiler_mnk[0] != 0 or args.mnk[1] % mma_tiler_mnk[1] != 0:
|
||||
parser.error("m n must be divisible by mma_tiler_mn")
|
||||
|
||||
run_dense_gemm(
|
||||
args.mnk,
|
||||
args.tolerance,
|
||||
)
|
||||
print("PASS")
|
||||
Reference in New Issue
Block a user