v4.3.4 update. (#2892)

This commit is contained in:
Junkai-Wu
2025-12-22 00:49:12 +08:00
committed by GitHub
parent 331e2f451c
commit 7f5fe3edf1
31 changed files with 839 additions and 240 deletions

View File

@@ -401,11 +401,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
d_per_blk = 128
d_blks = cute.ceil_div(d, d_per_blk)
reduction(
o, m, l,
o_partial, m_partial, l_partial,
scale_o
).launch(
self.reduction(o, m, l, o_partial, m_partial, l_partial, scale_o).launch(
grid=[d_blks, h_q, b],
block=[d_per_blk, 1, 1],
cluster=[1, 1, 1],
@@ -1173,7 +1169,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
# Reduce colmax in smem
if lane_store_max:
smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane)
self.smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane)
# Wait for colmax then load
cute.arch.barrier(barrier_id=softmax_nbar_id, number_of_threads=warpgroup_threads)
@@ -1259,7 +1255,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
# Reduce cluster colmax
if warpgroup_widx == 0:
if lane_store_max:
dsmem_fmax(
self.dsmem_fmax(
sM_cluster.iterator + sM_layout((0, lane_idx)),
sM[(0, lane_idx)],
m_cluster_full_ptr
@@ -1280,7 +1276,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
else:
# other splits copy cluster colmax into local smem
if lane_store_max:
sM_cluster[0, lane_idx] = dsmem_load(
sM_cluster[0, lane_idx] = self.dsmem_load(
sM_cluster.iterator + sM_layout((0, lane_idx))
)
@@ -1300,8 +1296,10 @@ class MixedInputFusedMultiHeadAttentionDecode:
sM_lane = sM_lane * scale_qs
# Store colsum and colmax
gmem_fadd(gL_partial.iterator + gL_partial.layout(lane_idx), sL_lane)
gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
self.gmem_fadd(
gL_partial.iterator + gL_partial.layout(lane_idx), sL_lane
)
self.gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
if kv_split_in_cluster == 0:
gM_partial[lane_idx] = sM_lane
@@ -1350,7 +1348,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
# Store colsum and colmax
gL_partial[lane_idx] = sL_lane
gM_partial[lane_idx] = sM_lane
gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
self.gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
o_handle = o_consumer.wait_and_advance()
cute.copy(thr_load_s, tStO, tSrO)
@@ -1378,6 +1376,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
return
@staticmethod
@cute.kernel
def reduction(
o : cute.Tensor,
@@ -1414,6 +1413,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
return
@staticmethod
@cute.jit
def _mapa(ptr : Pointer, cta_rank_in_cluster : Int32 = 0):
llvm_ptr = ptr.llvm_ptr
@@ -1424,8 +1424,8 @@ class MixedInputFusedMultiHeadAttentionDecode:
)
@cute.jit
def dsmem_load(val_ptr : Pointer):
val_llvm_ptr = _mapa(val_ptr, 0)
def dsmem_load(self, val_ptr: Pointer):
val_llvm_ptr = self._mapa(val_ptr, 0)
ret = llvm.inline_asm(
Float32.mlir_type,
@@ -1439,6 +1439,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
return Float32(ret)
@staticmethod
@cute.jit
def warp_fmax(val : Float32):
ret = llvm.inline_asm(
@@ -1472,10 +1473,10 @@ class MixedInputFusedMultiHeadAttentionDecode:
)
@cute.jit
def dsmem_fmax(val_ptr : Pointer, val : Float32, mbar_ptr : Pointer):
def dsmem_fmax(self, val_ptr: Pointer, val: Float32, mbar_ptr: Pointer):
expect_tx_bytes = Int32(Float32.width // 8)
val_llvm_ptr = _mapa(val_ptr, 0)
mbar_llvm_ptr = _mapa(mbar_ptr, 0)
val_llvm_ptr = self._mapa(val_ptr, 0)
mbar_llvm_ptr = self._mapa(mbar_ptr, 0)
nvvm.mbarrier_txn(
mbar_llvm_ptr,
@@ -1499,6 +1500,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
asm_dialect=llvm.AsmDialect.AD_ATT,
)
@staticmethod
@cute.jit
def gmem_fmax(ptr : Pointer, val : Float32):
llvm.inline_asm(
@@ -1529,6 +1531,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
asm_dialect=llvm.AsmDialect.AD_ATT,
)
@staticmethod
@cute.jit
def gmem_fadd(ptr : Pointer, val : Float32):
llvm.inline_asm(

View File

@@ -0,0 +1,368 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
from cutlass.cute.runtime import from_dlpack
def supports_pdl():
import torch
return torch.cuda.get_device_capability()[0] >= 9
"""
This example demonstrates the use of Programmatic Dependent Launch (PDL) using
CuTe DSL.
PDL is a mechanism which allows for overlapping execution of back-to-back kernels
within the same stream.
For example, consider the following two elementwise add operations, where the second
operation's first operand is the result of the first operation. While performing
``w = u + v`` we will load u and v, add them, and then store the result. Once we
have finished loading data, we are no longer utilizing the read bandwidth.
To effectively utilize the read bandwidth, we can start loading ``x``
immediately upon finishing reading. This is what PDL enables us to do.
.. code-block:: bash
w = u + v
y = w + x
To enable PDL, we need to do two things:
1. Insert the ``griddepcontrol.launch_dependents`` and ``griddepcontrol.wait`` instructions in the kernel.
2. Set the PDL launch attribute when launching the kernel.
The ``griddepcontrol.launch_dependents`` and ``griddepcontrol.wait``
instructions enable fine-grained control over kernel execution in PDL.
Once all thread blocks execute the ``griddepcontrol.launch_dependents``
instruction, the dependent kernels can opportunistically be early-launched.
``griddepcontrol.wait`` functions as a synchronization barrier - any warp
executing this instruction will block until the previous kernel finishes
execution. This allows precise control over data dependencies between kernels.
The following diagram shows the overlapping execution of two dependent kernels.
We call the instructions before ``griddepcontrol.wait`` as prologue (``P0``),
which may include barrier initialization and loading of independent data, etc.
We call the instructions after ``griddepcontrol.launch_dependents`` as epilogue
(``P2``), which may include math operations, data stores, etc. PDL enables
these prologue and epilogue phases to execute concurrently across dependent
kernels, improving GPU resource utilization. This is particularly beneficial
when prologue and epilogue are bound by different resources (e.g., memory
bandwidth vs compute throughput).
# P0: Prologue, P1: Main compute, P2: Epilogue
P0 P1 P2
K1: |=====|+++++|-----|
<-----> K2 can start early
(K1's P2 overlaps with K2's P0)
P0 P1 P2
K2: |=====| |+++++|-----|
^
|
wait for K1 to complete
Time ------------------------------------------------------>
We could run this example with and without PDL:
.. code-block:: bash
python examples/blackwell/programmatic_dependent_launch.py --benchmark
python examples/blackwell/programmatic_dependent_launch.py --benchmark --use_pdl
From the benchmark results, you can see some speedups for the PDL version in most cases, benefiting from
the overlapping execution of consecutive kernels. Moreover, you can use nsys to observe the overlapping execution.
.. code-block:: bash
nsys profile python examples/blackwell/programmatic_dependent_launch.py --benchmark --use_pdl
Note, PDL feature is supported on Hopper and later GPUs.
See [the programming guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization)
and the [PTX documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol)
for more details.
"""
@cute.kernel
def elementwise_add_kernel(
gA: cute.Tensor,
gB: cute.Tensor,
gC: cute.Tensor,
cC: cute.Tensor, # coordinate tensor
shape: cute.Shape,
thr_layout: cute.Layout,
val_layout: cute.Layout,
is_first_kernel: cutlass.Constexpr = True,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
blk_coord = ((None, None), bidx)
blkA = gA[blk_coord] # (TileM,TileN)
blkB = gB[blk_coord] # (TileM,TileN)
blkC = gC[blk_coord] # (TileM,TileN)
blkCrd = cC[blk_coord] # (TileM, TileN)
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type)
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout)
thr_copy_A = tiled_copy_A.get_slice(tidx)
thr_copy_B = tiled_copy_B.get_slice(tidx)
thr_copy_C = tiled_copy_C.get_slice(tidx)
thrA = thr_copy_A.partition_S(blkA)
thrB = thr_copy_B.partition_S(blkB)
thrC = thr_copy_C.partition_S(blkC)
frgA = cute.make_fragment_like(thrA)
frgB = cute.make_fragment_like(thrB)
frgC = cute.make_fragment_like(thrC)
thrCrd = thr_copy_C.partition_S(blkCrd)
frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean)
for i in range(cute.size(frgPred)):
val = cute.elem_less(thrCrd[i], shape)
frgPred[i] = val
# Note: when not using cuda-graph, the kernel execution may be blocked by the host overhead.
# In this case we won't see overlapping even when pdl is enabled.
# In this example, we add a loop (10 times) for all the copy and compute operations in the following code
# to make kernel running longer and make pdl benefits observable for both cuda-graph enabled and disabled cases.
if is_first_kernel:
for _ in range(10):
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
# Here we add the launch dependents instruction for the first kernel as a hint to the runtime to early-launch
# the next kernel. If the next kernel becomes concurrent, we will have overlap where the second kernel
# can start reading x to ensure an E2E speedup. Note the placement of launch dependents has no implication
# on correctness, only performance.
cute.arch.griddepcontrol_launch_dependents()
else:
# In this example, the second kernel's second operand ``gB`` has no dependencies, its loading can overlap
# with the computation of ``gC`` from the first kernel.
for _ in range(10):
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
# For the second kernel, its first operand ``gA`` is dependent on the previous kernel, we must call
# griddepcontrol.wait to assure correctness. This instruction will block until the prior kernels finishes
# and its memory operations are visible. Since gA is written by the prior kernel, this will block until gA
# is visible to our kernel. Without it, we would have undefined behavior due to a race condition.
cute.arch.griddepcontrol_wait()
for _ in range(10):
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
for _ in range(10):
result = frgA.load() + frgB.load()
frgC.store(result)
cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)
@cute.jit
def elementwise_add(
mA,
mB,
mC,
stream: cuda.CUstream,
use_pdl: cutlass.Constexpr = True,
is_first_kernel: cutlass.Constexpr = True,
):
dtype = mA.element_type
# copy_bits for a thread is 128 bits, and we use 128 // dtype.width to get the vector size
vector_size = 128 // dtype.width
thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0))
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN))
gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN))
gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN))
idC = cute.make_identity_tensor(mC.shape)
cC = cute.zipped_divide(idC, tiler=tiler_mn)
elementwise_add_kernel(
gA, gB, gC, cC, mC.shape, thr_layout, val_layout, is_first_kernel
).launch(
grid=[cute.size(gC, mode=[1]), 1, 1],
block=[cute.size(tv_layout, mode=[0]), 1, 1],
stream=stream,
use_pdl=use_pdl,
)
def run_pdl_example(
M,
N,
skip_ref_check=False,
benchmark=False,
warmup_iterations=5,
iterations=100,
use_pdl=True,
):
import torch
if not torch.cuda.is_available():
raise RuntimeError("Blackwell/Hopper GPU is required to run this example!")
print("\nRunning Elementwise Add test with:")
print(f"Tensor dimensions: [{M}, {N}]")
print(f"Use PDL: {use_pdl}")
u = torch.randn(M, N, dtype=torch.float32, device="cuda")
v = torch.randn(M, N, dtype=torch.float32, device="cuda")
w = torch.randn(M, N, dtype=torch.float32, device="cuda")
x = torch.randn(M, N, dtype=torch.float32, device="cuda")
y = torch.empty(M, N, dtype=torch.float32, device="cuda")
u_tensor = from_dlpack(u).mark_layout_dynamic()
v_tensor = from_dlpack(v).mark_layout_dynamic()
w_tensor = from_dlpack(w).mark_layout_dynamic()
x_tensor = from_dlpack(x).mark_layout_dynamic()
y_tensor = from_dlpack(y).mark_layout_dynamic()
stream = torch.cuda.Stream()
current_stream = cuda.CUstream(stream.cuda_stream)
# Since is_first_kernel is cutlass.Constexpr, we need to compile for
# the first and second kernel separately.
compiled_func_first_kernel = cute.compile(
elementwise_add,
u_tensor,
v_tensor,
w_tensor,
current_stream,
use_pdl,
is_first_kernel=True,
options="--enable-tvm-ffi",
)
compiled_func_second_kernel = cute.compile(
elementwise_add,
w_tensor,
x_tensor,
y_tensor,
current_stream,
use_pdl,
is_first_kernel=False,
options="--enable-tvm-ffi",
)
# launch and run the two consecutive kernels in a same stream.
def run_func(current_stream, u, v, w, x, y):
# Run first operation: w_tensor = u_tensor + v_tensor
compiled_func_first_kernel(
u,
v,
w,
current_stream,
)
# Run second operation: y_tensor = w_tensor + x_tensor
# its first operand ``w_tensor`` is the result of the first operation,
# they use the same memory space.
compiled_func_second_kernel(
w,
x,
y,
current_stream,
)
if not skip_ref_check:
run_func(current_stream, u, v, w, x, y)
print("Verifying results...")
torch.testing.assert_close(u.cpu() + v.cpu() + x.cpu(), y.cpu())
print("Results verified successfully!")
if not benchmark:
return
def generate_kernel_arguments():
u = torch.randn(M, N, dtype=torch.float32, device="cuda")
v = torch.randn(M, N, dtype=torch.float32, device="cuda")
w = torch.randn(M, N, dtype=torch.float32, device="cuda")
x = torch.randn(M, N, dtype=torch.float32, device="cuda")
y = torch.empty(M, N, dtype=torch.float32, device="cuda")
return testing.JitArguments(current_stream, u, v, w, x, y)
avg_time_us = testing.benchmark(
run_func,
workspace_generator=generate_kernel_arguments,
workspace_count=10,
warmup_iterations=warmup_iterations,
iterations=iterations,
stream=current_stream,
use_cuda_graphs=True,
)
print(f"Execution time: {avg_time_us:.4f} us")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="example of Programmatic Dependent Launch (PDL) using CuTe DSL"
)
parser.add_argument("--M", default=256, type=int)
parser.add_argument("--N", default=256, type=int)
parser.add_argument("--warmup_iterations", default=5, type=int)
parser.add_argument("--iterations", default=10, type=int)
parser.add_argument("--skip_ref_check", action="store_true")
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--use_pdl", action="store_true")
args = parser.parse_args()
if supports_pdl():
run_pdl_example(
args.M,
args.N,
skip_ref_check=args.skip_ref_check,
benchmark=args.benchmark,
warmup_iterations=args.warmup_iterations,
iterations=args.iterations,
use_pdl=args.use_pdl,
)
print("\nPASS")
else:
print(
"PDL is not supported on this device, it requires Hopper or newer generations"
)

