v4.3.4 update. (#2892)

This commit is contained in:
Junkai-Wu
2025-12-22 00:49:12 +08:00
committed by GitHub
parent 331e2f451c
commit 7f5fe3edf1
31 changed files with 839 additions and 240 deletions

View File

@@ -2,9 +2,16 @@
# CUTLASS 4.x
## [4.3.3](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.3) (2025-12-12)
## [4.3.4](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.4) (2025-12-22)
* New features
- Added PDL support along with example [Kernel launch with Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py)
### CuTe DSL
* Bug fixing and improvements
- Fixed a frame refcnt issue with cuda graph
- Enhancement for tvm-ffi AoT case for earlier module unload
- Fixed order issue in `make_smem_layout_a` in utils/hopper_helpers.py
## [4.3.3](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.3) (2025-12-12)
* New features
- Supported namedtuple and kwargs for JIT function arguments in tvm-ffi
- Supported variadic tuples for JIT function argument in tvm-ffi
@@ -14,8 +21,6 @@
- Clearer error message for the case of runtime error cudaErrorInsufficientDriver
## [4.3.2](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.2) (2025-12-05)
### CuTe DSL
* New features
- New env var `CUTE_DSL_CACHE_DIR` to specify the path for dumping caches

View File

@@ -1,9 +1,9 @@
![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# Overview
# CUTLASS 4.3.3
# CUTLASS 4.3.4
_CUTLASS 4.3.3 - Dec 2025_
_CUTLASS 4.3.4 - Dec 2025_
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
and related computations at all levels and scales within CUDA. It incorporates strategies for
@@ -56,6 +56,7 @@ To get started quickly - please refer :
- New env var `CUTE_DSL_CACHE_DIR` to specify the path for dumping caches.
- Supported namedtuple and kwargs for JIT function arguments in tvm-ffi.
- Supported variadic tuples for JIT function argument in tvm-ffi.
- Added PDL support along with example [Kernel launch with Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py)
* Debuggability improvements:
- Supported source location tracking for DSL APIs (Allow tools like ``nsight`` profiling to correlate perf metrics with Python source code)
- Supported dumping PTX and CUBIN code: [Hello World Example](https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/notebooks/hello_world.ipynb)
@@ -106,6 +107,9 @@ To get started quickly - please refer :
- Fixed an issue of allocating max smem when there's statically allocated smem
- Fixed an issue when JIT function argument with union type annotation for tvm-ffi
- Clearer error message for the case of runtime error cudaErrorInsufficientDriver
- Fixed a frame refcnt issue with cuda graph
- Enhancement for tvm-ffi AoT case for earlier module unload
- Fixed order issue in make_smem_layout_a in utils/hopper_helpers.py
## CUTLASS C++
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).

View File

@@ -401,11 +401,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
d_per_blk = 128
d_blks = cute.ceil_div(d, d_per_blk)
reduction(
o, m, l,
o_partial, m_partial, l_partial,
scale_o
).launch(
self.reduction(o, m, l, o_partial, m_partial, l_partial, scale_o).launch(
grid=[d_blks, h_q, b],
block=[d_per_blk, 1, 1],
cluster=[1, 1, 1],
@@ -1173,7 +1169,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
# Reduce colmax in smem
if lane_store_max:
smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane)
self.smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane)
# Wait for colmax then load
cute.arch.barrier(barrier_id=softmax_nbar_id, number_of_threads=warpgroup_threads)
@@ -1259,7 +1255,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
# Reduce cluster colmax
if warpgroup_widx == 0:
if lane_store_max:
dsmem_fmax(
self.dsmem_fmax(
sM_cluster.iterator + sM_layout((0, lane_idx)),
sM[(0, lane_idx)],
m_cluster_full_ptr
@@ -1280,7 +1276,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
else:
# other splits copy cluster colmax into local smem
if lane_store_max:
sM_cluster[0, lane_idx] = dsmem_load(
sM_cluster[0, lane_idx] = self.dsmem_load(
sM_cluster.iterator + sM_layout((0, lane_idx))
)
@@ -1300,8 +1296,10 @@ class MixedInputFusedMultiHeadAttentionDecode:
sM_lane = sM_lane * scale_qs
# Store colsum and colmax
gmem_fadd(gL_partial.iterator + gL_partial.layout(lane_idx), sL_lane)
gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
self.gmem_fadd(
gL_partial.iterator + gL_partial.layout(lane_idx), sL_lane
)
self.gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
if kv_split_in_cluster == 0:
gM_partial[lane_idx] = sM_lane
@@ -1350,7 +1348,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
# Store colsum and colmax
gL_partial[lane_idx] = sL_lane
gM_partial[lane_idx] = sM_lane
gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
self.gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
o_handle = o_consumer.wait_and_advance()
cute.copy(thr_load_s, tStO, tSrO)
@@ -1378,6 +1376,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
return
@staticmethod
@cute.kernel
def reduction(
o : cute.Tensor,
@@ -1414,6 +1413,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
return
@staticmethod
@cute.jit
def _mapa(ptr : Pointer, cta_rank_in_cluster : Int32 = 0):
llvm_ptr = ptr.llvm_ptr
@@ -1424,8 +1424,8 @@ class MixedInputFusedMultiHeadAttentionDecode:
)
@cute.jit
def dsmem_load(val_ptr : Pointer):
val_llvm_ptr = _mapa(val_ptr, 0)
def dsmem_load(self, val_ptr: Pointer):
val_llvm_ptr = self._mapa(val_ptr, 0)
ret = llvm.inline_asm(
Float32.mlir_type,
@@ -1439,6 +1439,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
return Float32(ret)
@staticmethod
@cute.jit
def warp_fmax(val : Float32):
ret = llvm.inline_asm(
@@ -1472,10 +1473,10 @@ class MixedInputFusedMultiHeadAttentionDecode:
)
@cute.jit
def dsmem_fmax(val_ptr : Pointer, val : Float32, mbar_ptr : Pointer):
def dsmem_fmax(self, val_ptr: Pointer, val: Float32, mbar_ptr: Pointer):
expect_tx_bytes = Int32(Float32.width // 8)
val_llvm_ptr = _mapa(val_ptr, 0)
mbar_llvm_ptr = _mapa(mbar_ptr, 0)
val_llvm_ptr = self._mapa(val_ptr, 0)
mbar_llvm_ptr = self._mapa(mbar_ptr, 0)
nvvm.mbarrier_txn(
mbar_llvm_ptr,
@@ -1499,6 +1500,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
asm_dialect=llvm.AsmDialect.AD_ATT,
)
@staticmethod
@cute.jit
def gmem_fmax(ptr : Pointer, val : Float32):
llvm.inline_asm(
@@ -1529,6 +1531,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
asm_dialect=llvm.AsmDialect.AD_ATT,
)
@staticmethod
@cute.jit
def gmem_fadd(ptr : Pointer, val : Float32):
llvm.inline_asm(

View File

@@ -0,0 +1,368 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
from cutlass.cute.runtime import from_dlpack
def supports_pdl():
import torch
return torch.cuda.get_device_capability()[0] >= 9
"""
This example demonstrates the use of Programmatic Dependent Launch (PDL) using
CuTe DSL.
PDL is a mechanism which allows for overlapping execution of back-to-back kernels
within the same stream.
For example, consider the following two elementwise add operations, where the second
operation's first operand is the result of the first operation. While performing
``w = u + v`` we will load u and v, add them, and then store the result. Once we
have finished loading data, we are no longer utilizing the read bandwidth.
To effectively utilize the read bandwidth, we can start loading ``x``
immediately upon finishing reading. This is what PDL enables us to do.
.. code-block:: bash
w = u + v
y = w + x
To enable PDL, we need to do two things:
1. Insert the ``griddepcontrol.launch_dependents`` and ``griddepcontrol.wait`` instructions in the kernel.
2. Set the PDL launch attribute when launching the kernel.
The ``griddepcontrol.launch_dependents`` and ``griddepcontrol.wait``
instructions enable fine-grained control over kernel execution in PDL.
Once all thread blocks execute the ``griddepcontrol.launch_dependents``
instruction, the dependent kernels can opportunistically be early-launched.
``griddepcontrol.wait`` functions as a synchronization barrier - any warp
executing this instruction will block until the previous kernel finishes
execution. This allows precise control over data dependencies between kernels.
The following diagram shows the overlapping execution of two dependent kernels.
We call the instructions before ``griddepcontrol.wait`` as prologue (``P0``),
which may include barrier initialization and loading of independent data, etc.
We call the instructions after ``griddepcontrol.launch_dependents`` as epilogue
(``P2``), which may include math operations, data stores, etc. PDL enables
these prologue and epilogue phases to execute concurrently across dependent
kernels, improving GPU resource utilization. This is particularly beneficial
when prologue and epilogue are bound by different resources (e.g., memory
bandwidth vs compute throughput).
# P0: Prologue, P1: Main compute, P2: Epilogue
P0 P1 P2
K1: |=====|+++++|-----|
<-----> K2 can start early
(K1's P2 overlaps with K2's P0)
P0 P1 P2
K2: |=====| |+++++|-----|
^
|
wait for K1 to complete
Time ------------------------------------------------------>
We could run this example with and without PDL:
.. code-block:: bash
python examples/blackwell/programmatic_dependent_launch.py --benchmark
python examples/blackwell/programmatic_dependent_launch.py --benchmark --use_pdl
From the benchmark results, you can see some speedups for the PDL version in most cases, benefiting from
the overlapping execution of consecutive kernels. Moreover, you can use nsys to observe the overlapping execution.
.. code-block:: bash
nsys profile python examples/blackwell/programmatic_dependent_launch.py --benchmark --use_pdl
Note, PDL feature is supported on Hopper and later GPUs.
See [the programming guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization)
and the [PTX documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol)
for more details.
"""
@cute.kernel
def elementwise_add_kernel(
gA: cute.Tensor,
gB: cute.Tensor,
gC: cute.Tensor,
cC: cute.Tensor, # coordinate tensor
shape: cute.Shape,
thr_layout: cute.Layout,
val_layout: cute.Layout,
is_first_kernel: cutlass.Constexpr = True,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
blk_coord = ((None, None), bidx)
blkA = gA[blk_coord] # (TileM,TileN)
blkB = gB[blk_coord] # (TileM,TileN)
blkC = gC[blk_coord] # (TileM,TileN)
blkCrd = cC[blk_coord] # (TileM, TileN)
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type)
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout)
thr_copy_A = tiled_copy_A.get_slice(tidx)
thr_copy_B = tiled_copy_B.get_slice(tidx)
thr_copy_C = tiled_copy_C.get_slice(tidx)
thrA = thr_copy_A.partition_S(blkA)
thrB = thr_copy_B.partition_S(blkB)
thrC = thr_copy_C.partition_S(blkC)
frgA = cute.make_fragment_like(thrA)
frgB = cute.make_fragment_like(thrB)
frgC = cute.make_fragment_like(thrC)
thrCrd = thr_copy_C.partition_S(blkCrd)
frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean)
for i in range(cute.size(frgPred)):
val = cute.elem_less(thrCrd[i], shape)
frgPred[i] = val
# Note: when not using cuda-graph, the kernel execution may be blocked by the host overhead.
# In this case we won't see overlapping even when pdl is enabled.
# In this example, we add a loop (10 times) for all the copy and compute operations in the following code
# to make kernel running longer and make pdl benefits observable for both cuda-graph enabled and disabled cases.
if is_first_kernel:
for _ in range(10):
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
# Here we add the launch dependents instruction for the first kernel as a hint to the runtime to early-launch
# the next kernel. If the next kernel becomes concurrent, we will have overlap where the second kernel
# can start reading x to ensure an E2E speedup. Note the placement of launch dependents has no implication
# on correctness, only performance.
cute.arch.griddepcontrol_launch_dependents()
else:
# In this example, the second kernel's second operand ``gB`` has no dependencies, its loading can overlap
# with the computation of ``gC`` from the first kernel.
for _ in range(10):
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
# For the second kernel, its first operand ``gA`` is dependent on the previous kernel, we must call
# griddepcontrol.wait to assure correctness. This instruction will block until the prior kernels finishes
# and its memory operations are visible. Since gA is written by the prior kernel, this will block until gA
# is visible to our kernel. Without it, we would have undefined behavior due to a race condition.
cute.arch.griddepcontrol_wait()
for _ in range(10):
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
for _ in range(10):
result = frgA.load() + frgB.load()
frgC.store(result)
cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)
@cute.jit
def elementwise_add(
mA,
mB,
mC,
stream: cuda.CUstream,
use_pdl: cutlass.Constexpr = True,
is_first_kernel: cutlass.Constexpr = True,
):
dtype = mA.element_type
# copy_bits for a thread is 128 bits, and we use 128 // dtype.width to get the vector size
vector_size = 128 // dtype.width
thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0))
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN))
gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN))
gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN))
idC = cute.make_identity_tensor(mC.shape)
cC = cute.zipped_divide(idC, tiler=tiler_mn)
elementwise_add_kernel(
gA, gB, gC, cC, mC.shape, thr_layout, val_layout, is_first_kernel
).launch(
grid=[cute.size(gC, mode=[1]), 1, 1],
block=[cute.size(tv_layout, mode=[0]), 1, 1],
stream=stream,
use_pdl=use_pdl,
)
def run_pdl_example(
M,
N,
skip_ref_check=False,
benchmark=False,
warmup_iterations=5,
iterations=100,
use_pdl=True,
):
import torch
if not torch.cuda.is_available():
raise RuntimeError("Blackwell/Hopper GPU is required to run this example!")
print("\nRunning Elementwise Add test with:")
print(f"Tensor dimensions: [{M}, {N}]")
print(f"Use PDL: {use_pdl}")
u = torch.randn(M, N, dtype=torch.float32, device="cuda")
v = torch.randn(M, N, dtype=torch.float32, device="cuda")
w = torch.randn(M, N, dtype=torch.float32, device="cuda")
x = torch.randn(M, N, dtype=torch.float32, device="cuda")
y = torch.empty(M, N, dtype=torch.float32, device="cuda")
u_tensor = from_dlpack(u).mark_layout_dynamic()
v_tensor = from_dlpack(v).mark_layout_dynamic()
w_tensor = from_dlpack(w).mark_layout_dynamic()
x_tensor = from_dlpack(x).mark_layout_dynamic()
y_tensor = from_dlpack(y).mark_layout_dynamic()
stream = torch.cuda.Stream()
current_stream = cuda.CUstream(stream.cuda_stream)
# Since is_first_kernel is cutlass.Constexpr, we need to compile for
# the first and second kernel separately.
compiled_func_first_kernel = cute.compile(
elementwise_add,
u_tensor,
v_tensor,
w_tensor,
current_stream,
use_pdl,
is_first_kernel=True,
options="--enable-tvm-ffi",
)
compiled_func_second_kernel = cute.compile(
elementwise_add,
w_tensor,
x_tensor,
y_tensor,
current_stream,
use_pdl,
is_first_kernel=False,
options="--enable-tvm-ffi",
)
# launch and run the two consecutive kernels in a same stream.
def run_func(current_stream, u, v, w, x, y):
# Run first operation: w_tensor = u_tensor + v_tensor
compiled_func_first_kernel(
u,
v,
w,
current_stream,
)
# Run second operation: y_tensor = w_tensor + x_tensor
# its first operand ``w_tensor`` is the result of the first operation,
# they use the same memory space.
compiled_func_second_kernel(
w,
x,
y,
current_stream,
)
if not skip_ref_check:
run_func(current_stream, u, v, w, x, y)
print("Verifying results...")
torch.testing.assert_close(u.cpu() + v.cpu() + x.cpu(), y.cpu())
print("Results verified successfully!")
if not benchmark:
return
def generate_kernel_arguments():
u = torch.randn(M, N, dtype=torch.float32, device="cuda")
v = torch.randn(M, N, dtype=torch.float32, device="cuda")
w = torch.randn(M, N, dtype=torch.float32, device="cuda")
x = torch.randn(M, N, dtype=torch.float32, device="cuda")
y = torch.empty(M, N, dtype=torch.float32, device="cuda")
return testing.JitArguments(current_stream, u, v, w, x, y)
avg_time_us = testing.benchmark(
run_func,
workspace_generator=generate_kernel_arguments,
workspace_count=10,
warmup_iterations=warmup_iterations,
iterations=iterations,
stream=current_stream,
use_cuda_graphs=True,
)
print(f"Execution time: {avg_time_us:.4f} us")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="example of Programmatic Dependent Launch (PDL) using CuTe DSL"
)
parser.add_argument("--M", default=256, type=int)
parser.add_argument("--N", default=256, type=int)
parser.add_argument("--warmup_iterations", default=5, type=int)
parser.add_argument("--iterations", default=10, type=int)
parser.add_argument("--skip_ref_check", action="store_true")
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--use_pdl", action="store_true")
args = parser.parse_args()
if supports_pdl():
run_pdl_example(
args.M,
args.N,
skip_ref_check=args.skip_ref_check,
benchmark=args.benchmark,
warmup_iterations=args.warmup_iterations,
iterations=args.iterations,
use_pdl=args.use_pdl,
)
print("\nPASS")
else:
print(
"PDL is not supported on this device, it requires Hopper or newer generations"
)

