mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
v4.3.4 update. (#2892)
This commit is contained in:
13
CHANGELOG.md
13
CHANGELOG.md
@@ -2,9 +2,16 @@
|
||||
|
||||
# CUTLASS 4.x
|
||||
|
||||
## [4.3.3](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.3) (2025-12-12)
|
||||
## [4.3.4](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.4) (2025-12-22)
|
||||
* New features
|
||||
- Added PDL support along with example [Kernel launch with Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py)
|
||||
|
||||
### CuTe DSL
|
||||
* Bug fixing and improvements
|
||||
- Fixed a frame refcnt issue with cuda graph
|
||||
- Enhancement for tvm-ffi AoT case for earlier module unload
|
||||
- Fixed order issue in `make_smem_layout_a` in utils/hopper_helpers.py
|
||||
|
||||
## [4.3.3](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.3) (2025-12-12)
|
||||
* New features
|
||||
- Supported namedtuple and kwargs for JIT function arguments in tvm-ffi
|
||||
- Supported variadic tuples for JIT function argument in tvm-ffi
|
||||
@@ -14,8 +21,6 @@
|
||||
- Clearer error message for the case of runtime error cudaErrorInsufficientDriver
|
||||
|
||||
## [4.3.2](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.2) (2025-12-05)
|
||||
|
||||
### CuTe DSL
|
||||
* New features
|
||||
- New env var `CUTE_DSL_CACHE_DIR` to specify the path for dumping caches
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||

|
||||
# Overview
|
||||
|
||||
# CUTLASS 4.3.3
|
||||
# CUTLASS 4.3.4
|
||||
|
||||
_CUTLASS 4.3.3 - Dec 2025_
|
||||
_CUTLASS 4.3.4 - Dec 2025_
|
||||
|
||||
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
|
||||
and related computations at all levels and scales within CUDA. It incorporates strategies for
|
||||
@@ -56,6 +56,7 @@ To get started quickly - please refer :
|
||||
- New env var `CUTE_DSL_CACHE_DIR` to specify the path for dumping caches.
|
||||
- Supported namedtuple and kwargs for JIT function arguments in tvm-ffi.
|
||||
- Supported variadic tuples for JIT function argument in tvm-ffi.
|
||||
- Added PDL support along with example [Kernel launch with Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py)
|
||||
* Debuggability improvements:
|
||||
- Supported source location tracking for DSL APIs (Allow tools like ``nsight`` profiling to correlate perf metrics with Python source code)
|
||||
- Supported dumping PTX and CUBIN code: [Hello World Example](https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/notebooks/hello_world.ipynb)
|
||||
@@ -106,6 +107,9 @@ To get started quickly - please refer :
|
||||
- Fixed an issue of allocating max smem when there's statically allocated smem
|
||||
- Fixed an issue when JIT function argument with union type annotation for tvm-ffi
|
||||
- Clearer error message for the case of runtime error cudaErrorInsufficientDriver
|
||||
- Fixed a frame refcnt issue with cuda graph
|
||||
- Enhancement for tvm-ffi AoT case for earlier module unload
|
||||
- Fixed order issue in make_smem_layout_a in utils/hopper_helpers.py
|
||||
|
||||
## CUTLASS C++
|
||||
* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
|
||||
@@ -401,11 +401,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
d_per_blk = 128
|
||||
d_blks = cute.ceil_div(d, d_per_blk)
|
||||
|
||||
reduction(
|
||||
o, m, l,
|
||||
o_partial, m_partial, l_partial,
|
||||
scale_o
|
||||
).launch(
|
||||
self.reduction(o, m, l, o_partial, m_partial, l_partial, scale_o).launch(
|
||||
grid=[d_blks, h_q, b],
|
||||
block=[d_per_blk, 1, 1],
|
||||
cluster=[1, 1, 1],
|
||||
@@ -1173,7 +1169,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
|
||||
# Reduce colmax in smem
|
||||
if lane_store_max:
|
||||
smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane)
|
||||
self.smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane)
|
||||
|
||||
# Wait for colmax then load
|
||||
cute.arch.barrier(barrier_id=softmax_nbar_id, number_of_threads=warpgroup_threads)
|
||||
@@ -1259,7 +1255,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
# Reduce cluster colmax
|
||||
if warpgroup_widx == 0:
|
||||
if lane_store_max:
|
||||
dsmem_fmax(
|
||||
self.dsmem_fmax(
|
||||
sM_cluster.iterator + sM_layout((0, lane_idx)),
|
||||
sM[(0, lane_idx)],
|
||||
m_cluster_full_ptr
|
||||
@@ -1280,7 +1276,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
else:
|
||||
# other splits copy cluster colmax into local smem
|
||||
if lane_store_max:
|
||||
sM_cluster[0, lane_idx] = dsmem_load(
|
||||
sM_cluster[0, lane_idx] = self.dsmem_load(
|
||||
sM_cluster.iterator + sM_layout((0, lane_idx))
|
||||
)
|
||||
|
||||
@@ -1300,8 +1296,10 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
sM_lane = sM_lane * scale_qs
|
||||
|
||||
# Store colsum and colmax
|
||||
gmem_fadd(gL_partial.iterator + gL_partial.layout(lane_idx), sL_lane)
|
||||
gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
|
||||
self.gmem_fadd(
|
||||
gL_partial.iterator + gL_partial.layout(lane_idx), sL_lane
|
||||
)
|
||||
self.gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
|
||||
if kv_split_in_cluster == 0:
|
||||
gM_partial[lane_idx] = sM_lane
|
||||
|
||||
@@ -1350,7 +1348,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
# Store colsum and colmax
|
||||
gL_partial[lane_idx] = sL_lane
|
||||
gM_partial[lane_idx] = sM_lane
|
||||
gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
|
||||
self.gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
|
||||
|
||||
o_handle = o_consumer.wait_and_advance()
|
||||
cute.copy(thr_load_s, tStO, tSrO)
|
||||
@@ -1378,6 +1376,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@cute.kernel
|
||||
def reduction(
|
||||
o : cute.Tensor,
|
||||
@@ -1414,6 +1413,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
def _mapa(ptr : Pointer, cta_rank_in_cluster : Int32 = 0):
|
||||
llvm_ptr = ptr.llvm_ptr
|
||||
@@ -1424,8 +1424,8 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
)
|
||||
|
||||
@cute.jit
|
||||
def dsmem_load(val_ptr : Pointer):
|
||||
val_llvm_ptr = _mapa(val_ptr, 0)
|
||||
def dsmem_load(self, val_ptr: Pointer):
|
||||
val_llvm_ptr = self._mapa(val_ptr, 0)
|
||||
|
||||
ret = llvm.inline_asm(
|
||||
Float32.mlir_type,
|
||||
@@ -1439,6 +1439,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
|
||||
return Float32(ret)
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
def warp_fmax(val : Float32):
|
||||
ret = llvm.inline_asm(
|
||||
@@ -1472,10 +1473,10 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
)
|
||||
|
||||
@cute.jit
|
||||
def dsmem_fmax(val_ptr : Pointer, val : Float32, mbar_ptr : Pointer):
|
||||
def dsmem_fmax(self, val_ptr: Pointer, val: Float32, mbar_ptr: Pointer):
|
||||
expect_tx_bytes = Int32(Float32.width // 8)
|
||||
val_llvm_ptr = _mapa(val_ptr, 0)
|
||||
mbar_llvm_ptr = _mapa(mbar_ptr, 0)
|
||||
val_llvm_ptr = self._mapa(val_ptr, 0)
|
||||
mbar_llvm_ptr = self._mapa(mbar_ptr, 0)
|
||||
|
||||
nvvm.mbarrier_txn(
|
||||
mbar_llvm_ptr,
|
||||
@@ -1499,6 +1500,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
def gmem_fmax(ptr : Pointer, val : Float32):
|
||||
llvm.inline_asm(
|
||||
@@ -1529,6 +1531,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
def gmem_fadd(ptr : Pointer, val : Float32):
|
||||
llvm.inline_asm(
|
||||
|
||||
@@ -0,0 +1,368 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import argparse
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
|
||||
def supports_pdl():
|
||||
import torch
|
||||
|
||||
return torch.cuda.get_device_capability()[0] >= 9
|
||||
|
||||
|
||||
"""
|
||||
This example demonstrates the use of Programmatic Dependent Launch (PDL) using
|
||||
CuTe DSL.
|
||||
|
||||
PDL is a mechanism which allows for overlapping execution of back-to-back kernels
|
||||
within the same stream.
|
||||
For example, consider the following two elementwise add operations, where the second
|
||||
operation's first operand is the result of the first operation. While performing
|
||||
``w = u + v`` we will load u and v, add them, and then store the result. Once we
|
||||
have finished loading data, we are no longer utilizing the read bandwidth.
|
||||
To effectively utilize the read bandwidth, we can start loading ``x``
|
||||
immediately upon finishing reading. This is what PDL enables us to do.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
w = u + v
|
||||
y = w + x
|
||||
|
||||
To enable PDL, we need to do two things:
|
||||
|
||||
1. Insert the ``griddepcontrol.launch_dependents`` and ``griddepcontrol.wait`` instructions in the kernel.
|
||||
2. Set the PDL launch attribute when launching the kernel.
|
||||
|
||||
The ``griddepcontrol.launch_dependents`` and ``griddepcontrol.wait``
|
||||
instructions enable fine-grained control over kernel execution in PDL.
|
||||
Once all thread blocks execute the ``griddepcontrol.launch_dependents``
|
||||
instruction, the dependent kernels can opportunistically be early-launched.
|
||||
``griddepcontrol.wait`` functions as a synchronization barrier - any warp
|
||||
executing this instruction will block until the previous kernel finishes
|
||||
execution. This allows precise control over data dependencies between kernels.
|
||||
|
||||
The following diagram shows the overlapping execution of two dependent kernels.
|
||||
We call the instructions before ``griddepcontrol.wait`` as prologue (``P0``),
|
||||
which may include barrier initialization and loading of independent data, etc.
|
||||
We call the instructions after ``griddepcontrol.launch_dependents`` as epilogue
|
||||
(``P2``), which may include math operations, data stores, etc. PDL enables
|
||||
these prologue and epilogue phases to execute concurrently across dependent
|
||||
kernels, improving GPU resource utilization. This is particularly beneficial
|
||||
when prologue and epilogue are bound by different resources (e.g., memory
|
||||
bandwidth vs compute throughput).
|
||||
|
||||
# P0: Prologue, P1: Main compute, P2: Epilogue
|
||||
|
||||
P0 P1 P2
|
||||
K1: |=====|+++++|-----|
|
||||
|
||||
<-----> K2 can start early
|
||||
(K1's P2 overlaps with K2's P0)
|
||||
|
||||
P0 P1 P2
|
||||
K2: |=====| |+++++|-----|
|
||||
^
|
||||
|
|
||||
wait for K1 to complete
|
||||
Time ------------------------------------------------------>
|
||||
|
||||
We could run this example with and without PDL:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/blackwell/programmatic_dependent_launch.py --benchmark
|
||||
python examples/blackwell/programmatic_dependent_launch.py --benchmark --use_pdl
|
||||
|
||||
From the benchmark results, you can see some speedups for the PDL version in most cases, benefiting from
|
||||
the overlapping execution of consecutive kernels. Moreover, you can use nsys to observe the overlapping execution.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
nsys profile python examples/blackwell/programmatic_dependent_launch.py --benchmark --use_pdl
|
||||
|
||||
Note, PDL feature is supported on Hopper and later GPUs.
|
||||
|
||||
See [the programming guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization)
|
||||
and the [PTX documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol)
|
||||
for more details.
|
||||
"""
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def elementwise_add_kernel(
|
||||
gA: cute.Tensor,
|
||||
gB: cute.Tensor,
|
||||
gC: cute.Tensor,
|
||||
cC: cute.Tensor, # coordinate tensor
|
||||
shape: cute.Shape,
|
||||
thr_layout: cute.Layout,
|
||||
val_layout: cute.Layout,
|
||||
is_first_kernel: cutlass.Constexpr = True,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
blk_coord = ((None, None), bidx)
|
||||
blkA = gA[blk_coord] # (TileM,TileN)
|
||||
blkB = gB[blk_coord] # (TileM,TileN)
|
||||
blkC = gC[blk_coord] # (TileM,TileN)
|
||||
blkCrd = cC[blk_coord] # (TileM, TileN)
|
||||
|
||||
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
|
||||
copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type)
|
||||
|
||||
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout)
|
||||
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
thr_copy_C = tiled_copy_C.get_slice(tidx)
|
||||
|
||||
thrA = thr_copy_A.partition_S(blkA)
|
||||
thrB = thr_copy_B.partition_S(blkB)
|
||||
thrC = thr_copy_C.partition_S(blkC)
|
||||
|
||||
frgA = cute.make_fragment_like(thrA)
|
||||
frgB = cute.make_fragment_like(thrB)
|
||||
frgC = cute.make_fragment_like(thrC)
|
||||
|
||||
thrCrd = thr_copy_C.partition_S(blkCrd)
|
||||
frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean)
|
||||
|
||||
for i in range(cute.size(frgPred)):
|
||||
val = cute.elem_less(thrCrd[i], shape)
|
||||
frgPred[i] = val
|
||||
|
||||
# Note: when not using cuda-graph, the kernel execution may be blocked by the host overhead.
|
||||
# In this case we won't see overlapping even when pdl is enabled.
|
||||
# In this example, we add a loop (10 times) for all the copy and compute operations in the following code
|
||||
# to make kernel running longer and make pdl benefits observable for both cuda-graph enabled and disabled cases.
|
||||
if is_first_kernel:
|
||||
for _ in range(10):
|
||||
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
|
||||
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
|
||||
# Here we add the launch dependents instruction for the first kernel as a hint to the runtime to early-launch
|
||||
# the next kernel. If the next kernel becomes concurrent, we will have overlap where the second kernel
|
||||
# can start reading x to ensure an E2E speedup. Note the placement of launch dependents has no implication
|
||||
# on correctness, only performance.
|
||||
cute.arch.griddepcontrol_launch_dependents()
|
||||
else:
|
||||
# In this example, the second kernel's second operand ``gB`` has no dependencies, its loading can overlap
|
||||
# with the computation of ``gC`` from the first kernel.
|
||||
for _ in range(10):
|
||||
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
|
||||
|
||||
# For the second kernel, its first operand ``gA`` is dependent on the previous kernel, we must call
|
||||
# griddepcontrol.wait to assure correctness. This instruction will block until the prior kernels finishes
|
||||
# and its memory operations are visible. Since gA is written by the prior kernel, this will block until gA
|
||||
# is visible to our kernel. Without it, we would have undefined behavior due to a race condition.
|
||||
cute.arch.griddepcontrol_wait()
|
||||
|
||||
for _ in range(10):
|
||||
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
|
||||
|
||||
for _ in range(10):
|
||||
result = frgA.load() + frgB.load()
|
||||
frgC.store(result)
|
||||
cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def elementwise_add(
|
||||
mA,
|
||||
mB,
|
||||
mC,
|
||||
stream: cuda.CUstream,
|
||||
use_pdl: cutlass.Constexpr = True,
|
||||
is_first_kernel: cutlass.Constexpr = True,
|
||||
):
|
||||
dtype = mA.element_type
|
||||
# copy_bits for a thread is 128 bits, and we use 128 // dtype.width to get the vector size
|
||||
vector_size = 128 // dtype.width
|
||||
|
||||
thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
|
||||
val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0))
|
||||
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
||||
|
||||
gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
|
||||
idC = cute.make_identity_tensor(mC.shape)
|
||||
cC = cute.zipped_divide(idC, tiler=tiler_mn)
|
||||
|
||||
elementwise_add_kernel(
|
||||
gA, gB, gC, cC, mC.shape, thr_layout, val_layout, is_first_kernel
|
||||
).launch(
|
||||
grid=[cute.size(gC, mode=[1]), 1, 1],
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
stream=stream,
|
||||
use_pdl=use_pdl,
|
||||
)
|
||||
|
||||
|
||||
def run_pdl_example(
|
||||
M,
|
||||
N,
|
||||
skip_ref_check=False,
|
||||
benchmark=False,
|
||||
warmup_iterations=5,
|
||||
iterations=100,
|
||||
use_pdl=True,
|
||||
):
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("Blackwell/Hopper GPU is required to run this example!")
|
||||
|
||||
print("\nRunning Elementwise Add test with:")
|
||||
print(f"Tensor dimensions: [{M}, {N}]")
|
||||
print(f"Use PDL: {use_pdl}")
|
||||
|
||||
u = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
v = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
w = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
x = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
y = torch.empty(M, N, dtype=torch.float32, device="cuda")
|
||||
|
||||
u_tensor = from_dlpack(u).mark_layout_dynamic()
|
||||
v_tensor = from_dlpack(v).mark_layout_dynamic()
|
||||
w_tensor = from_dlpack(w).mark_layout_dynamic()
|
||||
x_tensor = from_dlpack(x).mark_layout_dynamic()
|
||||
y_tensor = from_dlpack(y).mark_layout_dynamic()
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
current_stream = cuda.CUstream(stream.cuda_stream)
|
||||
# Since is_first_kernel is cutlass.Constexpr, we need to compile for
|
||||
# the first and second kernel separately.
|
||||
compiled_func_first_kernel = cute.compile(
|
||||
elementwise_add,
|
||||
u_tensor,
|
||||
v_tensor,
|
||||
w_tensor,
|
||||
current_stream,
|
||||
use_pdl,
|
||||
is_first_kernel=True,
|
||||
options="--enable-tvm-ffi",
|
||||
)
|
||||
compiled_func_second_kernel = cute.compile(
|
||||
elementwise_add,
|
||||
w_tensor,
|
||||
x_tensor,
|
||||
y_tensor,
|
||||
current_stream,
|
||||
use_pdl,
|
||||
is_first_kernel=False,
|
||||
options="--enable-tvm-ffi",
|
||||
)
|
||||
|
||||
# launch and run the two consecutive kernels in a same stream.
|
||||
def run_func(current_stream, u, v, w, x, y):
|
||||
# Run first operation: w_tensor = u_tensor + v_tensor
|
||||
compiled_func_first_kernel(
|
||||
u,
|
||||
v,
|
||||
w,
|
||||
current_stream,
|
||||
)
|
||||
# Run second operation: y_tensor = w_tensor + x_tensor
|
||||
# its first operand ``w_tensor`` is the result of the first operation,
|
||||
# they use the same memory space.
|
||||
compiled_func_second_kernel(
|
||||
w,
|
||||
x,
|
||||
y,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
if not skip_ref_check:
|
||||
run_func(current_stream, u, v, w, x, y)
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(u.cpu() + v.cpu() + x.cpu(), y.cpu())
|
||||
print("Results verified successfully!")
|
||||
|
||||
if not benchmark:
|
||||
return
|
||||
|
||||
def generate_kernel_arguments():
|
||||
u = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
v = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
w = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
x = torch.randn(M, N, dtype=torch.float32, device="cuda")
|
||||
y = torch.empty(M, N, dtype=torch.float32, device="cuda")
|
||||
|
||||
return testing.JitArguments(current_stream, u, v, w, x, y)
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
run_func,
|
||||
workspace_generator=generate_kernel_arguments,
|
||||
workspace_count=10,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
stream=current_stream,
|
||||
use_cuda_graphs=True,
|
||||
)
|
||||
print(f"Execution time: {avg_time_us:.4f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="example of Programmatic Dependent Launch (PDL) using CuTe DSL"
|
||||
)
|
||||
parser.add_argument("--M", default=256, type=int)
|
||||
parser.add_argument("--N", default=256, type=int)
|
||||
parser.add_argument("--warmup_iterations", default=5, type=int)
|
||||
parser.add_argument("--iterations", default=10, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument("--benchmark", action="store_true")
|
||||
parser.add_argument("--use_pdl", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
if supports_pdl():
|
||||
run_pdl_example(
|
||||
args.M,
|
||||
args.N,
|
||||
skip_ref_check=args.skip_ref_check,
|
||||
benchmark=args.benchmark,
|
||||
warmup_iterations=args.warmup_iterations,
|
||||
iterations=args.iterations,
|
||||
use_pdl=args.use_pdl,
|
||||
)
|
||||
print("\nPASS")
|
||||
else:
|
||||
print(
|
||||
"PDL is not supported on this device, it requires Hopper or newer generations"
|
||||
)
|
||||
@@ -839,7 +839,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
)
|
||||
|
||||
s_max_layout = cute.make_layout(
|
||||
cute.size(layout_acc_mn(pv_tiled_mma, acc_pv.layout), mode=[0])
|
||||
cute.size(self.layout_acc_mn(pv_tiled_mma, acc_pv.layout), mode=[0])
|
||||
)
|
||||
s_max = cute.make_rmem_tensor_like(s_max_layout, self.qk_acc_dtype)
|
||||
a_sum = cute.make_rmem_tensor_like(s_max, cutlass.Float32)
|
||||
@@ -888,7 +888,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
# MMA QK
|
||||
cute.nvgpu.warpgroup.fence()
|
||||
|
||||
gemm_zero_acc(
|
||||
self.gemm_zero_acc(
|
||||
qk_tiled_mma,
|
||||
tSrQ[(None, None, None, q_handle.index)],
|
||||
tSrK[(None, None, None, k_handle.index)],
|
||||
@@ -901,7 +901,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
# Wait for the pipeline MMAs to drain
|
||||
cute.nvgpu.warpgroup.wait_group(0)
|
||||
|
||||
s_max, a_sum = softmax_step(
|
||||
s_max, a_sum = self.softmax_step(
|
||||
True,
|
||||
self.mask_type,
|
||||
acc_qk,
|
||||
@@ -919,7 +919,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
True,
|
||||
)
|
||||
|
||||
acc_qk_fixed = make_acc_into_op(
|
||||
acc_qk_fixed = self.make_acc_into_op(
|
||||
acc_qk, pv_tiled_mma.tv_layout_A, self.q_dtype
|
||||
)
|
||||
|
||||
@@ -928,7 +928,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
# MMA PV
|
||||
cute.nvgpu.warpgroup.fence()
|
||||
|
||||
gemm_zero_acc(
|
||||
self.gemm_zero_acc(
|
||||
pv_tiled_mma,
|
||||
acc_qk_fixed,
|
||||
tOrV[(None, None, None, v_handle.index)],
|
||||
@@ -1040,7 +1040,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
cute.nvgpu.warpgroup.wait_group(0)
|
||||
|
||||
# acc_pv updated
|
||||
lse = tail(
|
||||
lse = self.tail(
|
||||
s_max, a_sum, acc_pv, pv_tiled_mma, scale_softmax, scale_output
|
||||
)
|
||||
|
||||
@@ -1077,10 +1077,10 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
|
||||
if tOcO[0][1] == 0:
|
||||
tOgLSE_mn = cute.make_tensor(
|
||||
tOgLSE.iterator, layout_acc_mn(pv_tiled_mma, tOgLSE.layout)
|
||||
tOgLSE.iterator, self.layout_acc_mn(pv_tiled_mma, tOgLSE.layout)
|
||||
)
|
||||
tOcO_mn = cute.make_tensor(
|
||||
tOcO.iterator, layout_acc_mn(pv_tiled_mma, tOcO.layout)
|
||||
tOcO.iterator, self.layout_acc_mn(pv_tiled_mma, tOcO.layout)
|
||||
)
|
||||
for i in cutlass.range_constexpr(cute.size(tOgLSE_mn, mode=[0])):
|
||||
if (
|
||||
@@ -1241,7 +1241,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
# MMA QK
|
||||
cute.nvgpu.warpgroup.fence()
|
||||
|
||||
gemm_zero_acc(
|
||||
self.gemm_zero_acc(
|
||||
qk_tiled_mma,
|
||||
tSrQ[(None, None, None, q_handle.index)],
|
||||
tSrK[(None, None, None, k_handle.index)],
|
||||
@@ -1255,7 +1255,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
# Wait for the pipeline MMAs to drain
|
||||
cute.nvgpu.warpgroup.wait_group(0)
|
||||
|
||||
s_max, a_sum = softmax_step(
|
||||
s_max, a_sum = self.softmax_step(
|
||||
fusion,
|
||||
self.mask_type,
|
||||
acc_qk,
|
||||
@@ -1272,7 +1272,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
window_size_right,
|
||||
)
|
||||
|
||||
acc_qk_fixed = make_acc_into_op(
|
||||
acc_qk_fixed = self.make_acc_into_op(
|
||||
acc_qk, pv_tiled_mma.tv_layout_A, self.q_dtype
|
||||
)
|
||||
|
||||
@@ -1300,6 +1300,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
|
||||
@cute.jit
|
||||
def softmax_step(
|
||||
self,
|
||||
fusion: bool,
|
||||
mask_type: fmha_utils.MaskEnum,
|
||||
acc_qk: cute.ThrMma,
|
||||
@@ -1328,10 +1329,10 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
)
|
||||
|
||||
acc_qk_mn = cute.make_tensor(
|
||||
acc_qk.iterator, layout_acc_mn(tiled_mma_qk, acc_qk.layout)
|
||||
acc_qk.iterator, self.layout_acc_mn(tiled_mma_qk, acc_qk.layout)
|
||||
)
|
||||
|
||||
reduction_target_qk = reduction_target_n(tiled_mma_qk)
|
||||
reduction_target_qk = self.reduction_target_n(tiled_mma_qk)
|
||||
red_rank = cute.rank(reduction_target_qk)
|
||||
|
||||
s_max_prev = None
|
||||
@@ -1346,7 +1347,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
s_max[i] = cute.arch.fmax(s_max[i], acc_qk_mn[i, j])
|
||||
else:
|
||||
acc_pv_mn = cute.make_tensor(
|
||||
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
||||
acc_pv.iterator, self.layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
||||
)
|
||||
s_max_prev = cute.make_rmem_tensor_like(s_max, s_max._dtype)
|
||||
|
||||
@@ -1396,15 +1397,15 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
return s_max, a_sum
|
||||
|
||||
@cute.jit
|
||||
def reduction_target_n(tiled_mma):
|
||||
separated = layout_separate(
|
||||
def reduction_target_n(self, tiled_mma):
|
||||
separated = self.layout_separate(
|
||||
tiled_mma.shape_mnk[0],
|
||||
cute.make_layout(tiled_mma.tv_layout_C.shape[0]),
|
||||
tiled_mma.tv_layout_C.stride[0],
|
||||
)
|
||||
return separated[1]
|
||||
|
||||
@cute.jit
|
||||
@staticmethod
|
||||
def convert_c_layout_to_a_layout(c, a):
|
||||
return cute.make_layout(
|
||||
(a, c.shape[1], (c.shape[2], cute.size(c, mode=[0]) // cute.size(a))),
|
||||
@@ -1416,9 +1417,9 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
)
|
||||
|
||||
@cute.jit
|
||||
def make_acc_into_op(acc, operand_layout_tv, Element):
|
||||
def make_acc_into_op(self, acc, operand_layout_tv, Element):
|
||||
operand = cute.make_rmem_tensor_like(
|
||||
convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]),
|
||||
self.convert_c_layout_to_a_layout(acc.layout, operand_layout_tv.shape[1]),
|
||||
Element,
|
||||
)
|
||||
operand_as_acc = cute.make_tensor(operand.iterator, acc.layout)
|
||||
@@ -1499,7 +1500,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
return operand
|
||||
|
||||
@cute.jit
|
||||
def tail(s_max, a_sum, acc_pv, tiled_mma_pv, scale_softmax, scale_output):
|
||||
def tail(self, s_max, a_sum, acc_pv, tiled_mma_pv, scale_softmax, scale_output):
|
||||
"""
|
||||
Final processing step for FMHA that computes log-sum-exp (LSE) and scales the output.
|
||||
|
||||
@@ -1527,9 +1528,9 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
"""
|
||||
# Create tensor view of accumulated P*V values with M*N layout
|
||||
acc_pv_mn = cute.make_tensor(
|
||||
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
||||
acc_pv.iterator, self.layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
||||
)
|
||||
reduction_target = reduction_target_n(tiled_mma_pv)
|
||||
reduction_target = self.reduction_target_n(tiled_mma_pv)
|
||||
red_rank = cute.rank(reduction_target)
|
||||
for r in cutlass.range_constexpr(red_rank):
|
||||
for i in cutlass.range_constexpr(cute.size(acc_pv_mn, mode=[0])):
|
||||
@@ -1538,7 +1539,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
)
|
||||
|
||||
acc_mn = cute.make_tensor(
|
||||
acc_pv.iterator, layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
||||
acc_pv.iterator, self.layout_acc_mn(tiled_mma_pv, acc_pv.layout)
|
||||
)
|
||||
|
||||
lse = cute.make_rmem_tensor_like(a_sum, a_sum._dtype)
|
||||
@@ -1559,7 +1560,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
|
||||
return lse
|
||||
|
||||
@cute.jit
|
||||
@staticmethod
|
||||
def layout_separate(thr, src, ref):
|
||||
lt = cute.make_layout(())
|
||||
ge = cute.make_layout(())
|
||||
@@ -1577,6 +1578,7 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
r = cute.append(cute.append(cute.make_layout(()), lt), ge)
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
def gemm_zero_acc(tiled_mma, A, B, C):
|
||||
rA = cute.rank(A)
|
||||
@@ -1606,8 +1608,8 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
assert 0
|
||||
|
||||
@cute.jit
|
||||
def layout_acc_mn(tiled_mma, acc):
|
||||
separated = layout_separate(
|
||||
def layout_acc_mn(self, tiled_mma, acc):
|
||||
separated = self.layout_separate(
|
||||
tiled_mma.shape_mnk[0], acc[0], tiled_mma.tv_layout_C.stride[1]
|
||||
)
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
|
||||
#define CUTLASS_MAJOR 4
|
||||
#define CUTLASS_MINOR 3
|
||||
#define CUTLASS_PATCH 3
|
||||
#define CUTLASS_PATCH 4
|
||||
|
||||
#ifdef CUTLASS_VERSIONS_GENERATED
|
||||
#include "cutlass/version_extended.h"
|
||||
|
||||
@@ -470,7 +470,56 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
names=[ast.alias(name=f"{top_module_name}.base_dsl", asname="_dsl_")]
|
||||
)
|
||||
)
|
||||
transformed_tree.body = import_stmts + transformed_tree.body
|
||||
|
||||
assert len(transformed_tree.body) == 1
|
||||
assert isinstance(transformed_tree.body[0], ast.FunctionDef)
|
||||
transformed_tree.body[0].body = import_stmts + transformed_tree.body[0].body
|
||||
# Remove all decorators from top level function
|
||||
transformed_tree.body[0].decorator_list = []
|
||||
|
||||
# Step 4. Wrap the function with nonlocal captures, if has any
|
||||
# if the function has a nonlocal variable, wrap it in a function and return the function
|
||||
# pseudo code:
|
||||
# def foo():
|
||||
# nonlocal_var_0 = None
|
||||
# nonlocal_var_1 = None
|
||||
# def foo(args):
|
||||
# ...
|
||||
# return foo
|
||||
# foo = foo()
|
||||
nonlocals = {v: None for v in function_pointer.__code__.co_freevars}
|
||||
|
||||
if len(nonlocals) > 0:
|
||||
assignments = []
|
||||
for n, _ in nonlocals.items():
|
||||
assignments.append(
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id=n, ctx=ast.Store())],
|
||||
value=ast.Constant(value=None),
|
||||
)
|
||||
)
|
||||
|
||||
return_expr = [ast.Return(value=ast.Name(id=func_name, ctx=ast.Load()))]
|
||||
|
||||
wrapper_fcn = ast.FunctionDef(
|
||||
name=func_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[],
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
defaults=[],
|
||||
),
|
||||
body=assignments + transformed_tree.body + return_expr,
|
||||
decorator_list=[],
|
||||
)
|
||||
invoke = ast.Call(
|
||||
func=ast.Name(id=func_name, ctx=ast.Load()), args=[], keywords=[]
|
||||
)
|
||||
assign = ast.Assign(
|
||||
targets=[ast.Name(id=func_name, ctx=ast.Store())], value=invoke
|
||||
)
|
||||
transformed_tree.body = [wrapper_fcn, assign]
|
||||
|
||||
# Step 4. Import cutlass and base_dsl
|
||||
ast.fix_missing_locations(transformed_tree)
|
||||
@@ -1521,6 +1570,15 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
self.scope_manager.add_to_scope(node.name)
|
||||
for arg in node.args.args:
|
||||
self.scope_manager.add_to_scope(arg.arg)
|
||||
arg.annotation = None
|
||||
|
||||
for arg in node.args.kwonlyargs:
|
||||
self.scope_manager.add_to_scope(arg.arg)
|
||||
arg.annotation = None
|
||||
|
||||
for arg in node.args.posonlyargs:
|
||||
self.scope_manager.add_to_scope(arg.arg)
|
||||
arg.annotation = None
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
@@ -622,18 +622,14 @@ class CompileCallable:
|
||||
func,
|
||||
)
|
||||
|
||||
# If it's a wrapped function created by jit decorator, get the original function
|
||||
if hasattr(func, "__wrapped__"):
|
||||
# If it's a wrapped function created by decorators, get the original function
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
|
||||
# Lazy initialization of DSL object if has not been initialized
|
||||
# Use local import to avoid circular import
|
||||
from .dsl import BaseDSL
|
||||
|
||||
BaseDSL._lazy_initialize_dsl(func)
|
||||
|
||||
if not hasattr(func, "_dsl_object"):
|
||||
raise DSLRuntimeError("Function is not decorated with jit decorator.")
|
||||
raise DSLRuntimeError(
|
||||
f"Function {func} is not decorated with jit decorator."
|
||||
)
|
||||
|
||||
# process compile options, extract the options and remove them from the kwargs
|
||||
options = kwargs.pop("options", None)
|
||||
@@ -645,8 +641,4 @@ class CompileCallable:
|
||||
else:
|
||||
compile_options = self._compile_options
|
||||
func._dsl_object.compile_options = compile_options
|
||||
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
|
||||
|
||||
if hasattr(func, "_decorator_frame"):
|
||||
kwargs["_decorator_frame"] = func._decorator_frame
|
||||
return func._dsl_object._func(fcn_ptr, *args, **kwargs)
|
||||
return func._dsl_object._func(func, *args, **kwargs)
|
||||
|
||||
@@ -31,9 +31,10 @@ import weakref
|
||||
from functools import lru_cache, wraps
|
||||
from collections import namedtuple, OrderedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, List
|
||||
from typing import Any, Callable, List, ClassVar
|
||||
from types import SimpleNamespace
|
||||
import warnings
|
||||
import threading
|
||||
|
||||
from . import typing as t
|
||||
from .env_manager import EnvironmentVarManager
|
||||
@@ -228,49 +229,92 @@ def new_from_mlir_values(obj, values):
|
||||
assert len(values) == 0, f"{obj} expects 0 values, but got {values}"
|
||||
return obj
|
||||
|
||||
|
||||
class DSLCallable:
|
||||
@dataclass(frozen=True)
|
||||
class DSLLocation:
|
||||
"""
|
||||
Wrapper class for a callable object used within the DSL.
|
||||
|
||||
DSLCallable is designed to wrap a function and provide additional
|
||||
introspection utilities such as retrieving the argument specification
|
||||
and signature. It ensures that the wrapped function can only be called
|
||||
once, after which the reference to the function is cleared to prevent
|
||||
further invocations. This is useful in scenarios where a function should
|
||||
only be executed a single time within the DSL's execution model.
|
||||
Represents Python source location information for MLIR DSL code.
|
||||
|
||||
Attributes:
|
||||
func (callable): The function to be wrapped and managed.
|
||||
filename (str): Name of the Python source file.
|
||||
lineno (int): Line number in the source file.
|
||||
col_offset (int): Column offset in the source line.
|
||||
function_name (str): Name of the function in which the location occurs.
|
||||
|
||||
Methods:
|
||||
__call__(*args, **kwargs): Calls the wrapped function and clears it.
|
||||
This is used primarily to annotate or trace locations in generated MLIR IR
|
||||
back to the original Python code for better diagnostic and debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
self.name = func.__name__
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
ret = self.__func__(*args, **kwargs)
|
||||
self.func = None
|
||||
return ret
|
||||
|
||||
@property
|
||||
def __func__(self):
|
||||
assert self.func is not None, "DSLCallable is already called"
|
||||
return self.func
|
||||
|
||||
@property
|
||||
def __signature__(self):
|
||||
return inspect.signature(self.__func__)
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return self.name
|
||||
filename: str
|
||||
lineno: int
|
||||
col_offset: int
|
||||
function_name: str
|
||||
|
||||
|
||||
class BaseDSL:
|
||||
@dataclass
|
||||
class PreprocessSessionData:
|
||||
"""
|
||||
Holds metadata and transformed AST related to a DSL preprocessing session.
|
||||
|
||||
Attributes:
|
||||
decorator_globals (dict): The global variables from the decorator's environment,
|
||||
captured for possible AST or code evaluation during preprocessing.
|
||||
"""
|
||||
decorator_globals: dict
|
||||
|
||||
|
||||
class DSLSingletonMeta(type):
|
||||
"""
|
||||
Metaclass implementing the Singleton pattern for DSL classes.
|
||||
|
||||
The DSLSingletonMeta ensures that only one instance of a derived DSL class exists at any time.
|
||||
When a class is called, it checks if an instance already exists in the `_instances` dictionary.
|
||||
- If requesting `BaseDSL` itself, it asserts that a concrete subclass has been initialized,
|
||||
and returns the first available singleton instance among subclasses.
|
||||
- If requesting a concrete subclass, it creates a new instance if none exists, or returns
|
||||
the already created instance.
|
||||
|
||||
This metaclass is useful for maintaining global state and configuration across the DSL system,
|
||||
ensuring that all parts of the application operate on the same DSL instance.
|
||||
|
||||
Attributes:
|
||||
_instances (dict): Maps DSL classes to their singleton instances.
|
||||
|
||||
Example:
|
||||
class MyDSL(BaseDSL): ...
|
||||
dsl1 = MyDSL()
|
||||
dsl2 = MyDSL()
|
||||
assert dsl1 is dsl2 # Singleton property
|
||||
"""
|
||||
|
||||
_instances: ClassVar[dict] = {}
|
||||
_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
with cls._lock:
|
||||
log().info(f"DSLSingletonMeta __call__ for {cls}")
|
||||
if cls is BaseDSL:
|
||||
# If one is querying a BaseDSL which is abstract, returns an arbitrary instance of a concrete subclass should be fine.
|
||||
# Here we just return the first instance of a concrete subclass.
|
||||
assert cls._instances, (
|
||||
"Need to initialize a concrete subclass of BaseDSL first"
|
||||
)
|
||||
return next(iter(cls._instances.values()))
|
||||
elif cls not in cls._instances:
|
||||
instance = super().__call__(*args, **kwargs)
|
||||
cls._instances[cls] = instance
|
||||
log().info(f"Active DSL singleton instances: {cls._instances}")
|
||||
return cls._instances[cls]
|
||||
|
||||
def clear_instances(cls):
|
||||
log().info(
|
||||
f"Clearing DSL singleton instances for {cls}, current instances: {cls._instances}"
|
||||
)
|
||||
if cls in cls._instances:
|
||||
del cls._instances[cls]
|
||||
log().info(f"DSL singleton instances after clearing: {cls._instances}")
|
||||
|
||||
|
||||
class BaseDSL(metaclass=DSLSingletonMeta):
|
||||
gpu_module = None
|
||||
_env_class = EnvironmentVarManager
|
||||
|
||||
@@ -310,7 +354,8 @@ class BaseDSL:
|
||||
self.name = name
|
||||
self.compiler_provider = compiler_provider
|
||||
self.pass_sm_arch_name = pass_sm_arch_name
|
||||
self.frame = None
|
||||
self.preprocess_session_data = None
|
||||
self.decorator_location = None
|
||||
self.no_cache = False
|
||||
self.device_compilation_only = device_compilation_only
|
||||
self.num_kernels = 0
|
||||
@@ -379,7 +424,6 @@ class BaseDSL:
|
||||
warnings.warn(message, UserWarning)
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_dsl(cls):
|
||||
# Instantiate the DSL Class once
|
||||
main_dsl = cls()
|
||||
@@ -414,38 +458,22 @@ class BaseDSL:
|
||||
return fcn_ptr
|
||||
|
||||
@staticmethod
|
||||
def _preprocess_and_execute(func):
|
||||
def _preprocess_and_replace_code(func):
|
||||
"""
|
||||
Run ast transformation and return the materialized function pointer
|
||||
"""
|
||||
|
||||
# Lazy initialization of DSL object if has not been initialized
|
||||
if not hasattr(func, "_dsl_object"):
|
||||
func._dsl_object = func._dsl_cls._get_dsl()
|
||||
delattr(func, "_dsl_cls")
|
||||
|
||||
if not func._dsl_object.enable_preprocessor:
|
||||
if hasattr(func, "_decorator_frame"):
|
||||
delattr(func, "_decorator_frame")
|
||||
if hasattr(func, "_transformed_ast"):
|
||||
delattr(func, "_transformed_ast")
|
||||
return func
|
||||
|
||||
if hasattr(func, "_transformed_ast"):
|
||||
if hasattr(func, "_preprocess_session_data"):
|
||||
# If the function ptr is already materialized, use the existing one
|
||||
func._dsl_object.frame = func._decorator_frame
|
||||
if func._transformed_ast is None:
|
||||
func._transformed_ast = func._dsl_object.run_preprocessor(func)
|
||||
if func._transformed_ast is None:
|
||||
del func._transformed_ast
|
||||
func._dsl_object.frame = None
|
||||
return func
|
||||
|
||||
fcn_ptr = func._dsl_object.get_function_ptr(func)
|
||||
# If the function is decorated, de-decorate it
|
||||
fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__)
|
||||
func._dsl_object.frame = None
|
||||
return DSLCallable(fcn_ptr)
|
||||
func._dsl_object.preprocess_session_data = func._preprocess_session_data
|
||||
func._dsl_object.decorator_location = func._decorator_location
|
||||
transformed_ast = func._dsl_object.run_preprocessor(func)
|
||||
fcn_ptr = func._dsl_object.get_function_ptr(func, transformed_ast)
|
||||
func.__code__ = (
|
||||
fcn_ptr.__code__
|
||||
if not isinstance(fcn_ptr, staticmethod)
|
||||
else fcn_ptr.__func__.__code__
|
||||
)
|
||||
return func
|
||||
|
||||
@staticmethod
|
||||
@@ -457,20 +485,27 @@ class BaseDSL:
|
||||
|
||||
def jit_runner_decorator(func):
|
||||
# Run preprocessor that alters AST
|
||||
func._dsl_cls = cls
|
||||
if BaseDSL._can_preprocess(**dkwargs):
|
||||
func._dsl_object = cls._get_dsl()
|
||||
func._decorator_location = BaseDSL.get_location_from_frame(frame)
|
||||
if (
|
||||
func._dsl_object.enable_preprocessor
|
||||
and func._dsl_object._can_preprocess(**dkwargs)
|
||||
):
|
||||
# For an annotated function, add some DSL attributes
|
||||
# When materializing the AST, we need decorator's frame
|
||||
func._decorator_frame = frame
|
||||
# No transformed ast at this point
|
||||
func._transformed_ast = None
|
||||
func._preprocess_session_data = PreprocessSessionData(
|
||||
decorator_globals=frame.f_globals,
|
||||
)
|
||||
BaseDSL._preprocess_and_replace_code(func)
|
||||
|
||||
@wraps(func)
|
||||
def jit_wrapper(*args, **kwargs):
|
||||
func_ptr = BaseDSL._preprocess_and_execute(func)
|
||||
return getattr(func._dsl_object, executor_name)(
|
||||
func_ptr, *args, **kwargs
|
||||
)
|
||||
return getattr(func._dsl_object, executor_name)(func, *args, **kwargs)
|
||||
|
||||
def set_name_prefix(name: str):
|
||||
jit_wrapper._name_prefix = name
|
||||
|
||||
jit_wrapper.set_name_prefix = set_name_prefix
|
||||
|
||||
return jit_wrapper
|
||||
|
||||
@@ -479,15 +514,6 @@ class BaseDSL:
|
||||
else:
|
||||
return jit_runner_decorator
|
||||
|
||||
@staticmethod
|
||||
def _lazy_initialize_dsl(func):
|
||||
"""
|
||||
Lazy initialization of DSL object if has not been initialized
|
||||
"""
|
||||
if hasattr(func, "_dsl_cls"):
|
||||
func._dsl_object = func._dsl_cls._get_dsl()
|
||||
delattr(func, "_dsl_cls")
|
||||
|
||||
@classmethod
|
||||
def jit(cls, *dargs, **dkwargs):
|
||||
"""
|
||||
@@ -516,6 +542,7 @@ class BaseDSL:
|
||||
"""
|
||||
Build the module op that contains the kernels.
|
||||
"""
|
||||
log().info(f"[abstract] Building GPU module for {self.name}")
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -688,9 +715,11 @@ class BaseDSL:
|
||||
dictionary is used to execute the python code.
|
||||
"""
|
||||
all_globals = {}
|
||||
if self.frame:
|
||||
all_globals.update(self.frame.f_globals)
|
||||
all_globals.update(self.frame.f_locals)
|
||||
if (
|
||||
self.preprocess_session_data
|
||||
and self.preprocess_session_data.decorator_globals
|
||||
):
|
||||
all_globals.update(self.preprocess_session_data.decorator_globals)
|
||||
return all_globals
|
||||
|
||||
@abstractmethod
|
||||
@@ -955,25 +984,40 @@ class BaseDSL:
|
||||
else:
|
||||
ir._GlobalDebug.set_types(f"diagnostic-{args.diagnostic}")
|
||||
|
||||
def get_location(self, frame=None):
|
||||
"""
|
||||
Get python location information and generate MLIR location
|
||||
"""
|
||||
frame = self.frame if frame is None else frame
|
||||
frame = inspect.currentframe().f_back if frame is None else frame
|
||||
@staticmethod
|
||||
def get_location_from_frame(frame):
|
||||
frameInfo = inspect.getframeinfo(frame)
|
||||
|
||||
file_loc = ir.Location.file(
|
||||
frame.f_code.co_filename,
|
||||
frame.f_lineno,
|
||||
frameInfo.positions.col_offset if hasattr(frameInfo, "positions") else 0,
|
||||
)
|
||||
loc = ir.Location.name(
|
||||
(
|
||||
return DSLLocation(
|
||||
filename=frameInfo.filename,
|
||||
lineno=frameInfo.lineno,
|
||||
col_offset=(
|
||||
frameInfo.positions.col_offset if hasattr(frameInfo, "positions") else 0
|
||||
),
|
||||
function_name=(
|
||||
"".join([c.strip() for c in frameInfo.code_context])
|
||||
if frameInfo.code_context
|
||||
else frameInfo.function
|
||||
),
|
||||
)
|
||||
|
||||
def get_ir_location(self, location: DSLLocation = None):
|
||||
"""
|
||||
Get python location information and generate MLIR location
|
||||
"""
|
||||
if location is None:
|
||||
if self.decorator_location:
|
||||
location = self.decorator_location
|
||||
|
||||
if location is None:
|
||||
return ir.Location.unknown()
|
||||
|
||||
file_loc = ir.Location.file(
|
||||
location.filename,
|
||||
location.lineno,
|
||||
location.col_offset,
|
||||
)
|
||||
loc = ir.Location.name(
|
||||
(location.function_name),
|
||||
childLoc=file_loc,
|
||||
)
|
||||
return loc
|
||||
@@ -1140,10 +1184,10 @@ class BaseDSL:
|
||||
gpu_module_attrs,
|
||||
args,
|
||||
args_spec,
|
||||
frame=None,
|
||||
location=None,
|
||||
):
|
||||
def build_ir_module():
|
||||
loc = self.get_location(frame)
|
||||
loc = self.get_ir_location(location)
|
||||
module = ir.Module.create(loc=loc)
|
||||
unit_attr = ir.UnitAttr.get()
|
||||
module.operation.attributes["gpu.container_module"] = unit_attr
|
||||
@@ -1308,6 +1352,10 @@ class BaseDSL:
|
||||
self.num_kernels = 0
|
||||
# reset the compile options after the compilation is done.
|
||||
self.compile_options = CompileOptions()
|
||||
# reset preprocess session data after the compilation is done.
|
||||
self.preprocess_session_data = None
|
||||
# reset decorator location after the compilation is done.
|
||||
self.decorator_location = None
|
||||
|
||||
def extract_dynamic_args(self, funcBody, args, kwargs, args_spec):
|
||||
"""This function is used to extract the original dynamic arguments for AOT C header generation.
|
||||
@@ -1348,11 +1396,10 @@ class BaseDSL:
|
||||
pipeline,
|
||||
no_cache,
|
||||
compile_only,
|
||||
loc=None,
|
||||
frame=None,
|
||||
location=None,
|
||||
):
|
||||
"""Generate MLIR module and compile iself.T_provider."""
|
||||
with ir.Context(), self.get_location(frame):
|
||||
with ir.Context(), self.get_ir_location(location):
|
||||
try:
|
||||
# Convert input arguments to MLIR arguments
|
||||
exe_args, func_types, adapted_args = self.generate_mlir_function_types(
|
||||
@@ -1374,7 +1421,7 @@ class BaseDSL:
|
||||
gpu_module_attrs,
|
||||
args,
|
||||
args_spec,
|
||||
frame=frame,
|
||||
location=location,
|
||||
)
|
||||
|
||||
# dryrun is used to only generate IR
|
||||
@@ -1437,11 +1484,14 @@ class BaseDSL:
|
||||
return transformed_ast
|
||||
return None
|
||||
|
||||
def get_function_ptr(self, original_function):
|
||||
def get_function_ptr(self, original_function, transformed_ast):
|
||||
file_name = inspect.getsourcefile(original_function)
|
||||
code_object = compile(
|
||||
original_function._transformed_ast, filename=file_name, mode="exec"
|
||||
transformed_ast,
|
||||
filename=file_name,
|
||||
mode="exec",
|
||||
)
|
||||
|
||||
return self.preprocessor.exec(
|
||||
original_function.__name__,
|
||||
original_function,
|
||||
@@ -1523,7 +1573,7 @@ class BaseDSL:
|
||||
|
||||
pipeline = kwargs.pop("pipeline", None)
|
||||
gpu_module_attrs = kwargs.pop("gpu_module_attrs", {})
|
||||
decorator_frame = kwargs.pop("_decorator_frame", None)
|
||||
self.decorator_location = getattr(funcBody, "_decorator_location", None)
|
||||
|
||||
# Disable cache
|
||||
no_cache = kwargs.pop("no_cache", False)
|
||||
@@ -1556,7 +1606,7 @@ class BaseDSL:
|
||||
function_name = self.mangle_name(function_name, canonicalized_args, args_spec)
|
||||
self.compile_options.apply_envar_settings(self.envar, function_name)
|
||||
if not self.compile_options.generate_line_info:
|
||||
decorator_frame = None
|
||||
self.decorator_location = None
|
||||
|
||||
# Generate MLIR Context and start generating IR
|
||||
log().debug(f"Generating MLIR for function '{function_name}'")
|
||||
@@ -1570,7 +1620,7 @@ class BaseDSL:
|
||||
pipeline,
|
||||
no_cache,
|
||||
compile_only,
|
||||
frame=decorator_frame,
|
||||
location=self.decorator_location,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -1679,8 +1729,7 @@ class BaseDSL:
|
||||
"""
|
||||
ret = None
|
||||
|
||||
with ir.Context(), self.get_location():
|
||||
loc = self.get_location()
|
||||
with ir.Context(), self.get_ir_location() as loc:
|
||||
module = ir.Module.create(loc=loc)
|
||||
unit_attr = ir.UnitAttr.get()
|
||||
module.operation.attributes["gpu.container_module"] = unit_attr
|
||||
@@ -1819,7 +1868,7 @@ class BaseDSL:
|
||||
)
|
||||
)
|
||||
|
||||
loc = self.get_location()
|
||||
loc = self.get_ir_location()
|
||||
with self._enter_gpu_module():
|
||||
log().debug("Generating device kernel")
|
||||
if self.device_compilation_only:
|
||||
|
||||
@@ -138,6 +138,7 @@ class MLIRBuilder(MLIRTypeBuilder):
|
||||
super().__init__()
|
||||
self.module: Optional[ir.Module] = None
|
||||
self.const_str_table: dict[str, ir.Value] = {}
|
||||
self.const_func_ptr_table: dict[str, ir.Value] = {}
|
||||
self.get_element_extra_kwargs: dict[str, Any] = {}
|
||||
|
||||
# create constants
|
||||
@@ -368,6 +369,64 @@ class MLIRBuilder(MLIRTypeBuilder):
|
||||
self.const_str_table[content] = symbol
|
||||
return symbol
|
||||
|
||||
def get_or_load_global_func_ptr_from_text(
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
function_name: str,
|
||||
) -> ir.Value:
|
||||
"""Get or create a function pointer global in .text section and load it.
|
||||
|
||||
This creates a constant global function pointer in the .text section
|
||||
(for AArch64 ADRP range compatibility) and performs a volatile load
|
||||
to prevent optimization.
|
||||
|
||||
This forces the function pointer to be local to the code, bypassing GOT entry
|
||||
ADRP lookup issues on AArch64 when GOT and .text section are more than 4GB
|
||||
apart which can happen when ASLR is applied.
|
||||
"""
|
||||
# Check if we've already created this global
|
||||
if function_name not in self.const_func_ptr_table:
|
||||
symbol = f"__func_ptr_{function_name}"
|
||||
|
||||
module_body = self.module.body
|
||||
with ir.InsertionPoint(module_body):
|
||||
# 1. Create the global constant
|
||||
# We use 'private' linkage so it doesn't conflict across modules
|
||||
global_ptr = llvm.GlobalOp(
|
||||
self.ptr_type,
|
||||
symbol,
|
||||
ir.Attribute.parse("#llvm.linkage<private>"),
|
||||
# Initialization via block below
|
||||
)
|
||||
|
||||
# 2. Set the necessary attributes for JIT safety and AArch64 range
|
||||
# We use 'constant' to mark it as immutable
|
||||
# We use 'section = ".text"' to force it into the code block
|
||||
global_ptr.attributes["constant"] = ir.UnitAttr.get()
|
||||
global_ptr.attributes["section"] = ir.StringAttr.get(".text")
|
||||
|
||||
# 3. Add a constructor block to the GlobalOp to initialize it
|
||||
# with the address of the target function
|
||||
initializer_block = global_ptr.initializer.blocks.append()
|
||||
with ir.InsertionPoint(initializer_block):
|
||||
# Get the address of the external function
|
||||
func_addr = llvm.AddressOfOp(self.ptr_type, function_name).res
|
||||
# Return the address as the initial value of the global
|
||||
llvm.return_(arg=func_addr)
|
||||
|
||||
self.const_func_ptr_table[function_name] = symbol
|
||||
else:
|
||||
symbol = self.const_func_ptr_table[function_name]
|
||||
|
||||
# Load it with volatile semantics in the current block
|
||||
with ir.InsertionPoint(current_block):
|
||||
symbol_addr = self.address_of(symbol, self.ptr_type)
|
||||
# Perform a volatile load to prevent optimization
|
||||
load_op = llvm.load(self.ptr_type, symbol_addr)
|
||||
# Set volatile attribute to prevent optimization
|
||||
load_op.owner.attributes["volatile_"] = ir.UnitAttr.get()
|
||||
return load_op
|
||||
|
||||
# function
|
||||
def function(
|
||||
self,
|
||||
|
||||
@@ -210,7 +210,7 @@ EnableTVMFFI = _dsl.EnableTVMFFI
|
||||
# attach the TVM FFI ABI interface postprocessor to the DSL
|
||||
from . import _tvm_ffi_args_spec_converter
|
||||
|
||||
_tvm_ffi_args_spec_converter.attach_args_spec_converter()
|
||||
_tvm_ffi_args_spec_converter.attach_args_spec_converter(_dsl.CuTeDSL._get_dsl())
|
||||
|
||||
# Explicitly export all symbols for documentation generation
|
||||
__all__ = [
|
||||
|
||||
@@ -395,8 +395,6 @@ def _tvm_ffi_args_spec_converter(
|
||||
return params, kwargs_wrapper_spec
|
||||
|
||||
|
||||
def attach_args_spec_converter():
|
||||
"""Attach TVM FFI ABI interface postprocessor to the DSL."""
|
||||
from .. import cutlass_dsl as _dsl
|
||||
|
||||
_dsl.CuTeDSL._get_dsl()._tvm_ffi_args_spec_converter = _tvm_ffi_args_spec_converter
|
||||
def attach_args_spec_converter(dsl):
|
||||
"""Attach TVM FFI ABI interface postprocessor to the DSL instance."""
|
||||
dsl._tvm_ffi_args_spec_converter = _tvm_ffi_args_spec_converter
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
# is strictly prohibited.
|
||||
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.cutlass_dsl import CuTeDSL, T, dsl_user_op
|
||||
from cutlass.cutlass_dsl import BaseDSL, T, dsl_user_op
|
||||
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir.dialects import nvvm, scf
|
||||
@@ -69,7 +69,7 @@ def elect_one(*, loc=None, ip=None) -> IfOpRegion:
|
||||
# Only one thread in the warp executes the code in this context
|
||||
pass
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
is_thread_leader = nvvm.elect_sync(T.bool())
|
||||
if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip)
|
||||
return IfOpRegion(if_op.then_block, loc=loc, ip=ip)
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op
|
||||
from cutlass.cutlass_dsl import BaseDSL, T, if_generate, dsl_user_op
|
||||
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
|
||||
@@ -44,7 +44,7 @@ def mbarrier_init_fence(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
A fence operation that applies to the mbarrier initializations.
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
nvvm.fence_mbarrier_init(loc=loc, ip=ip)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ def mbarrier_arrive_and_expect_tx(
|
||||
the mbarrier is converted to a remote address in the peer CTA's
|
||||
SMEM.
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
if peer_cta_rank_in_cluster is not None:
|
||||
@@ -103,7 +103,7 @@ def mbarrier_expect_tx(
|
||||
the mbarrier is converted to a remote address in the peer CTA's
|
||||
SMEM.
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
if peer_cta_rank_in_cluster is not None:
|
||||
@@ -138,7 +138,7 @@ def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None:
|
||||
:param phase: The phase to wait for (either 0 or 1)
|
||||
:type phase: Int
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
|
||||
timeout_ns = 10000000
|
||||
# This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX
|
||||
@@ -164,7 +164,7 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo
|
||||
:return: A boolean value indicating whether the wait operation was successful
|
||||
:rtype: Boolean
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
|
||||
return Boolean(
|
||||
nvvm.mbarrier_wait_parity(
|
||||
@@ -193,7 +193,7 @@ def mbarrier_conditional_try_wait(
|
||||
:return: A boolean value indicating whether the wait operation was successful
|
||||
:rtype: Boolean
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
return if_generate(
|
||||
cond,
|
||||
lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip),
|
||||
@@ -225,7 +225,7 @@ def mbarrier_arrive(
|
||||
"""
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
if peer_cta_rank_in_cluster is not None:
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
|
||||
mbar_llvm_ptr = nvvm.mapa_shared_cluster(
|
||||
mbar_llvm_ptr.type,
|
||||
@@ -259,7 +259,7 @@ def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> N
|
||||
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
||||
:type mbar_ptr: Pointer
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
nvvm.cp_async_mbarrier_arrive_shared(
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.base_dsl.common import DSLRuntimeError
|
||||
from cutlass.cutlass_dsl import CuTeDSL, dsl_user_op
|
||||
from cutlass.cutlass_dsl import BaseDSL, dsl_user_op
|
||||
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import builtin, arith, llvm, vector
|
||||
|
||||
@@ -53,7 +54,7 @@ def cvt_i8_bf16_intrinsic(vec_i8, length, *, loc=None, ip=None):
|
||||
:return: The output 1D vector of bfloat16 with the same length as the input vector.
|
||||
:rtype: 1D vector of bfloat16
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch in cvt_i8_bf16_intrinsic.supported_archs:
|
||||
raise DSLRuntimeError(f"cvt_i8_bf16_intrinsic is not supported on {arch}")
|
||||
src_pos = 0
|
||||
@@ -130,7 +131,7 @@ def cvt_i4_bf16_intrinsic(vec_i4, length, *, loc=None, ip=None):
|
||||
:return: The output 1D vector of bfloat16 with the same length as the input vector.
|
||||
:rtype: 1D vector of bfloat16
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch in cvt_i4_bf16_intrinsic.supported_archs:
|
||||
raise DSLRuntimeError(f"cvt_i4_bf16_intrinsic is not supported on {arch}")
|
||||
src_pos = 0
|
||||
|
||||
@@ -1305,6 +1305,46 @@ def exp_packed_f32x2(
|
||||
return exp2(b[0], loc=loc, ip=ip), exp2(b[1], loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def griddepcontrol_wait(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
This instruction is used to wait for the previous kernel's grid ending
|
||||
(all blocks of the previous kernel have finished and memflushed), i.e.,
|
||||
the instruction after this instruction will not be issued until the previous
|
||||
grid has finished.
|
||||
"""
|
||||
llvm.inline_asm(
|
||||
res=None,
|
||||
operands_=[],
|
||||
asm_string="griddepcontrol.wait;",
|
||||
constraints="",
|
||||
has_side_effects=True,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def griddepcontrol_launch_dependents(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Issuing the launch_dependents instruction hints a dependent kernel to launch earlier.
|
||||
launch_dependents doesn't impact the functionality but the performance:
|
||||
Launching a dependent kernel too early can compete with current kernels,
|
||||
while launching too late can lead to a long latency.
|
||||
"""
|
||||
llvm.inline_asm(
|
||||
res=None,
|
||||
operands_=[],
|
||||
asm_string="griddepcontrol.launch_dependents;",
|
||||
constraints="",
|
||||
has_side_effects=True,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cvt_f4e2m1_f16(src, *, loc=None, ip=None):
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Optional, Type
|
||||
|
||||
from cutlass import cute
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.cutlass_dsl import CuTeDSL
|
||||
from cutlass.cutlass_dsl import BaseDSL
|
||||
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
@@ -146,7 +146,7 @@ class CopyBulkTensorTileG2SOp(TmaCopyOp):
|
||||
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
|
||||
)
|
||||
# Arch verification
|
||||
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch >= Arch.sm_90:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -263,7 +263,7 @@ class CopyBulkTensorTileG2SMulticastOp(TmaCopyOp):
|
||||
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
|
||||
)
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch >= Arch.sm_90:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -386,7 +386,7 @@ class CopyBulkTensorTileS2GOp(TmaCopyOp):
|
||||
|
||||
def __post_init__(self):
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch >= Arch.sm_90:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -561,7 +561,7 @@ class CopyBulkG2SOp(CopyOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch >= Arch.sm_90:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -646,7 +646,7 @@ class CopyBulkG2SMulticastOp(CopyOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch >= Arch.sm_90:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -740,7 +740,7 @@ class CopyBulkS2GOp(CopyOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch >= Arch.sm_90:
|
||||
raise OpError(
|
||||
self,
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Type
|
||||
|
||||
from cutlass import cute
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.cutlass_dsl import CuTeDSL
|
||||
from cutlass.cutlass_dsl import BaseDSL
|
||||
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
@@ -113,7 +113,7 @@ class _LdBase(CopyOp):
|
||||
:raises OpError: If pack parameter is not a Pack instance
|
||||
"""
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -416,7 +416,7 @@ class _StBase(CopyOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -625,7 +625,7 @@ class _S2TCopyBase(CopyOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch.is_family_of(Arch.sm_100f):
|
||||
raise OpError(
|
||||
self,
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Type, Any
|
||||
|
||||
from cutlass import cute
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.cutlass_dsl import CuTeDSL, T
|
||||
from cutlass.cutlass_dsl import BaseDSL, T
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
@@ -162,7 +162,7 @@ class MmaOp(Tcgen05MmaOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -314,7 +314,7 @@ class BlockScaledMmaOp(Tcgen05MmaOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -471,7 +471,7 @@ class SparseMmaOp(Tcgen05MmaOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Type, Any
|
||||
|
||||
from cutlass import cute
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.cutlass_dsl import CuTeDSL, T
|
||||
from cutlass.cutlass_dsl import BaseDSL, T
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
@@ -130,7 +130,7 @@ class MmaOp(WarpGroupMmaOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch == Arch.sm_90a:
|
||||
raise OpError(
|
||||
self,
|
||||
|
||||
@@ -925,7 +925,17 @@ def load_module(file_path: str, *, enable_tvm_ffi: bool = True):
|
||||
if enable_tvm_ffi:
|
||||
import tvm_ffi
|
||||
|
||||
return tvm_ffi.load_module(file_path)
|
||||
try:
|
||||
# keep_module_alive=False means the module will be unloaded
|
||||
# after the returned module goes out of scope, this is useful
|
||||
# for frequent loading and unloading of modules. The only requirement
|
||||
# is that the module do not return object that have deleter in the module
|
||||
# and the returned object lives longer than the module.
|
||||
# DSL functions to not have such issue so it is desirable to set this to False.
|
||||
return tvm_ffi.load_module(file_path, keep_module_alive=False)
|
||||
except TypeError:
|
||||
# compatible with tvm-ffi < 0.1.6
|
||||
return tvm_ffi.load_module(file_path)
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
"Unimplemented, please load the module with enable_tvm_ffi=True."
|
||||
|
||||
@@ -20,7 +20,7 @@ from cutlass.cutlass_dsl import (
|
||||
T,
|
||||
cutlass_arith,
|
||||
_binary_op_type_promote,
|
||||
CuTeDSL,
|
||||
BaseDSL,
|
||||
)
|
||||
from cutlass._mlir import ir
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
@@ -1776,7 +1776,7 @@ class TensorSSA(cutlass_arith.ArithValue):
|
||||
fast_cvt_func = cvt_i8_bf16_intrinsic
|
||||
elif src_dtype == Int4 and dtype == BFloat16:
|
||||
fast_cvt_func = cvt_i4_bf16_intrinsic
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if fast_cvt_func is not None and arch in fast_cvt_func.supported_archs:
|
||||
res_vect = fast_cvt_func(src, size(self.shape), loc=loc, ip=ip)
|
||||
else:
|
||||
|
||||
@@ -407,7 +407,7 @@ def benchmark(
|
||||
To use CUDA graphs, the callable must be a compiled @cute.jit annotated function.
|
||||
When using CUDA graphs, the kernel must be launched in a non-default stream.
|
||||
|
||||
:param callable: The function to benchmark
|
||||
:param callable: The function to benchmark. For jit function, it must be compiled functions.
|
||||
:type callable: Callable
|
||||
:param warmup_iterations: Number of warmup iterations, defaults to 10
|
||||
:type warmup_iterations: int, optional
|
||||
@@ -475,15 +475,6 @@ def benchmark(
|
||||
elapsed_time = float("nan")
|
||||
|
||||
if use_cuda_graphs:
|
||||
# Check if the callable is a JitCompiledFunction or JitExecutor
|
||||
# These are functions that can be called to launch kernels
|
||||
compiled_types = (
|
||||
cutlass.base_dsl.jit_executor.JitCompiledFunction,
|
||||
cutlass.base_dsl.jit_executor.JitExecutor,
|
||||
)
|
||||
if not isinstance(callable, compiled_types):
|
||||
raise TypeError("Function must be precompiled to be used with CUDA Graphs")
|
||||
|
||||
# Check if the stream is a non-default stream
|
||||
if int(stream) == int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT):
|
||||
raise ValueError(
|
||||
|
||||
@@ -247,7 +247,10 @@ class CutlassBaseDSL(BaseDSL):
|
||||
return False
|
||||
|
||||
def _build_gpu_module(self, attrs, loc=None):
|
||||
log().info(f"self : {self}")
|
||||
log().info(f"Building GPU module for {self.name}")
|
||||
self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels"), loc=loc)
|
||||
log().info(f"GPU module: {self.gpu_module}")
|
||||
with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])):
|
||||
pass
|
||||
|
||||
@@ -275,6 +278,9 @@ class CutlassBaseDSL(BaseDSL):
|
||||
return pipeline
|
||||
|
||||
def _enter_gpu_module(self):
|
||||
log().info(f"self: {self}")
|
||||
log().info(f"Entering GPU module for {self.name}")
|
||||
log().info(f"GPU module: {self.gpu_module}")
|
||||
return ir.InsertionPoint(self.gpu_module.bodyRegion.blocks[0])
|
||||
|
||||
def _generate_kernel_attrs(self, config: BaseDSL.LaunchConfig) -> dict:
|
||||
|
||||
@@ -126,16 +126,22 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
|
||||
)
|
||||
context.module.body.append(parsed_op)
|
||||
|
||||
|
||||
with ir.InsertionPoint(current_block):
|
||||
cuda_global_state_ptr = self.address_of(
|
||||
self.cuda_global_state_symbol, self.ptr_type
|
||||
)
|
||||
cuda_init_ptr = self.address_of("cuda_init", self.ptr_type)
|
||||
cuda_load_to_device_ptr = self.address_of("cuda_load_to_device", self.ptr_type)
|
||||
set_error_ptr = self.address_of(
|
||||
"TVMFFIErrorSetRaisedFromCStr", self.ptr_type
|
||||
)
|
||||
|
||||
cuda_init_ptr = context.builder.get_or_load_global_func_ptr_from_text(
|
||||
current_block, "cuda_init"
|
||||
)
|
||||
cuda_load_to_device_ptr = context.builder.get_or_load_global_func_ptr_from_text(
|
||||
current_block, "cuda_load_to_device"
|
||||
)
|
||||
set_error_ptr = context.builder.get_or_load_global_func_ptr_from_text(
|
||||
current_block, "TVMFFIErrorSetRaisedFromCStr"
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(current_block):
|
||||
# Call the callback function with the loaded ptr value
|
||||
init_result = llvm.call(
|
||||
result=self.i32_type, # function returns i32
|
||||
@@ -495,6 +501,13 @@ class TVMFFIJitCompiledFunctionBase(CudaDialectJitCompiledFunction):
|
||||
"""Create the tvm_ffi.Function from the current execution engine.
|
||||
"""
|
||||
if self.engine is not None:
|
||||
# trigger eager compile of init callbacks
|
||||
cuda_init = self.engine.raw_lookup("cuda_init")
|
||||
cuda_load_to_device = self.engine.raw_lookup("cuda_load_to_device")
|
||||
if cuda_init is None:
|
||||
raise DSLRuntimeError("cuda_init not found")
|
||||
if cuda_load_to_device is None:
|
||||
raise DSLRuntimeError("cuda_load_to_device not found")
|
||||
tvm_ffi_function_ptr = self.engine.raw_lookup(
|
||||
"__tvm_ffi_" + self.function_name
|
||||
)
|
||||
|
||||
@@ -261,7 +261,7 @@ def make_smem_layout_a(
|
||||
a_smem_layout_staged = cute.tile_to_shape(
|
||||
a_smem_layout_atom,
|
||||
cute.append(a_smem_shape, num_stages),
|
||||
order=(0, 1, 2) if is_k_major else (0, 1, 2),
|
||||
order=(0, 1, 2) if is_k_major else (1, 0, 2),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl==4.3.3
|
||||
nvidia-cutlass-dsl==4.3.4
|
||||
|
||||
@@ -133,7 +133,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '4.3.3'
|
||||
this.__version__ = '4.3.4'
|
||||
|
||||
from cutlass_cppgen.backend import create_memory_pool
|
||||
from cutlass_cppgen.emit.pytorch import pytorch
|
||||
|
||||
@@ -51,7 +51,7 @@ setup_pycute.perform_setup()
|
||||
|
||||
setup(
|
||||
name='cutlass_cppgen',
|
||||
version='4.3.3',
|
||||
version='4.3.4',
|
||||
description='CUTLASS Pythonic Interface',
|
||||
package_dir={'': '.'},
|
||||
packages=[
|
||||
|
||||
@@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='cutlass_library',
|
||||
version='4.3.3',
|
||||
version='4.3.4',
|
||||
description='CUTLASS library generation scripts',
|
||||
packages=['cutlass_library']
|
||||
)
|
||||
|
||||
@@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='pycute',
|
||||
version='4.3.3',
|
||||
version='4.3.4',
|
||||
description='Python implementation of CuTe',
|
||||
packages=['pycute'],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user