View File

@@ -839,7 +839,7 @@ class HopperFusedMultiHeadAttentionForward:
)
s_max_layout = cute.make_layout(
cute.size(layout_acc_mn(pv_tiled_mma, acc_pv.layout), mode=[0])
cute.size(self.layout_acc_mn(pv_tiled_mma, acc_pv.layout), mode=[0])
)
s_max = cute.make_rmem_tensor_like(s_max_layout, self.qk_acc_dtype)
a_sum = cute.make_rmem_tensor_like(s_max, cutlass.Float32)
@@ -888,7 +888,7 @@ class HopperFusedMultiHeadAttentionForward:
# MMA QK
cute.nvgpu.warpgroup.fence()
gemm_zero_acc(
self.gemm_zero_acc(
qk_tiled_mma,
tSrQ[(None, None, None, q_handle.index)],
tSrK[(None, None, None, k_handle.index)],
@@ -901,7 +901,7 @@ class HopperFusedMultiHeadAttentionForward:
# Wait for the pipeline MMAs to drain
cute.nvgpu.warpgroup.wait_group(0)
s_max, a_sum = softmax_step(
s_max, a_sum = self.softmax_step(
True,
self.mask_type,
acc_qk,
@@ -919,7 +919,7 @@ class HopperFusedMultiHeadAttentionForward:
True,
)
acc_qk_fixed = make_acc_into_op(
acc_qk_fixed = self.make_acc_into_op(
acc_qk, pv_tiled_mma.tv_layout_A, self.q_dtype
)
@@ -928,7 +928,7 @@ class HopperFusedMultiHeadAttentionForward:
# MMA PV
cute.nvgpu.warpgroup.fence()
gemm_zero_acc(
self.gemm_zero_acc(
pv_tiled_mma,
acc_qk_fixed,
tOrV[(None, None, None, v_handle.index)],
@@ -1040,7 +1040,7 @@ class HopperFusedMultiHeadAttentionForward:
cute.nvgpu.warpgroup.wait_group(0)
# acc_pv updated
lse = tail(
lse = self.tail(
s_max, a_sum, acc_pv, pv_tiled_mma, scale_softmax, scale_output
)
@@ -1077,10 +1077,10 @@ class HopperFusedMultiHeadAttentionForward:
if tOcO[0][1] == 0:
tOgLSE_mn = cute.make_tensor(
tOgLSE.iterator, layout_acc_mn(pv_tiled_mma, tOgLSE.layout)
tOgLSE.iterator, self.layout_acc_mn(pv_tiled_mma, tOgLSE.layout)
)
tOcO_mn = cute.make_tensor(
tOcO.iterator, layout_acc_mn(pv_tiled_mma, tOcO.layout)
tOcO.iterator, self.layout_acc_mn(pv_tiled_mma, tOcO.layout)
)
for i in cutlass.range_constexpr(cute.size(tOgLSE_mn, mode=[0])):
if (
@@ -1241,7 +1241,7 @@ class HopperFusedMultiHeadAttentionForward:
# MMA QK
cute.nvgpu.warpgroup.fence()
gemm_zero_acc(
self.gemm_zero_acc(
qk_tiled_mma,
tSrQ[(None, None, None, q_handle.index)],
tSrK[(None, None, None, k_handle.index)],
@@ -1255,7 +1255,7 @@ class HopperFusedMultiHeadAttentionForward:
# Wait for the pipeline MMAs to drain
cute.nvgpu.warpgroup.wait_group(0)
s_max, a_sum = softmax_step(
s_max, a_sum = self.softmax_step(
fusion,
self.mask_type,
acc_qk,
@@ -1272,7 +1272,7 @@ class HopperFusedMultiHeadAttentionForward:
window_size_right,
)
acc_qk_fixed = make_acc_into_op(
acc_qk_fixed = self.make_acc_into_op(
acc_qk, pv_tiled_mma.tv_layout_A, self.q_dtype
)
@@ -1300,6 +1300,7 @@ class HopperFusedMultiHeadAttentionForward:
@cute.jit
def softmax_step(
self,
fusion: bool,
mask_type: fmha_utils.MaskEnum,
acc_qk: cute.ThrMma,
@@ -1328,10 +1329,10 @@ class HopperFusedMultiHeadAttentionForward:
)
acc_qk_mn = cute.make_tensor(
acc_qk.iterator, layout_acc_mn(tiled_mma_qk, acc_qk.layout)
acc_qk.iterator, self.layout_acc_mn(tiled_mma_qk, acc_qk.layout)
)
reduction_target_qk = reduction_target_n(tiled_mma_qk)
reduction_target_qk = self.reduction_target_n(tiled_mma_qk)
red_rank = cute.rank(reduction_target_qk)
s_max_prev = None
@@ -1346,7 +1347,7 @@ class HopperFusedMultiHeadAttentionForward:
s_max[i] = cute.arch.fmax(s_max[i], acc_qk_mn[i, j])
else:
acc_pv_mn = cute.make_tensor(
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
acc_pv.iterator, self.layout_acc_mn(tiled_mma_pv, acc_pv.layout)
)
s_max_prev = cute.make_rmem_tensor_like(s_max, s_max._dtype)
@@ -1396,15 +1397,15 @@ class HopperFusedMultiHeadAttentionForward:
return s_max, a_sum
@cute.jit
def reduction_target_n(tiled_mma):
separated = layout_separate(
def reduction_target_n(self, tiled_mma):
separated = self.layout_separate(
tiled_mma.shape_mnk[0],
cute.make_layout(tiled_mma.tv_layout_C.shape[0]),
tiled_mma.tv_layout_C.stride[0],
)
return separated[1]
@cute.jit
@staticmethod
def convert_c_layout_to_a_layout(c, a):
return cute.make_layout(
(a, c.shape[1], (c.shape[2], cute.size(c, mode=[0]) // cute.size(a))),
@@ -1416,9 +1417,9 @@ class HopperFusedMultiHeadAttentionForward:
)
@cute.jit
def make_acc_into_op(acc, operand_layout_tv, Element):
def make_acc_into_op(self, acc, operand_layout_tv, Element):
operand = cute.make_rmem_tensor_like(
convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]),
self.convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]),
Element,
)
operand_as_acc = cute.make_tensor(operand.iterator, acc.layout)
@@ -1499,7 +1500,7 @@ class HopperFusedMultiHeadAttentionForward:
return operand
@cute.jit
def tail(s_max, a_sum, acc_pv, tiled_mma_pv, scale_softmax, scale_output):
def tail(self, s_max, a_sum, acc_pv, tiled_mma_pv, scale_softmax, scale_output):
"""
Final processing step for FMHA that computes log-sum-exp (LSE) and scales the output.
@@ -1527,9 +1528,9 @@ class HopperFusedMultiHeadAttentionForward:
"""
# Create tensor view of accumulated P*V values with M*N layout
acc_pv_mn = cute.make_tensor(
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
acc_pv.iterator, self.layout_acc_mn(tiled_mma_pv, acc_pv.layout)
)
reduction_target = reduction_target_n(tiled_mma_pv)
reduction_target = self.reduction_target_n(tiled_mma_pv)
red_rank = cute.rank(reduction_target)
for r in cutlass.range_constexpr(red_rank):
for i in cutlass.range_constexpr(cute.size(acc_pv_mn, mode=[0])):
@@ -1538,7 +1539,7 @@ class HopperFusedMultiHeadAttentionForward:
)
acc_mn = cute.make_tensor(
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
acc_pv.iterator, self.layout_acc_mn(tiled_mma_pv, acc_pv.layout)
)
lse = cute.make_rmem_tensor_like(a_sum, a_sum._dtype)
@@ -1559,7 +1560,7 @@ class HopperFusedMultiHeadAttentionForward:
return lse
@cute.jit
@staticmethod
def layout_separate(thr, src, ref):
lt = cute.make_layout(())
ge = cute.make_layout(())
@@ -1577,6 +1578,7 @@ class HopperFusedMultiHeadAttentionForward:
r = cute.append(cute.append(cute.make_layout(()), lt), ge)
return r
@staticmethod
@cute.jit
def gemm_zero_acc(tiled_mma, A, B, C):
rA = cute.rank(A)
@@ -1606,8 +1608,8 @@ class HopperFusedMultiHeadAttentionForward:
assert 0
@cute.jit
def layout_acc_mn(tiled_mma, acc):
separated = layout_separate(
def layout_acc_mn(self, tiled_mma, acc):
separated = self.layout_separate(
tiled_mma.shape_mnk[0], acc[0], tiled_mma.tv_layout_C.stride[1]
)