View File

@@ -839,7 +839,7 @@ class HopperFusedMultiHeadAttentionForward:
)
s_max_layout = cute.make_layout(
cute.size(layout_acc_mn(pv_tiled_mma, acc_pv.layout), mode=[0])
cute.size(self.layout_acc_mn(pv_tiled_mma, acc_pv.layout), mode=[0])
)
s_max = cute.make_rmem_tensor_like(s_max_layout, self.qk_acc_dtype)
a_sum = cute.make_rmem_tensor_like(s_max, cutlass.Float32)
@@ -888,7 +888,7 @@ class HopperFusedMultiHeadAttentionForward:
# MMA QK
cute.nvgpu.warpgroup.fence()
gemm_zero_acc(
self.gemm_zero_acc(
qk_tiled_mma,
tSrQ[(None, None, None, q_handle.index)],
tSrK[(None, None, None, k_handle.index)],
@@ -901,7 +901,7 @@ class HopperFusedMultiHeadAttentionForward:
# Wait for the pipeline MMAs to drain
cute.nvgpu.warpgroup.wait_group(0)
s_max, a_sum = softmax_step(
s_max, a_sum = self.softmax_step(
True,
self.mask_type,
acc_qk,
@@ -919,7 +919,7 @@ class HopperFusedMultiHeadAttentionForward:
True,
)
acc_qk_fixed = make_acc_into_op(
acc_qk_fixed = self.make_acc_into_op(
acc_qk, pv_tiled_mma.tv_layout_A, self.q_dtype
)
@@ -928,7 +928,7 @@ class HopperFusedMultiHeadAttentionForward:
# MMA PV
cute.nvgpu.warpgroup.fence()
gemm_zero_acc(
self.gemm_zero_acc(
pv_tiled_mma,
acc_qk_fixed,
tOrV[(None, None, None, v_handle.index)],
@@ -1040,7 +1040,7 @@ class HopperFusedMultiHeadAttentionForward:
cute.nvgpu.warpgroup.wait_group(0)
# acc_pv updated
lse = tail(
lse = self.tail(
s_max, a_sum, acc_pv, pv_tiled_mma, scale_softmax, scale_output
)
@@ -1077,10 +1077,10 @@ class HopperFusedMultiHeadAttentionForward:
if tOcO[0][1] == 0:
tOgLSE_mn = cute.make_tensor(
tOgLSE.iterator, layout_acc_mn(pv_tiled_mma, tOgLSE.layout)
tOgLSE.iterator, self.layout_acc_mn(pv_tiled_mma, tOgLSE.layout)
)
tOcO_mn = cute.make_tensor(
tOcO.iterator, layout_acc_mn(pv_tiled_mma, tOcO.layout)
tOcO.iterator, self.layout_acc_mn(pv_tiled_mma, tOcO.layout)
)
for i in cutlass.range_constexpr(cute.size(tOgLSE_mn, mode=[0])):
if (
@@ -1241,7 +1241,7 @@ class HopperFusedMultiHeadAttentionForward:
# MMA QK
cute.nvgpu.warpgroup.fence()
gemm_zero_acc(
self.gemm_zero_acc(
qk_tiled_mma,
tSrQ[(None, None, None, q_handle.index)],
tSrK[(None, None, None, k_handle.index)],
@@ -1255,7 +1255,7 @@ class HopperFusedMultiHeadAttentionForward:
# Wait for the pipeline MMAs to drain
cute.nvgpu.warpgroup.wait_group(0)
s_max, a_sum = softmax_step(
s_max, a_sum = self.softmax_step(
fusion,
self.mask_type,
acc_qk,
@@ -1272,7 +1272,7 @@ class HopperFusedMultiHeadAttentionForward:
window_size_right,
)
acc_qk_fixed = make_acc_into_op(
acc_qk_fixed = self.make_acc_into_op(
acc_qk, pv_tiled_mma.tv_layout_A, self.q_dtype
)
@@ -1300,6 +1300,7 @@ class HopperFusedMultiHeadAttentionForward:
@cute.jit
def softmax_step(
self,
fusion: bool,
mask_type: fmha_utils.MaskEnum,
acc_qk: cute.ThrMma,
@@ -1328,10 +1329,10 @@ class HopperFusedMultiHeadAttentionForward:
)
acc_qk_mn = cute.make_tensor(
acc_qk.iterator, layout_acc_mn(tiled_mma_qk, acc_qk.layout)
acc_qk.iterator, self.layout_acc_mn(tiled_mma_qk, acc_qk.layout)
)
reduction_target_qk = reduction_target_n(tiled_mma_qk)
reduction_target_qk = self.reduction_target_n(tiled_mma_qk)
red_rank = cute.rank(reduction_target_qk)
s_max_prev = None
@@ -1346,7 +1347,7 @@ class HopperFusedMultiHeadAttentionForward:
s_max[i] = cute.arch.fmax(s_max[i], acc_qk_mn[i, j])
else:
acc_pv_mn = cute.make_tensor(
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
acc_pv.iterator, self.layout_acc_mn(tiled_mma_pv, acc_pv.layout)
)
s_max_prev = cute.make_rmem_tensor_like(s_max, s_max._dtype)
@@ -1396,15 +1397,15 @@ class HopperFusedMultiHeadAttentionForward:
return s_max, a_sum
@cute.jit
def reduction_target_n(tiled_mma):
separated = layout_separate(
def reduction_target_n(self, tiled_mma):
separated = self.layout_separate(
tiled_mma.shape_mnk[0],
cute.make_layout(tiled_mma.tv_layout_C.shape[0]),
tiled_mma.tv_layout_C.stride[0],
)
return separated[1]
@cute.jit
@staticmethod
def convert_c_layout_to_a_layout(c, a):
return cute.make_layout(
(a, c.shape[1], (c.shape[2], cute.size(c, mode=[0]) // cute.size(a))),
@@ -1416,9 +1417,9 @@ class HopperFusedMultiHeadAttentionForward:
)
@cute.jit
def make_acc_into_op(acc, operand_layout_tv, Element):
def make_acc_into_op(self, acc, operand_layout_tv, Element):
operand = cute.make_rmem_tensor_like(
convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]),
self.convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]),
Element,
)
operand_as_acc = cute.make_tensor(operand.iterator, acc.layout)
@@ -1499,7 +1500,7 @@ class HopperFusedMultiHeadAttentionForward:
return operand
@cute.jit
def tail(s_max, a_sum, acc_pv, tiled_mma_pv, scale_softmax, scale_output):
def tail(self, s_max, a_sum, acc_pv, tiled_mma_pv, scale_softmax, scale_output):
"""
Final processing step for FMHA that computes log-sum-exp (LSE) and scales the output.
@@ -1527,9 +1528,9 @@ class HopperFusedMultiHeadAttentionForward:
"""
# Create tensor view of accumulated P*V values with M*N layout
acc_pv_mn = cute.make_tensor(
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
acc_pv.iterator, self.layout_acc_mn(tiled_mma_pv, acc_pv.layout)
)
reduction_target = reduction_target_n(tiled_mma_pv)
reduction_target = self.reduction_target_n(tiled_mma_pv)
red_rank = cute.rank(reduction_target)
for r in cutlass.range_constexpr(red_rank):
for i in cutlass.range_constexpr(cute.size(acc_pv_mn, mode=[0])):
@@ -1538,7 +1539,7 @@ class HopperFusedMultiHeadAttentionForward:
)
acc_mn = cute.make_tensor(
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
acc_pv.iterator, self.layout_acc_mn(tiled_mma_pv, acc_pv.layout)
)
lse = cute.make_rmem_tensor_like(a_sum, a_sum._dtype)
@@ -1559,7 +1560,7 @@ class HopperFusedMultiHeadAttentionForward:
return lse
@cute.jit
@staticmethod
def layout_separate(thr, src, ref):
lt = cute.make_layout(())
ge = cute.make_layout(())
@@ -1577,6 +1578,7 @@ class HopperFusedMultiHeadAttentionForward:
r = cute.append(cute.append(cute.make_layout(()), lt), ge)
return r
@staticmethod
@cute.jit
def gemm_zero_acc(tiled_mma, A, B, C):
rA = cute.rank(A)
@@ -1606,8 +1608,8 @@ class HopperFusedMultiHeadAttentionForward:
assert 0
@cute.jit
def layout_acc_mn(tiled_mma, acc):
separated = layout_separate(
def layout_acc_mn(self, tiled_mma, acc):
separated = self.layout_separate(
tiled_mma.shape_mnk[0], acc[0], tiled_mma.tv_layout_C.stride[1]
)

View File

@@ -36,7 +36,7 @@
#define CUTLASS_MAJOR 4
#define CUTLASS_MINOR 3
#define CUTLASS_PATCH 3
#define CUTLASS_PATCH 4
#ifdef CUTLASS_VERSIONS_GENERATED
#include "cutlass/version_extended.h"

View File

@@ -470,7 +470,56 @@ class DSLPreprocessor(ast.NodeTransformer):
names=[ast.alias(name=f"{top_module_name}.base_dsl", asname="_dsl_")]
)
)
transformed_tree.body = import_stmts + transformed_tree.body
assert len(transformed_tree.body) == 1
assert isinstance(transformed_tree.body[0], ast.FunctionDef)
transformed_tree.body[0].body = import_stmts + transformed_tree.body[0].body
# Remove all decorators from top level function
transformed_tree.body[0].decorator_list = []
# Step 4. Wrap the function with nonlocal captures, if has any
# if the function has a nonlocal variable, wrap it in a function and return the function
# pseudo code:
# def foo():
# nonlocal_var_0 = None
# nonlocal_var_1 = None
# def foo(args):
# ...
# return foo
# foo = foo()
nonlocals = {v: None for v in function_pointer.__code__.co_freevars}
if len(nonlocals) > 0:
assignments = []
for n, _ in nonlocals.items():
assignments.append(
ast.Assign(
targets=[ast.Name(id=n, ctx=ast.Store())],
value=ast.Constant(value=None),
)
)
return_expr = [ast.Return(value=ast.Name(id=func_name, ctx=ast.Load()))]
wrapper_fcn = ast.FunctionDef(
name=func_name,
args=ast.arguments(
posonlyargs=[],
args=[],
kwonlyargs=[],
kw_defaults=[],
defaults=[],
),
body=assignments + transformed_tree.body + return_expr,
decorator_list=[],
)
invoke = ast.Call(
func=ast.Name(id=func_name, ctx=ast.Load()), args=[], keywords=[]
)
assign = ast.Assign(
targets=[ast.Name(id=func_name, ctx=ast.Store())], value=invoke
)
transformed_tree.body = [wrapper_fcn, assign]
# Step 4. Import cutlass and base_dsl
ast.fix_missing_locations(transformed_tree)
@@ -1521,6 +1570,15 @@ class DSLPreprocessor(ast.NodeTransformer):
self.scope_manager.add_to_scope(node.name)
for arg in node.args.args:
self.scope_manager.add_to_scope(arg.arg)
arg.annotation = None
for arg in node.args.kwonlyargs:
self.scope_manager.add_to_scope(arg.arg)
arg.annotation = None
for arg in node.args.posonlyargs:
self.scope_manager.add_to_scope(arg.arg)
arg.annotation = None
self.generic_visit(node)

View File

@@ -622,18 +622,14 @@ class CompileCallable:
func,
)
# If it's a wrapped function created by jit decorator, get the original function
if hasattr(func, "__wrapped__"):
# If it's a wrapped function created by decorators, get the original function
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
# Lazy initialization of DSL object if has not been initialized
# Use local import to avoid circular import
from .dsl import BaseDSL
BaseDSL._lazy_initialize_dsl(func)
if not hasattr(func, "_dsl_object"):
raise DSLRuntimeError("Function is not decorated with jit decorator.")
raise DSLRuntimeError(
f"Function {func} is not decorated with jit decorator."
)
# process compile options, extract the options and remove them from the kwargs
options = kwargs.pop("options", None)
@@ -645,8 +641,4 @@ class CompileCallable:
else:
compile_options = self._compile_options
func._dsl_object.compile_options = compile_options
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
if hasattr(func, "_decorator_frame"):
kwargs["_decorator_frame"] = func._decorator_frame
return func._dsl_object._func(fcn_ptr, *args, **kwargs)
return func._dsl_object._func(func, *args, **kwargs)

View File

@@ -31,9 +31,10 @@ import weakref
from functools import lru_cache, wraps
from collections import namedtuple, OrderedDict
from abc import ABC, abstractmethod
from typing import Any, Callable, List
from typing import Any, Callable, List, ClassVar
from types import SimpleNamespace
import warnings
import threading
from . import typing as t
from .env_manager import EnvironmentVarManager
@@ -228,49 +229,92 @@ def new_from_mlir_values(obj, values):
assert len(values) == 0, f"{obj} expects 0 values, but got {values}"
return obj
class DSLCallable:
@dataclass(frozen=True)
class DSLLocation:
"""
Wrapper class for a callable object used within the DSL.
DSLCallable is designed to wrap a function and provide additional
introspection utilities such as retrieving the argument specification
and signature. It ensures that the wrapped function can only be called
once, after which the reference to the function is cleared to prevent
further invocations. This is useful in scenarios where a function should
only be executed a single time within the DSL's execution model.
Represents Python source location information for MLIR DSL code.
Attributes:
func (callable): The function to be wrapped and managed.
filename (str): Name of the Python source file.
lineno (int): Line number in the source file.
col_offset (int): Column offset in the source line.
function_name (str): Name of the function in which the location occurs.
Methods:
__call__(*args, **kwargs): Calls the wrapped function and clears it.
This is used primarily to annotate or trace locations in generated MLIR IR
back to the original Python code for better diagnostic and debugging.
"""
def __init__(self, func):
self.func = func
self.name = func.__name__
def __call__(self, *args, **kwargs):
ret = self.__func__(*args, **kwargs)
self.func = None
return ret
@property
def __func__(self):
assert self.func is not None, "DSLCallable is already called"
return self.func
@property
def __signature__(self):
return inspect.signature(self.__func__)
@property
def __name__(self):
return self.name
filename: str
lineno: int
col_offset: int
function_name: str
class BaseDSL:
@dataclass
class PreprocessSessionData:
"""
Holds metadata and transformed AST related to a DSL preprocessing session.
Attributes:
decorator_globals (dict): The global variables from the decorator's environment,
captured for possible AST or code evaluation during preprocessing.
"""
decorator_globals: dict
class DSLSingletonMeta(type):
"""
Metaclass implementing the Singleton pattern for DSL classes.
The DSLSingletonMeta ensures that only one instance of a derived DSL class exists at any time.
When a class is called, it checks if an instance already exists in the `_instances` dictionary.
- If requesting `BaseDSL` itself, it asserts that a concrete subclass has been initialized,
and returns the first available singleton instance among subclasses.
- If requesting a concrete subclass, it creates a new instance if none exists, or returns
the already created instance.
This metaclass is useful for maintaining global state and configuration across the DSL system,
ensuring that all parts of the application operate on the same DSL instance.
Attributes:
_instances (dict): Maps DSL classes to their singleton instances.
Example:
class MyDSL(BaseDSL): ...
dsl1 = MyDSL()
dsl2 = MyDSL()
assert dsl1 is dsl2 # Singleton property
"""
_instances: ClassVar[dict] = {}
_lock: ClassVar[threading.Lock] = threading.Lock()
def __call__(cls, *args, **kwargs):
with cls._lock:
log().info(f"DSLSingletonMeta __call__ for {cls}")
if cls is BaseDSL:
# If one is querying a BaseDSL which is abstract, returns an arbitrary instance of a concrete subclass should be fine.
# Here we just return the first instance of a concrete subclass.
assert cls._instances, (
"Need to initialize a concrete subclass of BaseDSL first"
)
return next(iter(cls._instances.values()))
elif cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
log().info(f"Active DSL singleton instances: {cls._instances}")
return cls._instances[cls]
def clear_instances(cls):
log().info(
f"Clearing DSL singleton instances for {cls}, current instances: {cls._instances}"
)
if cls in cls._instances:
del cls._instances[cls]
log().info(f"DSL singleton instances after clearing: {cls._instances}")
class BaseDSL(metaclass=DSLSingletonMeta):
gpu_module = None
_env_class = EnvironmentVarManager
@@ -310,7 +354,8 @@ class BaseDSL:
self.name = name
self.compiler_provider = compiler_provider
self.pass_sm_arch_name = pass_sm_arch_name
self.frame = None
self.preprocess_session_data = None
self.decorator_location = None
self.no_cache = False
self.device_compilation_only = device_compilation_only
self.num_kernels = 0
@@ -379,7 +424,6 @@ class BaseDSL:
warnings.warn(message, UserWarning)
@classmethod
@lru_cache(maxsize=1)
def _get_dsl(cls):
# Instantiate the DSL Class once
main_dsl = cls()
@@ -414,38 +458,22 @@ class BaseDSL:
return fcn_ptr
@staticmethod
def _preprocess_and_execute(func):
def _preprocess_and_replace_code(func):
"""
Run ast transformation and return the materialized function pointer
"""
# Lazy initialization of DSL object if has not been initialized
if not hasattr(func, "_dsl_object"):
func._dsl_object = func._dsl_cls._get_dsl()
delattr(func, "_dsl_cls")
if not func._dsl_object.enable_preprocessor:
if hasattr(func, "_decorator_frame"):
delattr(func, "_decorator_frame")
if hasattr(func, "_transformed_ast"):
delattr(func, "_transformed_ast")
return func
if hasattr(func, "_transformed_ast"):
if hasattr(func, "_preprocess_session_data"):
# If the function ptr is already materialized, use the existing one
func._dsl_object.frame = func._decorator_frame
if func._transformed_ast is None:
func._transformed_ast = func._dsl_object.run_preprocessor(func)
if func._transformed_ast is None:
del func._transformed_ast
func._dsl_object.frame = None
return func
fcn_ptr = func._dsl_object.get_function_ptr(func)
# If the function is decorated, de-decorate it
fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__)
func._dsl_object.frame = None
return DSLCallable(fcn_ptr)
func._dsl_object.preprocess_session_data = func._preprocess_session_data
func._dsl_object.decorator_location = func._decorator_location
transformed_ast = func._dsl_object.run_preprocessor(func)
fcn_ptr = func._dsl_object.get_function_ptr(func, transformed_ast)
func.__code__ = (
fcn_ptr.__code__
if not isinstance(fcn_ptr, staticmethod)
else fcn_ptr.__func__.__code__
)
return func
@staticmethod
@@ -457,20 +485,27 @@ class BaseDSL:
def jit_runner_decorator(func):
# Run preprocessor that alters AST
func._dsl_cls = cls
if BaseDSL._can_preprocess(**dkwargs):
func._dsl_object = cls._get_dsl()
func._decorator_location = BaseDSL.get_location_from_frame(frame)
if (
func._dsl_object.enable_preprocessor
and func._dsl_object._can_preprocess(**dkwargs)
):
# For an annotated function, add some DSL attributes
# When materializing the AST, we need decorator's frame
func._decorator_frame = frame
# No transformed ast at this point
func._transformed_ast = None
func._preprocess_session_data = PreprocessSessionData(
decorator_globals=frame.f_globals,
)
BaseDSL._preprocess_and_replace_code(func)
@wraps(func)
def jit_wrapper(*args, **kwargs):
func_ptr = BaseDSL._preprocess_and_execute(func)
return getattr(func._dsl_object, executor_name)(
func_ptr, *args, **kwargs
)
return getattr(func._dsl_object, executor_name)(func, *args, **kwargs)
def set_name_prefix(name: str):
jit_wrapper._name_prefix = name
jit_wrapper.set_name_prefix = set_name_prefix
return jit_wrapper
@@ -479,15 +514,6 @@ class BaseDSL:
else:
return jit_runner_decorator
@staticmethod
def _lazy_initialize_dsl(func):
"""
Lazy initialization of DSL object if has not been initialized
"""
if hasattr(func, "_dsl_cls"):
func._dsl_object = func._dsl_cls._get_dsl()
delattr(func, "_dsl_cls")
@classmethod
def jit(cls, *dargs, **dkwargs):
"""
@@ -516,6 +542,7 @@ class BaseDSL:
"""
Build the module op that contains the kernels.
"""
log().info(f"[abstract] Building GPU module for {self.name}")
pass
@abstractmethod
@@ -688,9 +715,11 @@ class BaseDSL:
dictionary is used to execute the python code.
"""
all_globals = {}
if self.frame:
all_globals.update(self.frame.f_globals)
all_globals.update(self.frame.f_locals)
if (
self.preprocess_session_data
and self.preprocess_session_data.decorator_globals
):
all_globals.update(self.preprocess_session_data.decorator_globals)
return all_globals
@abstractmethod
@@ -955,25 +984,40 @@ class BaseDSL:
else:
ir._GlobalDebug.set_types(f"diagnostic-{args.diagnostic}")
def get_location(self, frame=None):
"""
Get python location information and generate MLIR location
"""
frame = self.frame if frame is None else frame
frame = inspect.currentframe().f_back if frame is None else frame
@staticmethod
def get_location_from_frame(frame):
frameInfo = inspect.getframeinfo(frame)
file_loc = ir.Location.file(
frame.f_code.co_filename,
frame.f_lineno,
frameInfo.positions.col_offset if hasattr(frameInfo, "positions") else 0,
)
loc = ir.Location.name(
(
return DSLLocation(
filename=frameInfo.filename,
lineno=frameInfo.lineno,
col_offset=(
frameInfo.positions.col_offset if hasattr(frameInfo, "positions") else 0
),
function_name=(
"".join([c.strip() for c in frameInfo.code_context])
if frameInfo.code_context
else frameInfo.function
),
)
def get_ir_location(self, location: DSLLocation = None):
"""
Get python location information and generate MLIR location
"""
if location is None:
if self.decorator_location:
location = self.decorator_location
if location is None:
return ir.Location.unknown()
file_loc = ir.Location.file(
location.filename,
location.lineno,
location.col_offset,
)
loc = ir.Location.name(
(location.function_name),
childLoc=file_loc,
)
return loc
@@ -1140,10 +1184,10 @@ class BaseDSL:
gpu_module_attrs,
args,
args_spec,
frame=None,
location=None,
):
def build_ir_module():
loc = self.get_location(frame)
loc = self.get_ir_location(location)
module = ir.Module.create(loc=loc)
unit_attr = ir.UnitAttr.get()
module.operation.attributes["gpu.container_module"] = unit_attr
@@ -1308,6 +1352,10 @@ class BaseDSL:
self.num_kernels = 0
# reset the compile options after the compilation is done.
self.compile_options = CompileOptions()
# reset preprocess session data after the compilation is done.
self.preprocess_session_data = None
# reset decorator location after the compilation is done.
self.decorator_location = None
def extract_dynamic_args(self, funcBody, args, kwargs, args_spec):
"""This function is used to extract the original dynamic arguments for AOT C header generation.
@@ -1348,11 +1396,10 @@ class BaseDSL:
pipeline,
no_cache,
compile_only,
loc=None,
frame=None,
location=None,
):
"""Generate MLIR module and compile iself.T_provider."""
with ir.Context(), self.get_location(frame):
with ir.Context(), self.get_ir_location(location):
try:
# Convert input arguments to MLIR arguments
exe_args, func_types, adapted_args = self.generate_mlir_function_types(
@@ -1374,7 +1421,7 @@ class BaseDSL:
gpu_module_attrs,
args,
args_spec,
frame=frame,
location=location,
)
# dryrun is used to only generate IR
@@ -1437,11 +1484,14 @@ class BaseDSL:
return transformed_ast
return None
def get_function_ptr(self, original_function):
def get_function_ptr(self, original_function, transformed_ast):
file_name = inspect.getsourcefile(original_function)
code_object = compile(
original_function._transformed_ast, filename=file_name, mode="exec"
transformed_ast,
filename=file_name,
mode="exec",
)
return self.preprocessor.exec(
original_function.__name__,
original_function,
@@ -1523,7 +1573,7 @@ class BaseDSL:
pipeline = kwargs.pop("pipeline", None)
gpu_module_attrs = kwargs.pop("gpu_module_attrs", {})
decorator_frame = kwargs.pop("_decorator_frame", None)
self.decorator_location = getattr(funcBody, "_decorator_location", None)
# Disable cache
no_cache = kwargs.pop("no_cache", False)
@@ -1556,7 +1606,7 @@ class BaseDSL:
function_name = self.mangle_name(function_name, canonicalized_args, args_spec)
self.compile_options.apply_envar_settings(self.envar, function_name)
if not self.compile_options.generate_line_info:
decorator_frame = None
self.decorator_location = None
# Generate MLIR Context and start generating IR
log().debug(f"Generating MLIR for function '{function_name}'")
@@ -1570,7 +1620,7 @@ class BaseDSL:
pipeline,
no_cache,
compile_only,
frame=decorator_frame,
location=self.decorator_location,
)
return result
@@ -1679,8 +1729,7 @@ class BaseDSL:
"""
ret = None
with ir.Context(), self.get_location():
loc = self.get_location()
with ir.Context(), self.get_ir_location() as loc:
module = ir.Module.create(loc=loc)
unit_attr = ir.UnitAttr.get()
module.operation.attributes["gpu.container_module"] = unit_attr
@@ -1819,7 +1868,7 @@ class BaseDSL:
)
)
loc = self.get_location()
loc = self.get_ir_location()
with self._enter_gpu_module():
log().debug("Generating device kernel")
if self.device_compilation_only:

View File

@@ -138,6 +138,7 @@ class MLIRBuilder(MLIRTypeBuilder):
super().__init__()
self.module: Optional[ir.Module] = None
self.const_str_table: dict[str, ir.Value] = {}
self.const_func_ptr_table: dict[str, ir.Value] = {}
self.get_element_extra_kwargs: dict[str, Any] = {}
# create constants
@@ -368,6 +369,64 @@ class MLIRBuilder(MLIRTypeBuilder):
self.const_str_table[content] = symbol
return symbol
def get_or_load_global_func_ptr_from_text(
self,
current_block: ir.Block,
function_name: str,
) -> ir.Value:
"""Get or create a function pointer global in .text section and load it.
This creates a constant global function pointer in the .text section
(for AArch64 ADRP range compatibility) and performs a volatile load
to prevent optimization.
This forces the function pointer to be local to the code, bypassing GOT entry
ADRP lookup issues on AArch64 when GOT and .text section are more than 4GB
apart which can happen when ASLR is applied.
"""
# Check if we've already created this global
if function_name not in self.const_func_ptr_table:
symbol = f"__func_ptr_{function_name}"
module_body = self.module.body
with ir.InsertionPoint(module_body):
# 1. Create the global constant
# We use 'private' linkage so it doesn't conflict across modules
global_ptr = llvm.GlobalOp(
self.ptr_type,
symbol,
ir.Attribute.parse("#llvm.linkage<private>"),
# Initialization via block below
)
# 2. Set the necessary attributes for JIT safety and AArch64 range
# We use 'constant' to mark it as immutable
# We use 'section = ".text"' to force it into the code block
global_ptr.attributes["constant"] = ir.UnitAttr.get()
global_ptr.attributes["section"] = ir.StringAttr.get(".text")
# 3. Add a constructor block to the GlobalOp to initialize it
# with the address of the target function
initializer_block = global_ptr.initializer.blocks.append()
with ir.InsertionPoint(initializer_block):
# Get the address of the external function
func_addr = llvm.AddressOfOp(self.ptr_type, function_name).res
# Return the address as the initial value of the global
llvm.return_(arg=func_addr)
self.const_func_ptr_table[function_name] = symbol
else:
symbol = self.const_func_ptr_table[function_name]
# Load it with volatile semantics in the current block
with ir.InsertionPoint(current_block):
symbol_addr = self.address_of(symbol, self.ptr_type)
# Perform a volatile load to prevent optimization
load_op = llvm.load(self.ptr_type, symbol_addr)
# Set volatile attribute to prevent optimization
load_op.owner.attributes["volatile_"] = ir.UnitAttr.get()
return load_op
# function
def function(
self,

View File

@@ -210,7 +210,7 @@ EnableTVMFFI = _dsl.EnableTVMFFI
# attach the TVM FFI ABI interface postprocessor to the DSL
from . import _tvm_ffi_args_spec_converter
_tvm_ffi_args_spec_converter.attach_args_spec_converter()
_tvm_ffi_args_spec_converter.attach_args_spec_converter(_dsl.CuTeDSL._get_dsl())
# Explicitly export all symbols for documentation generation
__all__ = [

View File

@@ -395,8 +395,6 @@ def _tvm_ffi_args_spec_converter(
return params, kwargs_wrapper_spec
def attach_args_spec_converter():
"""Attach TVM FFI ABI interface postprocessor to the DSL."""
from .. import cutlass_dsl as _dsl
_dsl.CuTeDSL._get_dsl()._tvm_ffi_args_spec_converter = _tvm_ffi_args_spec_converter
def attach_args_spec_converter(dsl):
"""Attach TVM FFI ABI interface postprocessor to the DSL instance."""
dsl._tvm_ffi_args_spec_converter = _tvm_ffi_args_spec_converter

View File

@@ -10,7 +10,7 @@
# is strictly prohibited.
from cutlass.base_dsl.arch import Arch
from cutlass.cutlass_dsl import CuTeDSL, T, dsl_user_op
from cutlass.cutlass_dsl import BaseDSL, T, dsl_user_op
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir.dialects import nvvm, scf
@@ -69,7 +69,7 @@ def elect_one(*, loc=None, ip=None) -> IfOpRegion:
# Only one thread in the warp executes the code in this context
pass
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
is_thread_leader = nvvm.elect_sync(T.bool())
if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip)
return IfOpRegion(if_op.then_block, loc=loc, ip=ip)

View File

@@ -11,7 +11,7 @@
from typing import Optional
from cutlass.base_dsl.arch import Arch
from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op
from cutlass.cutlass_dsl import BaseDSL, T, if_generate, dsl_user_op
from cutlass._mlir.dialects import nvvm
@@ -44,7 +44,7 @@ def mbarrier_init_fence(*, loc=None, ip=None) -> None:
"""
A fence operation that applies to the mbarrier initializations.
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
nvvm.fence_mbarrier_init(loc=loc, ip=ip)
@@ -63,7 +63,7 @@ def mbarrier_arrive_and_expect_tx(
the mbarrier is converted to a remote address in the peer CTA's
SMEM.
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
mbar_llvm_ptr = mbar_ptr.llvm_ptr
if peer_cta_rank_in_cluster is not None:
@@ -103,7 +103,7 @@ def mbarrier_expect_tx(
the mbarrier is converted to a remote address in the peer CTA's
SMEM.
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
mbar_llvm_ptr = mbar_ptr.llvm_ptr
if peer_cta_rank_in_cluster is not None:
@@ -138,7 +138,7 @@ def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None:
:param phase: The phase to wait for (either 0 or 1)
:type phase: Int
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
timeout_ns = 10000000
# This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX
@@ -164,7 +164,7 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo
:return: A boolean value indicating whether the wait operation was successful
:rtype: Boolean
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
return Boolean(
nvvm.mbarrier_wait_parity(
@@ -193,7 +193,7 @@ def mbarrier_conditional_try_wait(
:return: A boolean value indicating whether the wait operation was successful
:rtype: Boolean
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
return if_generate(
cond,
lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip),
@@ -225,7 +225,7 @@ def mbarrier_arrive(
"""
mbar_llvm_ptr = mbar_ptr.llvm_ptr
if peer_cta_rank_in_cluster is not None:
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
mbar_llvm_ptr = nvvm.mapa_shared_cluster(
mbar_llvm_ptr.type,
@@ -259,7 +259,7 @@ def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> N
:param mbar_ptr: A pointer to the mbarrier in SMEM
:type mbar_ptr: Pointer
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
mbar_llvm_ptr = mbar_ptr.llvm_ptr
nvvm.cp_async_mbarrier_arrive_shared(

View File

@@ -12,7 +12,8 @@
from cutlass.base_dsl.arch import Arch
from cutlass.base_dsl.common import DSLRuntimeError
from cutlass.cutlass_dsl import CuTeDSL, dsl_user_op
from cutlass.cutlass_dsl import BaseDSL, dsl_user_op
from cutlass._mlir import ir
from cutlass._mlir.dialects import builtin, arith, llvm, vector
@@ -53,7 +54,7 @@ def cvt_i8_bf16_intrinsic(vec_i8, length, *, loc=None, ip=None):
:return: The output 1D vector of bfloat16 with the same length as the input vector.
:rtype: 1D vector of bfloat16
"""
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if not arch in cvt_i8_bf16_intrinsic.supported_archs:
raise DSLRuntimeError(f"cvt_i8_bf16_intrinsic is not supported on {arch}")
src_pos = 0
@@ -130,7 +131,7 @@ def cvt_i4_bf16_intrinsic(vec_i4, length, *, loc=None, ip=None):
:return: The output 1D vector of bfloat16 with the same length as the input vector.
:rtype: 1D vector of bfloat16
"""
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if not arch in cvt_i4_bf16_intrinsic.supported_archs:
raise DSLRuntimeError(f"cvt_i4_bf16_intrinsic is not supported on {arch}")
src_pos = 0

View File

@@ -1305,6 +1305,46 @@ def exp_packed_f32x2(
return exp2(b[0], loc=loc, ip=ip), exp2(b[1], loc=loc, ip=ip)
@dsl_user_op
def griddepcontrol_wait(*, loc=None, ip=None) -> None:
"""
This instruction is used to wait for the previous kernel's grid ending
(all blocks of the previous kernel have finished and memflushed), i.e.,
the instruction after this instruction will not be issued until the previous
grid has finished.
"""
llvm.inline_asm(
res=None,
operands_=[],
asm_string="griddepcontrol.wait;",
constraints="",
has_side_effects=True,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
@dsl_user_op
def griddepcontrol_launch_dependents(*, loc=None, ip=None) -> None:
"""
Issuing the launch_dependents instruction hints a dependent kernel to launch earlier.
launch_dependents doesn't impact the functionality but the performance:
Launching a dependent kernel too early can compete with current kernels,
while launching too late can lead to a long latency.
"""
llvm.inline_asm(
res=None,
operands_=[],
asm_string="griddepcontrol.launch_dependents;",
constraints="",
has_side_effects=True,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
@dsl_user_op
def cvt_f4e2m1_f16(src, *, loc=None, ip=None):

View File

@@ -15,7 +15,7 @@ from typing import Optional, Type
from cutlass import cute
from cutlass.base_dsl.arch import Arch
from cutlass.cutlass_dsl import CuTeDSL
from cutlass.cutlass_dsl import BaseDSL
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
@@ -146,7 +146,7 @@ class CopyBulkTensorTileG2SOp(TmaCopyOp):
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
)
# Arch verification
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
if not arch >= Arch.sm_90:
raise OpError(
self,
@@ -263,7 +263,7 @@ class CopyBulkTensorTileG2SMulticastOp(TmaCopyOp):
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
)
# Arch verification
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if not arch >= Arch.sm_90:
raise OpError(
self,
@@ -386,7 +386,7 @@ class CopyBulkTensorTileS2GOp(TmaCopyOp):
def __post_init__(self):
# Arch verification
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if not arch >= Arch.sm_90:
raise OpError(
self,
@@ -561,7 +561,7 @@ class CopyBulkG2SOp(CopyOp):
def __post_init__(self) -> None:
# Arch verification
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
if not arch >= Arch.sm_90:
raise OpError(
self,
@@ -646,7 +646,7 @@ class CopyBulkG2SMulticastOp(CopyOp):
def __post_init__(self) -> None:
# Arch verification
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
if not arch >= Arch.sm_90:
raise OpError(
self,
@@ -740,7 +740,7 @@ class CopyBulkS2GOp(CopyOp):
def __post_init__(self) -> None:
# Arch verification
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
if not arch >= Arch.sm_90:
raise OpError(
self,

View File

@@ -15,7 +15,7 @@ from typing import Type
from cutlass import cute
from cutlass.base_dsl.arch import Arch
from cutlass.cutlass_dsl import CuTeDSL
from cutlass.cutlass_dsl import BaseDSL
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
@@ -113,7 +113,7 @@ class _LdBase(CopyOp):
:raises OpError: If pack parameter is not a Pack instance
"""
# Arch verification
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if arch not in self.admissible_archs:
raise OpError(
self,
@@ -416,7 +416,7 @@ class _StBase(CopyOp):
def __post_init__(self) -> None:
# Arch verification
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if arch not in self.admissible_archs:
raise OpError(
self,
@@ -625,7 +625,7 @@ class _S2TCopyBase(CopyOp):
def __post_init__(self) -> None:
# Arch verification
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if not arch.is_family_of(Arch.sm_100f):
raise OpError(
self,

View File

@@ -15,7 +15,7 @@ from typing import Type, Any
from cutlass import cute
from cutlass.base_dsl.arch import Arch
from cutlass.cutlass_dsl import CuTeDSL, T
from cutlass.cutlass_dsl import BaseDSL, T
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
@@ -162,7 +162,7 @@ class MmaOp(Tcgen05MmaOp):
def __post_init__(self) -> None:
# Verify arch
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if arch not in self.admissible_archs:
raise OpError(
self,
@@ -314,7 +314,7 @@ class BlockScaledMmaOp(Tcgen05MmaOp):
def __post_init__(self) -> None:
# Verify arch
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if arch not in self.admissible_archs:
raise OpError(
self,
@@ -471,7 +471,7 @@ class SparseMmaOp(Tcgen05MmaOp):
def __post_init__(self) -> None:
# Verify arch
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if arch not in self.admissible_archs:
raise OpError(
self,

View File

@@ -15,7 +15,7 @@ from typing import Type, Any
from cutlass import cute
from cutlass.base_dsl.arch import Arch
from cutlass.cutlass_dsl import CuTeDSL, T
from cutlass.cutlass_dsl import BaseDSL, T
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
@@ -130,7 +130,7 @@ class MmaOp(WarpGroupMmaOp):
def __post_init__(self) -> None:
# Verify arch
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if not arch == Arch.sm_90a:
raise OpError(
self,

View File

@@ -925,7 +925,17 @@ def load_module(file_path: str, *, enable_tvm_ffi: bool = True):
if enable_tvm_ffi:
import tvm_ffi
return tvm_ffi.load_module(file_path)
try:
# keep_module_alive=False means the module will be unloaded
# after the returned module goes out of scope, this is useful
# for frequent loading and unloading of modules. The only requirement
# is that the module do not return object that have deleter in the module
# and the returned object lives longer than the module.
# DSL functions to not have such issue so it is desirable to set this to False.
return tvm_ffi.load_module(file_path, keep_module_alive=False)
except TypeError:
# compatible with tvm-ffi < 0.1.6
return tvm_ffi.load_module(file_path)
else:
raise DSLRuntimeError(
"Unimplemented, please load the module with enable_tvm_ffi=True."

View File

@@ -20,7 +20,7 @@ from cutlass.cutlass_dsl import (
T,
cutlass_arith,
_binary_op_type_promote,
CuTeDSL,
BaseDSL,
)
from cutlass._mlir import ir
import cutlass._mlir.dialects.cute as _cute_ir
@@ -1776,7 +1776,7 @@ class TensorSSA(cutlass_arith.ArithValue):
fast_cvt_func = cvt_i8_bf16_intrinsic
elif src_dtype == Int4 and dtype == BFloat16:
fast_cvt_func = cvt_i4_bf16_intrinsic
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if fast_cvt_func is not None and arch in fast_cvt_func.supported_archs:
res_vect = fast_cvt_func(src, size(self.shape), loc=loc, ip=ip)
else:

View File

@@ -407,7 +407,7 @@ def benchmark(
To use CUDA graphs, the callable must be a compiled @cute.jit annotated function.
When using CUDA graphs, the kernel must be launched in a non-default stream.
:param callable: The function to benchmark
:param callable: The function to benchmark. For jit function, it must be compiled functions.
:type callable: Callable
:param warmup_iterations: Number of warmup iterations, defaults to 10
:type warmup_iterations: int, optional
@@ -475,15 +475,6 @@ def benchmark(
elapsed_time = float("nan")
if use_cuda_graphs:
# Check if the callable is a JitCompiledFunction or JitExecutor
# These are functions that can be called to launch kernels
compiled_types = (
cutlass.base_dsl.jit_executor.JitCompiledFunction,
cutlass.base_dsl.jit_executor.JitExecutor,
)
if not isinstance(callable, compiled_types):
raise TypeError("Function must be precompiled to be used with CUDA Graphs")
# Check if the stream is a non-default stream
if int(stream) == int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT):
raise ValueError(

View File

@@ -247,7 +247,10 @@ class CutlassBaseDSL(BaseDSL):
return False
def _build_gpu_module(self, attrs, loc=None):
log().info(f"self : {self}")
log().info(f"Building GPU module for {self.name}")
self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels"), loc=loc)
log().info(f"GPU module: {self.gpu_module}")
with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])):
pass
@@ -275,6 +278,9 @@ class CutlassBaseDSL(BaseDSL):
return pipeline
def _enter_gpu_module(self):
log().info(f"self: {self}")
log().info(f"Entering GPU module for {self.name}")
log().info(f"GPU module: {self.gpu_module}")
return ir.InsertionPoint(self.gpu_module.bodyRegion.blocks[0])
def _generate_kernel_attrs(self, config: BaseDSL.LaunchConfig) -> dict:

View File

@@ -126,16 +126,22 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
)
context.module.body.append(parsed_op)
with ir.InsertionPoint(current_block):
cuda_global_state_ptr = self.address_of(
self.cuda_global_state_symbol, self.ptr_type
)
cuda_init_ptr = self.address_of("cuda_init", self.ptr_type)
cuda_load_to_device_ptr = self.address_of("cuda_load_to_device", self.ptr_type)
set_error_ptr = self.address_of(
"TVMFFIErrorSetRaisedFromCStr", self.ptr_type
)
cuda_init_ptr = context.builder.get_or_load_global_func_ptr_from_text(
current_block, "cuda_init"
)
cuda_load_to_device_ptr = context.builder.get_or_load_global_func_ptr_from_text(
current_block, "cuda_load_to_device"
)
set_error_ptr = context.builder.get_or_load_global_func_ptr_from_text(
current_block, "TVMFFIErrorSetRaisedFromCStr"
)
with ir.InsertionPoint(current_block):
# Call the callback function with the loaded ptr value
init_result = llvm.call(
result=self.i32_type, # function returns i32
@@ -495,6 +501,13 @@ class TVMFFIJitCompiledFunctionBase(CudaDialectJitCompiledFunction):
"""Create the tvm_ffi.Function from the current execution engine.
"""
if self.engine is not None:
# trigger eager compile of init callbacks
cuda_init = self.engine.raw_lookup("cuda_init")
cuda_load_to_device = self.engine.raw_lookup("cuda_load_to_device")
if cuda_init is None:
raise DSLRuntimeError("cuda_init not found")
if cuda_load_to_device is None:
raise DSLRuntimeError("cuda_load_to_device not found")
tvm_ffi_function_ptr = self.engine.raw_lookup(
"__tvm_ffi_" + self.function_name
)

View File

@@ -261,7 +261,7 @@ def make_smem_layout_a(
a_smem_layout_staged = cute.tile_to_shape(
a_smem_layout_atom,
cute.append(a_smem_shape, num_stages),
order=(0, 1, 2) if is_k_major else (0, 1, 2),
order=(0, 1, 2) if is_k_major else (1, 0, 2),
loc=loc,
ip=ip,
)

View File

@@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.3.3
nvidia-cutlass-dsl==4.3.4

View File

@@ -133,7 +133,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '4.3.3'
this.__version__ = '4.3.4'
from cutlass_cppgen.backend import create_memory_pool
from cutlass_cppgen.emit.pytorch import pytorch

View File

@@ -51,7 +51,7 @@ setup_pycute.perform_setup()
setup(
name='cutlass_cppgen',
version='4.3.3',
version='4.3.4',
description='CUTLASS Pythonic Interface',
package_dir={'': '.'},
packages=[

View File

@@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='cutlass_library',
version='4.3.3',
version='4.3.4',
description='CUTLASS library generation scripts',
packages=['cutlass_library']
)

View File

@@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='pycute',
version='4.3.3',
version='4.3.4',
description='Python implementation of CuTe',
packages=['pycute'],
)