mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
v4.3.4 update. (#2892)
This commit is contained in:
@@ -401,11 +401,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
d_per_blk = 128
|
||||
d_blks = cute.ceil_div(d, d_per_blk)
|
||||
|
||||
reduction(
|
||||
o, m, l,
|
||||
o_partial, m_partial, l_partial,
|
||||
scale_o
|
||||
).launch(
|
||||
self.reduction(o, m, l, o_partial, m_partial, l_partial, scale_o).launch(
|
||||
grid=[d_blks, h_q, b],
|
||||
block=[d_per_blk, 1, 1],
|
||||
cluster=[1, 1, 1],
|
||||
@@ -1173,7 +1169,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
|
||||
# Reduce colmax in smem
|
||||
if lane_store_max:
|
||||
smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane)
|
||||
self.smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane)
|
||||
|
||||
# Wait for colmax then load
|
||||
cute.arch.barrier(barrier_id=softmax_nbar_id, number_of_threads=warpgroup_threads)
|
||||
@@ -1259,7 +1255,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
# Reduce cluster colmax
|
||||
if warpgroup_widx == 0:
|
||||
if lane_store_max:
|
||||
dsmem_fmax(
|
||||
self.dsmem_fmax(
|
||||
sM_cluster.iterator + sM_layout((0, lane_idx)),
|
||||
sM[(0, lane_idx)],
|
||||
m_cluster_full_ptr
|
||||
@@ -1280,7 +1276,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
else:
|
||||
# other splits copy cluster colmax into local smem
|
||||
if lane_store_max:
|
||||
sM_cluster[0, lane_idx] = dsmem_load(
|
||||
sM_cluster[0, lane_idx] = self.dsmem_load(
|
||||
sM_cluster.iterator + sM_layout((0, lane_idx))
|
||||
)
|
||||
|
||||
@@ -1300,8 +1296,10 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
sM_lane = sM_lane * scale_qs
|
||||
|
||||
# Store colsum and colmax
|
||||
gmem_fadd(gL_partial.iterator + gL_partial.layout(lane_idx), sL_lane)
|
||||
gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
|
||||
self.gmem_fadd(
|
||||
gL_partial.iterator + gL_partial.layout(lane_idx), sL_lane
|
||||
)
|
||||
self.gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
|
||||
if kv_split_in_cluster == 0:
|
||||
gM_partial[lane_idx] = sM_lane
|
||||
|
||||
@@ -1350,7 +1348,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
# Store colsum and colmax
|
||||
gL_partial[lane_idx] = sL_lane
|
||||
gM_partial[lane_idx] = sM_lane
|
||||
gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
|
||||
self.gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
|
||||
|
||||
o_handle = o_consumer.wait_and_advance()
|
||||
cute.copy(thr_load_s, tStO, tSrO)
|
||||
@@ -1378,6 +1376,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@cute.kernel
|
||||
def reduction(
|
||||
o : cute.Tensor,
|
||||
@@ -1414,6 +1413,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
def _mapa(ptr : Pointer, cta_rank_in_cluster : Int32 = 0):
|
||||
llvm_ptr = ptr.llvm_ptr
|
||||
@@ -1424,8 +1424,8 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
)
|
||||
|
||||
@cute.jit
|
||||
def dsmem_load(val_ptr : Pointer):
|
||||
val_llvm_ptr = _mapa(val_ptr, 0)
|
||||
def dsmem_load(self, val_ptr: Pointer):
|
||||
val_llvm_ptr = self._mapa(val_ptr, 0)
|
||||
|
||||
ret = llvm.inline_asm(
|
||||
Float32.mlir_type,
|
||||
@@ -1439,6 +1439,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
|
||||
return Float32(ret)
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
def warp_fmax(val : Float32):
|
||||
ret = llvm.inline_asm(
|
||||
@@ -1472,10 +1473,10 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
)
|
||||
|
||||
@cute.jit
|
||||
def dsmem_fmax(val_ptr : Pointer, val : Float32, mbar_ptr : Pointer):
|
||||
def dsmem_fmax(self, val_ptr: Pointer, val: Float32, mbar_ptr: Pointer):
|
||||
expect_tx_bytes = Int32(Float32.width // 8)
|
||||
val_llvm_ptr = _mapa(val_ptr, 0)
|
||||
mbar_llvm_ptr = _mapa(mbar_ptr, 0)
|
||||
val_llvm_ptr = self._mapa(val_ptr, 0)
|
||||
mbar_llvm_ptr = self._mapa(mbar_ptr, 0)
|
||||
|
||||
nvvm.mbarrier_txn(
|
||||
mbar_llvm_ptr,
|
||||
@@ -1499,6 +1500,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
def gmem_fmax(ptr : Pointer, val : Float32):
|
||||
llvm.inline_asm(
|
||||
@@ -1529,6 +1531,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
def gmem_fadd(ptr : Pointer, val : Float32):
|
||||
llvm.inline_asm(
|
||||
|
||||
@@ -0,0 +1,368 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import argparse
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
|
||||
def supports_pdl():
|
||||
import torch
|
||||
|
||||
return torch.cuda.get_device_capability()[0] >= 9
|
||||
|
||||
|
||||
"""
|
||||
This example demonstrates the use of Programmatic Dependent Launch (PDL) using
|
||||
CuTe DSL.
|
||||
|
||||
PDL is a mechanism which allows for overlapping execution of back-to-back kernels
|
||||
within the same stream.
|
||||
For example, consider the following two elementwise add operations, where the second
|
||||
operation's first operand is the result of the first operation. While performing
|
||||
``w = u + v`` we will load u and v, add them, and then store the result. Once we
|
||||
have finished loading data, we are no longer utilizing the read bandwidth.
|
||||
To effectively utilize the read bandwidth, we can start loading ``x``
|
||||
immediately upon finishing reading. This is what PDL enables us to do.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
w = u + v
|
||||
y = w + x
|
||||
|
||||
To enable PDL, we need to do two things:
|
||||
|
||||
1. Insert the ``griddepcontrol.launch_dependents`` and ``griddepcontrol.wait`` instructions in the kernel.
|
||||
2. Set the PDL launch attribute when launching the kernel.
|
||||
|
||||
The ``griddepcontrol.launch_dependents`` and ``griddepcontrol.wait``
|
||||
instructions enable fine-grained control over kernel execution in PDL.
|
||||
Once all thread blocks execute the ``griddepcontrol.launch_dependents``
|
||||
instruction, the dependent kernels can opportunistically be early-launched.
|
||||
``griddepcontrol.wait`` functions as a synchronization barrier - any warp
|
||||
executing this instruction will block until the previous kernel finishes
|
||||
execution. This allows precise control over data dependencies between kernels.
|
||||
|
||||
The following diagram shows the overlapping execution of two dependent kernels.
|
||||
We call the instructions before ``griddepcontrol.wait`` as prologue (``P0``),
|
||||
which may include barrier initialization and loading of independent data, etc.
|
||||
We call the instructions after ``griddepcontrol.launch_dependents`` as epilogue
|
||||
(``P2``), which may include math operations, data stores, etc. PDL enables
|
||||
these prologue and epilogue phases to execute concurrently across dependent
|
||||
kernels, improving GPU resource utilization. This is particularly beneficial
|
||||
when prologue and epilogue are bound by different resources (e.g., memory
|
||||
bandwidth vs compute throughput).
|
||||
|
||||
# P0: Prologue, P1: Main compute, P2: Epilogue
|
||||
|
||||
P0 P1 P2
|
||||
K1: |=====|+++++|-----|
|
||||
|
||||
<-----> K2 can start early
|
||||
(K1's P2 overlaps with K2's P0)
|
||||
|
||||
P0 P1 P2
|
||||
K2: |=====| |+++++|-----|
|
||||
^
|
||||
|
|
||||
wait for K1 to complete
|
||||
Time ------------------------------------------------------>
|
||||
|
||||
We could run this example with and without PDL:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/blackwell/programmatic_dependent_launch.py --benchmark
|
||||
python examples/blackwell/programmatic_dependent_launch.py --benchmark --use_pdl
|
||||
|
||||
From the benchmark results, you can see some speedups for the PDL version in most cases, benefiting from
|
||||
the overlapping execution of consecutive kernels. Moreover, you can use nsys to observe the overlapping execution.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
nsys profile python examples/blackwell/programmatic_dependent_launch.py --benchmark --use_pdl
|
||||
|
||||
Note, PDL feature is supported on Hopper and later GPUs.
|
||||
|
||||
See [the programming guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization)
|
||||
and the [PTX documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol)
|
||||
for more details.
|
||||
"""
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def elementwise_add_kernel(
|
||||
gA: cute.Tensor,
|
||||
gB: cute.Tensor,
|
||||
gC: cute.Tensor,
|
||||
cC: cute.Tensor, # coordinate tensor
|
||||
shape: cute.Shape,
|
||||
thr_layout: cute.Layout,
|
||||
val_layout: cute.Layout,
|
||||
is_first_kernel: cutlass.Constexpr = True,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
blk_coord = ((None, None), bidx)
|
||||
blkA = gA[blk_coord] # (TileM,TileN)
|
||||
blkB = gB[blk_coord] # (TileM,TileN)
|
||||
blkC = gC[blk_coord] # (TileM,TileN)
|
||||
blkCrd = cC[blk_coord] # (TileM, TileN)
|
||||
|
||||
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
|
||||
copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type)
|
||||
|
||||
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout)
|
||||
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
thr_copy_C = tiled_copy_C.get_slice(tidx)
|
||||
|
||||
thrA = thr_copy_A.partition_S(blkA)
|
||||
thrB = thr_copy_B.partition_S(blkB)
|
||||
thrC = thr_copy_C.partition_S(blkC)
|
||||
|
||||
frgA = cute.make_fragment_like(thrA)
|
||||
frgB = cute.make_fragment_like(thrB)
|
||||
frgC = cute.make_fragment_like(thrC)
|
||||
|
||||
thrCrd = thr_copy_C.partition_S(blkCrd)
|
||||
frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean)
|
||||
|
||||
for i in range(cute.size(frgPred)):
|
||||
val = cute.elem_less(thrCrd[i], shape)
|
||||
frgPred[i] = val
|
||||
|
||||
# Note: when not using cuda-graph, the kernel execution may be blocked by the host overhead.
|
||||
# In this case we won't see overlapping even when pdl is enabled.
|
||||
# In this example, we add a loop (10 times) for all the copy and compute operations in the following code
|
||||
# to make kernel running longer and make pdl benefits observable for both cuda-graph enabled and disabled cases.
|
||||
if is_first_kernel:
|
||||
for _ in range(10):
|
||||
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
|
||||
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
|
||||
# Here we add the launch dependents instruction for the first kernel as a hint to the runtime to early-launch
|
||||
# the next kernel. If the next kernel becomes concurrent, we will have overlap where the second kernel
|
||||
# can start reading x to ensure an E2E speedup. Note the placement of launch dependents has no implication
|
||||
# on correctness, only performance.
|
||||
cute.arch.griddepcontrol_launch_dependents()
|
||||
else:
|
||||
# In this example, the second kernel's second operand ``gB`` has no dependencies, its loading can overlap
|
||||
# with the computation of ``gC`` from the first kernel.
|
||||
for _ in range(10):
|
||||
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
|
||||
|
||||
# For the second kernel, its first operand ``gA`` is dependent on the previous kernel, we must call
|
||||
# griddepcontrol.wait to assure correctness. This instruction will block until the prior kernels finishes
|
||||
# and its memory operations are visible. Since gA is written by the prior kernel, this will block until gA
|
||||
# is visible to our kernel. Without it, we would have undefined behavior due to a race condition.
|
||||
cute.arch.griddepcontrol_wait()
|
||||
|
||||
for _ in range(10):
|
||||
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
|
||||
|
||||
for _ in range(10):
|
||||
result = frgA.load() + frgB.load()
|
||||
frgC.store(result)
|
||||
cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def elementwise_add(
|
||||
mA,
|
||||
mB,
|
||||
mC,
|
||||
stream: cuda.CUstream,
|
||||
use_pdl: cutlass.Constexpr = True,
|
||||
is_first_kernel: cutlass.Constexpr = True,
|
||||
):
|
||||
dtype = mA.element_type
|
||||
# copy_bits for a thread is 128 bits, and we use 128 // dtype.width to get the vector size
|
||||
vector_size = 128 // dtype.width
|
||||
|
||||
thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
|
||||
val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0))
|
||||
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
||||
|
||||
gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
|
||||
idC = cute.make_identity_tensor(mC.shape)
|
||||
cC = cute.zipped_divide(idC, tiler=tiler_mn)
|
||||
|
||||
elementwise_add_kernel(
|
||||
gA, gB, gC, cC, mC.shape, thr_layout, val_layout, is_first_kernel
|
||||
).launch(
|
||||
grid=[cute.size(gC, mode=[1]), 1, 1],
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
stream=stream,
|
||||
use_pdl=use_pdl,
|
||||
)
|
||||
|
||||
|
||||
def run_pdl_example(
|
||||
M,
|
||||
N,
|
||||
skip_ref_check=False,
|
||||
benchmark=False,
|
||||
warmup_iterations=5,
|
||||
iterations=100,
|
||||
use_pdl=True,
|
||||
):
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("Blackwell/Hopper GPU is required to run this example!")
|
||||
|
||||
print("\nRunning Elementwise Add test with:")
|
||||
print(f"Tensor dimensions: [{M}, {N}]")
|
||||
print(f"Use PDL: {use_pdl}")
|
||||
|
||||
u = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
v = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
w = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
x = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
y = torch.empty(M, N, dtype=torch.float32, device="cuda")
|
||||
|
||||
u_tensor = from_dlpack(u).mark_layout_dynamic()
|
||||
v_tensor = from_dlpack(v).mark_layout_dynamic()
|
||||
w_tensor = from_dlpack(w).mark_layout_dynamic()
|
||||
x_tensor = from_dlpack(x).mark_layout_dynamic()
|
||||
y_tensor = from_dlpack(y).mark_layout_dynamic()
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
current_stream = cuda.CUstream(stream.cuda_stream)
|
||||
# Since is_first_kernel is cutlass.Constexpr, we need to compile for
|
||||
# the first and second kernel separately.
|
||||
compiled_func_first_kernel = cute.compile(
|
||||
elementwise_add,
|
||||
u_tensor,
|
||||
v_tensor,
|
||||
w_tensor,
|
||||
current_stream,
|
||||
use_pdl,
|
||||
is_first_kernel=True,
|
||||
options="--enable-tvm-ffi",
|
||||
)
|
||||
compiled_func_second_kernel = cute.compile(
|
||||
elementwise_add,
|
||||
w_tensor,
|
||||
x_tensor,
|
||||
y_tensor,
|
||||
current_stream,
|
||||
use_pdl,
|
||||
is_first_kernel=False,
|
||||
options="--enable-tvm-ffi",
|
||||
)
|
||||
|
||||
# launch and run the two consecutive kernels in a same stream.
|
||||
def run_func(current_stream, u, v, w, x, y):
|
||||
# Run first operation: w_tensor = u_tensor + v_tensor
|
||||
compiled_func_first_kernel(
|
||||
u,
|
||||
v,
|
||||
w,
|
||||
current_stream,
|
||||
)
|
||||
# Run second operation: y_tensor = w_tensor + x_tensor
|
||||
# its first operand ``w_tensor`` is the result of the first operation,
|
||||
# they use the same memory space.
|
||||
compiled_func_second_kernel(
|
||||
w,
|
||||
x,
|
||||
y,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
if not skip_ref_check:
|
||||
run_func(current_stream, u, v, w, x, y)
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(u.cpu() + v.cpu() + x.cpu(), y.cpu())
|
||||
print("Results verified successfully!")
|
||||
|
||||
if not benchmark:
|
||||
return
|
||||
|
||||
def generate_kernel_arguments():
|
||||
u = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
v = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
w = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
x = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
y = torch.empty(M, N, dtype=torch.float32, device="cuda")
|
||||
|
||||
return testing.JitArguments(current_stream, u, v, w, x, y)
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
run_func,
|
||||
workspace_generator=generate_kernel_arguments,
|
||||
workspace_count=10,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
stream=current_stream,
|
||||
use_cuda_graphs=True,
|
||||
)
|
||||
print(f"Execution time: {avg_time_us:.4f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="example of Programmatic Dependent Launch (PDL) using CuTe DSL"
|
||||
)
|
||||
parser.add_argument("--M", default=256, type=int)
|
||||
parser.add_argument("--N", default=256, type=int)
|
||||
parser.add_argument("--warmup_iterations", default=5, type=int)
|
||||
parser.add_argument("--iterations", default=10, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument("--benchmark", action="store_true")
|
||||
parser.add_argument("--use_pdl", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
if supports_pdl():
|
||||
run_pdl_example(
|
||||
args.M,
|
||||
args.N,
|
||||
skip_ref_check=args.skip_ref_check,
|
||||
benchmark=args.benchmark,
|
||||
warmup_iterations=args.warmup_iterations,
|
||||
iterations=args.iterations,
|
||||
use_pdl=args.use_pdl,
|
||||
)
|
||||
print("\nPASS")
|
||||
else:
|
||||
print(
|
||||
"PDL is not supported on this device, it requires Hopper or newer generations"
|
||||
)
|
||||
@@ -839,7 +839,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
)
|
||||
|
||||
s_max_layout = cute.make_layout(
|
||||
cute.size(layout_acc_mn(pv_tiled_mma, acc_pv.layout), mode=[0])
|
||||
cute.size(self.layout_acc_mn(pv_tiled_mma, acc_pv.layout), mode=[0])
|
||||
)
|
||||
s_max = cute.make_rmem_tensor_like(s_max_layout, self.qk_acc_dtype)
|
||||
a_sum = cute.make_rmem_tensor_like(s_max, cutlass.Float32)
|
||||
@@ -888,7 +888,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
# MMA QK
|
||||
cute.nvgpu.warpgroup.fence()
|
||||
|
||||
gemm_zero_acc(
|
||||
self.gemm_zero_acc(
|
||||
qk_tiled_mma,
|
||||
tSrQ[(None, None, None, q_handle.index)],
|
||||
tSrK[(None, None, None, k_handle.index)],
|
||||
@@ -901,7 +901,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
# Wait for the pipeline MMAs to drain
|
||||
cute.nvgpu.warpgroup.wait_group(0)
|
||||
|
||||
s_max, a_sum = softmax_step(
|
||||
s_max, a_sum = self.softmax_step(
|
||||
True,
|
||||
self.mask_type,
|
||||
acc_qk,
|
||||
@@ -919,7 +919,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
True,
|
||||
)
|
||||
|
||||
acc_qk_fixed = make_acc_into_op(
|
||||
acc_qk_fixed = self.make_acc_into_op(
|
||||
acc_qk, pv_tiled_mma.tv_layout_A, self.q_dtype
|
||||
)
|
||||
|
||||
@@ -928,7 +928,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
# MMA PV
|
||||
cute.nvgpu.warpgroup.fence()
|
||||
|
||||
gemm_zero_acc(
|
||||
self.gemm_zero_acc(
|
||||
pv_tiled_mma,
|
||||
acc_qk_fixed,
|
||||
tOrV[(None, None, None, v_handle.index)],
|
||||
@@ -1040,7 +1040,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
cute.nvgpu.warpgroup.wait_group(0)
|
||||
|
||||
# acc_pv updated
|
||||
lse = tail(
|
||||
lse = self.tail(
|
||||
s_max, a_sum, acc_pv, pv_tiled_mma, scale_softmax, scale_output
|
||||
)
|
||||
|
||||
@@ -1077,10 +1077,10 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
|
||||
if tOcO[0][1] == 0:
|
||||
tOgLSE_mn = cute.make_tensor(
|
||||
tOgLSE.iterator, layout_acc_mn(pv_tiled_mma, tOgLSE.layout)
|
||||
tOgLSE.iterator, self.layout_acc_mn(pv_tiled_mma, tOgLSE.layout)
|
||||
)
|
||||
tOcO_mn = cute.make_tensor(
|
||||
tOcO.iterator, layout_acc_mn(pv_tiled_mma, tOcO.layout)
|
||||
tOcO.iterator, self.layout_acc_mn(pv_tiled_mma, tOcO.layout)
|
||||
)
|
||||
for i in cutlass.range_constexpr(cute.size(tOgLSE_mn, mode=[0])):
|
||||
if (
|
||||
@@ -1241,7 +1241,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
# MMA QK
|
||||
cute.nvgpu.warpgroup.fence()
|
||||
|
||||
gemm_zero_acc(
|
||||
self.gemm_zero_acc(
|
||||
qk_tiled_mma,
|
||||
tSrQ[(None, None, None, q_handle.index)],
|
||||
tSrK[(None, None, None, k_handle.index)],
|
||||
@@ -1255,7 +1255,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
# Wait for the pipeline MMAs to drain
|
||||
cute.nvgpu.warpgroup.wait_group(0)
|
||||
|
||||
s_max, a_sum = softmax_step(
|
||||
s_max, a_sum = self.softmax_step(
|
||||
fusion,
|
||||
self.mask_type,
|
||||
acc_qk,
|
||||
@@ -1272,7 +1272,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
window_size_right,
|
||||
)
|
||||
|
||||
acc_qk_fixed = make_acc_into_op(
|
||||
acc_qk_fixed = self.make_acc_into_op(
|
||||
acc_qk, pv_tiled_mma.tv_layout_A, self.q_dtype
|
||||
)
|
||||
|
||||
@@ -1300,6 +1300,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
|
||||
@cute.jit
|
||||
def softmax_step(
|
||||
self,
|
||||
fusion: bool,
|
||||
mask_type: fmha_utils.MaskEnum,
|
||||
acc_qk: cute.ThrMma,
|
||||
@@ -1328,10 +1329,10 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
)
|
||||
|
||||
acc_qk_mn = cute.make_tensor(
|
||||
acc_qk.iterator, layout_acc_mn(tiled_mma_qk, acc_qk.layout)
|
||||
acc_qk.iterator, self.layout_acc_mn(tiled_mma_qk, acc_qk.layout)
|
||||
)
|
||||
|
||||
reduction_target_qk = reduction_target_n(tiled_mma_qk)
|
||||
reduction_target_qk = self.reduction_target_n(tiled_mma_qk)
|
||||
red_rank = cute.rank(reduction_target_qk)
|
||||
|
||||
s_max_prev = None
|
||||
@@ -1346,7 +1347,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
s_max[i] = cute.arch.fmax(s_max[i], acc_qk_mn[i, j])
|
||||
else:
|
||||
acc_pv_mn = cute.make_tensor(
|
||||
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
||||
acc_pv.iterator, self.layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
||||
)
|
||||
s_max_prev = cute.make_rmem_tensor_like(s_max, s_max._dtype)
|
||||
|
||||
@@ -1396,15 +1397,15 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
return s_max, a_sum
|
||||
|
||||
@cute.jit
|
||||
def reduction_target_n(tiled_mma):
|
||||
separated = layout_separate(
|
||||
def reduction_target_n(self, tiled_mma):
|
||||
separated = self.layout_separate(
|
||||
tiled_mma.shape_mnk[0],
|
||||
cute.make_layout(tiled_mma.tv_layout_C.shape[0]),
|
||||
tiled_mma.tv_layout_C.stride[0],
|
||||
)
|
||||
return separated[1]
|
||||
|
||||
@cute.jit
|
||||
@staticmethod
|
||||
def convert_c_layout_to_a_layout(c, a):
|
||||
return cute.make_layout(
|
||||
(a, c.shape[1], (c.shape[2], cute.size(c, mode=[0]) // cute.size(a))),
|
||||
@@ -1416,9 +1417,9 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
)
|
||||
|
||||
@cute.jit
|
||||
def make_acc_into_op(acc, operand_layout_tv, Element):
|
||||
def make_acc_into_op(self, acc, operand_layout_tv, Element):
|
||||
operand = cute.make_rmem_tensor_like(
|
||||
convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]),
|
||||
self.convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]),
|
||||
Element,
|
||||
)
|
||||
operand_as_acc = cute.make_tensor(operand.iterator, acc.layout)
|
||||
@@ -1499,7 +1500,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
return operand
|
||||
|
||||
@cute.jit
|
||||
def tail(s_max, a_sum, acc_pv, tiled_mma_pv, scale_softmax, scale_output):
|
||||
def tail(self, s_max, a_sum, acc_pv, tiled_mma_pv, scale_softmax, scale_output):
|
||||
"""
|
||||
Final processing step for FMHA that computes log-sum-exp (LSE) and scales the output.
|
||||
|
||||
@@ -1527,9 +1528,9 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
"""
|
||||
# Create tensor view of accumulated P*V values with M*N layout
|
||||
acc_pv_mn = cute.make_tensor(
|
||||
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
||||
acc_pv.iterator, self.layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
||||
)
|
||||
reduction_target = reduction_target_n(tiled_mma_pv)
|
||||
reduction_target = self.reduction_target_n(tiled_mma_pv)
|
||||
red_rank = cute.rank(reduction_target)
|
||||
for r in cutlass.range_constexpr(red_rank):
|
||||
for i in cutlass.range_constexpr(cute.size(acc_pv_mn, mode=[0])):
|
||||
@@ -1538,7 +1539,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
)
|
||||
|
||||
acc_mn = cute.make_tensor(
|
||||
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
||||
acc_pv.iterator, self.layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
||||
)
|
||||
|
||||
lse = cute.make_rmem_tensor_like(a_sum, a_sum._dtype)
|
||||
@@ -1559,7 +1560,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
|
||||
return lse
|
||||
|
||||
@cute.jit
|
||||
@staticmethod
|
||||
def layout_separate(thr, src, ref):
|
||||
lt = cute.make_layout(())
|
||||
ge = cute.make_layout(())
|
||||
@@ -1577,6 +1578,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
r = cute.append(cute.append(cute.make_layout(()), lt), ge)
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
def gemm_zero_acc(tiled_mma, A, B, C):
|
||||
rA = cute.rank(A)
|
||||
@@ -1606,8 +1608,8 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
assert 0
|
||||
|
||||
@cute.jit
|
||||
def layout_acc_mn(tiled_mma, acc):
|
||||
separated = layout_separate(
|
||||
def layout_acc_mn(self, tiled_mma, acc):
|
||||
separated = self.layout_separate(
|
||||
tiled_mma.shape_mnk[0], acc[0], tiled_mma.tv_layout_C.stride[1]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user