mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 14:59:01 +00:00
Release v4.0.0 (#2294)
This commit is contained in:
392
examples/python/CuTeDSL/ampere/elementwise_add.py
Normal file
392
examples/python/CuTeDSL/ampere/elementwise_add.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import time
|
||||
from typing import Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.torch as cutlass_torch
|
||||
|
||||
"""
|
||||
An Elementwise Addition Example using CuTe DSL.
|
||||
|
||||
This example kernel copies data from global memory to register memory (rmem), performs the elementwise
|
||||
addition operation, and stores the result back to global memory.
|
||||
|
||||
Primary goals of this example are to demonstrate how basic global memory copies can be expressed in
|
||||
CuTe DSL and illustrate canonical partitioning patterns in CuTe. It also implements canonical
|
||||
predication for tensors whose shape is not multiple of tile size to guard OOB reads.
|
||||
|
||||
Thread-value (or TV) layouts are central to canonical partitioning patterns in CuTe. They provide a
|
||||
mapping from thread and a thread's value to the set of coordinates within a tile that we have sliced
|
||||
out from a data tensor.
|
||||
|
||||
The input tensors are row-major layout, that leading dimension is the right most dimension. In order
|
||||
to efficiently copy data from global memory, we must map threads contiguously on row dimension.
|
||||
|
||||
Thread ID mapping to 2D coordinates with layout `(4,32):(32,1)`:
|
||||
|
||||
+----+----+----+----+-----+----+
|
||||
| | 0 | 1 | 2 | ... | 31 |
|
||||
+----+----+----+----+-----+----+
|
||||
| 0 | T0 | T1 | T2 | ... | T31|
|
||||
+----+----+----+----+-----+----+
|
||||
| 1 |T32 |T33 |T34 | ... |T63 |
|
||||
+----+----+----+----+-----+----+
|
||||
| 2 |T64 |T65 |T66 | ... |T95 |
|
||||
+----+----+----+----+-----+----+
|
||||
| 3 |T96 |T97 |T98 | ... |T127|
|
||||
+----+----+----+----+-----+----+
|
||||
|
||||
As Ampere GPU supports a maximum of 128bit per load/store instruction and each element is 32bit, we
|
||||
can load 4 elements per instruction. Having additional contiguous values allows for vectorization
|
||||
across threads (coalesced accesses) and is required for saturating the memory bandwidth.
|
||||
|
||||
We use `(4,4):(4,1)` as the val layout in this example. Notice that the major mode is the same as
|
||||
the major mode of the input tensor - without which vectorization would not be possible.
|
||||
|
||||
If you already know the TV layout you want to use for your tiled copy, CuTe DSL provides utility
|
||||
`cute.make_layout_tv` to build the tiled copy type around it and the atom of your choice.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
thr_layout = cute.make_layout((4, 32), stride=(32, 1))
|
||||
val_layout = cute.make_layout((4, 4), stride=(4, 1))
|
||||
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
||||
|
||||
# Tile input tensor to thread blocks: ((TileM,TileN),(RestM,RestN))
|
||||
gA = cute.zipped_divide(mA, tiler_mn)
|
||||
|
||||
where `tiler_mn` is the tile size per thread block and `tv_layout` is the TV layout which maps
|
||||
thread index and inter-thread index of data array per thread to logical coordinates of elements in
|
||||
input and output tensors.
|
||||
|
||||
Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy` utility.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
blkA = gA[((None, None), bidx)] # (TileM,TileN)
|
||||
|
||||
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
|
||||
tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
|
||||
|
||||
# get slice of tiled_copy_A for current thread
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
|
||||
# partition per thread block tensor as source of tiled copy
|
||||
thrA = thr_copy_A.partition_S(blkA)
|
||||
|
||||
# allocate fragment for gmem->rmem
|
||||
frgA = cute.make_fragment_like(thrA)
|
||||
|
||||
# copy data from global memory to register memory
|
||||
cute.copy(copy_atom_load, thrA, frgA)
|
||||
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/ampere/elementwise_add.py --M 3 --N 12
|
||||
python examples/ampere/elementwise_add.py --M 1024 --N 512
|
||||
python examples/ampere/elementwise_add.py --M 1024 --N 1024 --benchmark --warmup_iterations 2 --iterations 1000
|
||||
|
||||
To collect performance with NCU profiler:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Don't iterate too many times when profiling with ncu
|
||||
ncu python examples/ampere/elementwise_add.py --M 2048 --N 2048 --benchmark --iterations 10 --skip_ref_check
|
||||
"""
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def elementwise_add_kernel(
|
||||
gA: cute.Tensor,
|
||||
gB: cute.Tensor,
|
||||
gC: cute.Tensor,
|
||||
cC: cute.Tensor, # coordinate tensor
|
||||
shape: cute.Shape,
|
||||
tv_layout: cute.Layout,
|
||||
tiler_mn: cute.Shape,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
# slice for CTAs
|
||||
# logical id -> address
|
||||
blk_coord = ((None, None), bidx)
|
||||
blkA = gA[blk_coord] # (TileM,TileN)
|
||||
blkB = gB[blk_coord] # (TileM,TileN)
|
||||
blkC = gC[blk_coord] # (TileM,TileN)
|
||||
blkCrd = cC[blk_coord] # (TileM, TileN)
|
||||
|
||||
print(f"[DSL INFO] Sliced Tensors per thread block:")
|
||||
print(f"[DSL INFO] blkA = {blkA.type}")
|
||||
print(f"[DSL INFO] blkB = {blkB.type}")
|
||||
print(f"[DSL INFO] blkC = {blkC.type}")
|
||||
print(f"[DSL INFO] blkCrd = {blkCrd.type}")
|
||||
|
||||
# # declare the atoms which will be used later for memory copy
|
||||
copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
|
||||
copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type)
|
||||
|
||||
tiled_copy_A = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
|
||||
tiled_copy_B = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn)
|
||||
tiled_copy_C = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn)
|
||||
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
thr_copy_C = tiled_copy_C.get_slice(tidx)
|
||||
|
||||
thrA = thr_copy_A.partition_S(blkA)
|
||||
thrB = thr_copy_B.partition_S(blkB)
|
||||
thrC = thr_copy_C.partition_S(blkC)
|
||||
|
||||
# allocate fragments for gmem->rmem
|
||||
frgA = cute.make_fragment_like(thrA)
|
||||
frgB = cute.make_fragment_like(thrB)
|
||||
frgC = cute.make_fragment_like(thrC)
|
||||
|
||||
thrCrd = thr_copy_C.partition_S(blkCrd)
|
||||
frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)
|
||||
|
||||
print(f"[DSL INFO] Sliced Tensors per thread:")
|
||||
print(f"[DSL INFO] thrA = {thrA.type}")
|
||||
print(f"[DSL INFO] thrB = {thrB.type}")
|
||||
print(f"[DSL INFO] thrC = {thrC.type}")
|
||||
print(f"[DSL INFO] thrCrd = {thrCrd.type}")
|
||||
|
||||
for i in cutlass.range_dynamic(0, cute.size(frgPred), 1):
|
||||
val = cute.elem_less(thrCrd[i], shape)
|
||||
frgPred[i] = val
|
||||
|
||||
# Print per thread predicate mask
|
||||
# if tidx == 0 and bidx == 0:
|
||||
# cute.printf("block_dim = {}", cute.arch.grid_dim())
|
||||
# cute.printf("shape = {}", shape)
|
||||
# cute.print_tensor(thrA)
|
||||
# cute.print_tensor(thrB)
|
||||
# cute.print_tensor(frgPred)
|
||||
|
||||
##########################################################
|
||||
# Move data to reg address space
|
||||
##########################################################
|
||||
|
||||
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
|
||||
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
|
||||
|
||||
# if tidx == 0 and bidx == 0:
|
||||
# cute.print_tensor(frgA)
|
||||
# cute.print_tensor(frgB)
|
||||
|
||||
# Load data before use. The compiler will optimize the copy and load
|
||||
# operations to convert some memory ld/st into register uses.
|
||||
result = frgA.load() + frgB.load()
|
||||
|
||||
# Save the results back to registers. Here we reuse b's registers.
|
||||
frgC.store(result)
|
||||
|
||||
# Copy the results back to c
|
||||
cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def elementwise_add(mA, mB, mC, copy_bits: cutlass.Constexpr = 128):
|
||||
dtype = mA.element_type
|
||||
vector_size = copy_bits // dtype.width
|
||||
|
||||
thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
|
||||
val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0))
|
||||
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
||||
|
||||
print(f"[DSL INFO] Input Tensors:")
|
||||
print(f"[DSL INFO] mA = {mA.type}")
|
||||
print(f"[DSL INFO] mB = {mB.type}")
|
||||
|
||||
print(f"[DSL INFO] Tiling Parameters:")
|
||||
print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block")
|
||||
print(f"[DSL INFO] tv_layout = {tv_layout}")
|
||||
|
||||
gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
print(f"[DSL INFO] Tiled Tensors:")
|
||||
print(f"[DSL INFO] gA = {gA.type}")
|
||||
print(f"[DSL INFO] gB = {gB.type}")
|
||||
print(f"[DSL INFO] gC = {gC.type}")
|
||||
|
||||
idC = cute.make_identity_tensor(mC.shape)
|
||||
cC = cute.zipped_divide(idC, tiler=tiler_mn)
|
||||
print(f"[DSL INFO] coord tensor = {cC.type}")
|
||||
|
||||
elementwise_add_kernel(gA, gB, gC, cC, mC.shape, tv_layout, tiler_mn).launch(
|
||||
grid=[cute.size(gC, mode=[1]), 1, 1],
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
)
|
||||
|
||||
|
||||
def run_elementwise_add(
|
||||
M,
|
||||
N,
|
||||
dtype: Type[cutlass.Numeric],
|
||||
is_a_dynamic_layout=False,
|
||||
is_b_dynamic_layout=False,
|
||||
is_result_dynamic_layout=False,
|
||||
skip_ref_check=False,
|
||||
benchmark=True,
|
||||
warmup_iterations=2,
|
||||
iterations=200,
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(f"Ampere GPU is required to run this example!")
|
||||
|
||||
print(f"\nRunning Elementwise Add test with:")
|
||||
print(f"Tensor dimensions: [{M}, {N}]")
|
||||
print(f"Input and Output Data type: {dtype}")
|
||||
|
||||
torch_dtype = cutlass_torch.dtype(dtype)
|
||||
if dtype.is_integer:
|
||||
a = torch.randint(0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype)
|
||||
b = torch.randint(0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype)
|
||||
else:
|
||||
a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
|
||||
c = torch.zeros_like(a)
|
||||
|
||||
print(f"Input tensor shapes:")
|
||||
print(f"a: {a.shape}, dtype: {a.dtype}")
|
||||
print(f"b: {b.shape}, dtype: {b.dtype}")
|
||||
print(f"c: {c.shape}, dtype: {c.dtype}\n")
|
||||
|
||||
if not is_a_dynamic_layout:
|
||||
a_tensor = from_dlpack(a).mark_layout_dynamic()
|
||||
else:
|
||||
a_tensor = a
|
||||
|
||||
if not is_b_dynamic_layout:
|
||||
b_tensor = from_dlpack(b).mark_layout_dynamic()
|
||||
else:
|
||||
b_tensor = b
|
||||
|
||||
if not is_result_dynamic_layout:
|
||||
c_tensor = from_dlpack(c).mark_layout_dynamic()
|
||||
else:
|
||||
c_tensor = c
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
compiled_func = cute.compile(elementwise_add, a_tensor, b_tensor, c_tensor)
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
print("Executing vector add kernel...")
|
||||
|
||||
# Get current CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
if not skip_ref_check:
|
||||
compiled_func(a_tensor, b_tensor, c_tensor)
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(a + b, c)
|
||||
print("Results verified successfully!")
|
||||
|
||||
if not benchmark:
|
||||
return
|
||||
|
||||
# Create CUDA events for timing
|
||||
start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_func(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
# Use the current stream for CUDA events instead of the default stream
|
||||
# Record start event
|
||||
cuda.cuEventRecord(start_event, current_stream)
|
||||
|
||||
# Execute the kernel
|
||||
for _ in range(iterations):
|
||||
compiled_func(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
# Record end event
|
||||
cuda.cuEventRecord(end_event, current_stream)
|
||||
cuda.cuEventSynchronize(end_event)
|
||||
|
||||
# Calculate elapsed time
|
||||
err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event)
|
||||
avg_time = elapsed_time / iterations
|
||||
|
||||
# Print execution results
|
||||
print(f"Kernel execution time: {avg_time:.4f} ms")
|
||||
print(
|
||||
f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9:.2f} GB/s"
|
||||
)
|
||||
print(f"First few elements of result: \n{c[:3, :3]}")
|
||||
|
||||
# Destroy events
|
||||
cuda.cuEventDestroy(start_event)
|
||||
cuda.cuEventDestroy(end_event)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels"
|
||||
)
|
||||
parser.add_argument("--M", default=1024, type=int)
|
||||
parser.add_argument("--N", default=1024, type=int)
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument("--benchmark", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
run_elementwise_add(
|
||||
args.M,
|
||||
args.N,
|
||||
dtype=cutlass.Float32,
|
||||
is_a_dynamic_layout=True,
|
||||
is_b_dynamic_layout=True,
|
||||
is_result_dynamic_layout=True,
|
||||
skip_ref_check=args.skip_ref_check,
|
||||
benchmark=args.benchmark,
|
||||
warmup_iterations=args.warmup_iterations,
|
||||
iterations=args.iterations,
|
||||
)
|
||||
print("\nPASS")
|
||||
395
examples/python/CuTeDSL/ampere/elementwise_apply.py
Normal file
395
examples/python/CuTeDSL/ampere/elementwise_apply.py
Normal file
@@ -0,0 +1,395 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import argparse
|
||||
import operator
|
||||
import torch
|
||||
from typing import Type
|
||||
import time
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
An Elementwise Apply Example using CuTe DSL.
|
||||
|
||||
This example kernel demonstrates the meta-programming capability of the CuTe DSL by allowing
|
||||
customization of elementwise operations through lambda functions. The kernel copies data from
|
||||
global memory to register memory (rmem), applies a user-defined operation to the elements,
|
||||
and stores the result back to global memory.
|
||||
|
||||
Primary goals of this example:
|
||||
1. Demonstrate meta-programming capability by passing lambda functions to customize elementwise operations
|
||||
2. Show how to apply different operations (add, multiply, etc.) using the same kernel structure
|
||||
3. Illustrate how to parameterize CUDA kernels with operation types at compile time
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Run with addition operation
|
||||
python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op add
|
||||
|
||||
# Run with multiplication operation
|
||||
python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op mul
|
||||
|
||||
# Run with subtraction operation
|
||||
python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op sub
|
||||
|
||||
# Benchmark performance
|
||||
python examples/ampere/elementwise_apply.py --M 2048 --N 2048 --op add --benchmark --warmup_iterations 2 --iterations 10
|
||||
|
||||
The example demonstrates how to express complex CUDA kernels with customizable operations
|
||||
while maintaining high performance through efficient memory access patterns.
|
||||
"""
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def elementwise_apply_kernel(
|
||||
op: cutlass.Constexpr,
|
||||
gA: cute.Tensor,
|
||||
gB: cute.Tensor,
|
||||
gC: cute.Tensor,
|
||||
cC: cute.Tensor, # coordinate tensor
|
||||
shape: cute.Shape,
|
||||
tv_layout: cute.Layout, # (tid, vid) -> logic coord
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
# slice for CTAs
|
||||
cta_coord = ((None, None), bidx)
|
||||
# logical coord -> address
|
||||
ctaA = gA[cta_coord] # (TileM, TileN)
|
||||
ctaB = gB[cta_coord] # (TileM, TileN)
|
||||
ctaC = gC[cta_coord] # (TileM, TileN)
|
||||
ctaCrd = cC[cta_coord] # (TileM, TileN)
|
||||
|
||||
print(f"[DSL INFO] Sliced Tensors per thread block:")
|
||||
print(f"[DSL INFO] ctaA = {ctaA.type}")
|
||||
print(f"[DSL INFO] ctaB = {ctaB.type}")
|
||||
print(f"[DSL INFO] ctaC = {ctaC.type}")
|
||||
print(f"[DSL INFO] ctaCrd = {ctaCrd.type}")
|
||||
|
||||
# compose with CTA TV layout
|
||||
# (tid, vid) -> address
|
||||
tidfrgA = cute.composition(ctaA, tv_layout)
|
||||
tidfrgB = cute.composition(ctaB, tv_layout)
|
||||
tidfrgC = cute.composition(ctaC, tv_layout)
|
||||
tidfrgCrd = cute.composition(ctaCrd, tv_layout)
|
||||
# print(f"{tv_layout = }")
|
||||
# print(f"{tidfrgA = }")
|
||||
|
||||
thr_coord = (tidx, (None, None))
|
||||
|
||||
# slice for threads
|
||||
# vid -> address
|
||||
thrA = tidfrgA[thr_coord] # (V)
|
||||
thrB = tidfrgB[thr_coord] # (V)
|
||||
thrC = tidfrgC[thr_coord] # (V)
|
||||
thrCrd = tidfrgCrd[thr_coord]
|
||||
|
||||
print(f"[DSL INFO] Sliced Tensors per thread:")
|
||||
print(f"[DSL INFO] thrA = {thrA.type}")
|
||||
print(f"[DSL INFO] thrB = {thrB.type}")
|
||||
print(f"[DSL INFO] thrC = {thrC.type}")
|
||||
print(f"[DSL INFO] thrCrd = {thrCrd.type}")
|
||||
|
||||
# allocate fragments for gmem->rmem
|
||||
frgA = cute.make_fragment_like(thrA, gA.element_type)
|
||||
frgB = cute.make_fragment_like(thrB, gB.element_type)
|
||||
frgC = cute.make_fragment_like(thrC, gC.element_type)
|
||||
frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)
|
||||
|
||||
for i in cutlass.range_dynamic(cute.size(frgPred), unroll=1):
|
||||
frgPred[i] = cute.elem_less(thrCrd[i], shape)
|
||||
|
||||
# if tidx == 0 and bidx == 0:
|
||||
# cute.print_tensor(frgPred)
|
||||
|
||||
##########################################################
|
||||
# Move data to reg address space
|
||||
##########################################################
|
||||
|
||||
# declare the atoms which will be used later for memory copy
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
gA.element_type,
|
||||
num_bits_per_copy=gA.element_type.width,
|
||||
)
|
||||
copy_atom_store = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
gC.element_type,
|
||||
num_bits_per_copy=gC.element_type.width,
|
||||
)
|
||||
|
||||
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
|
||||
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
|
||||
|
||||
# Load data before use. The compiler will optimize the copy and load
|
||||
# operations to convert some memory ld/st into register uses.
|
||||
result = op(frgA.load(), frgB.load())
|
||||
|
||||
# Save the results back to registers. Here we reuse b's registers.
|
||||
frgC.store(result)
|
||||
|
||||
# Copy the results back to c
|
||||
cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def elementwise_apply(
|
||||
op: cutlass.Constexpr,
|
||||
a: cute.Tensor,
|
||||
b: cute.Tensor,
|
||||
result: cute.Tensor,
|
||||
):
|
||||
"""CUDA kernel applying binary operator on each element of two n-D input tensors in
|
||||
CuTe Python and store to result tensor.
|
||||
|
||||
:param op: Binary operator or lambda function to apply element-wise
|
||||
:type op: cutlass.Constexpr
|
||||
:param a: First input tensor
|
||||
:type a: cute.Tensor
|
||||
:param b: Second input tensor
|
||||
:type b: cute.Tensor
|
||||
:param result: Output tensor to store the results of op(a, b)
|
||||
:type result: cute.Tensor
|
||||
:return: None
|
||||
:rtype: None
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Example 1: Adding two tensors
|
||||
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, device="cuda")
|
||||
y = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32, device="cuda")
|
||||
result = torch.empty_like(x)
|
||||
elementwise_apply(operator.add, from_dlpack(x), from_dlpack(y), from_dlpack(result))
|
||||
# result:
|
||||
# tensor([[6.0, 8.0],
|
||||
# [10.0, 12.0]], device='cuda:0')
|
||||
|
||||
# Example 2: Using a lambda function
|
||||
elementwise_apply(lambda a, b: a * a + b * b, from_dlpack(x), from_dlpack(y), from_dlpack(result))
|
||||
# result:
|
||||
# tensor([[ 2., 8.],
|
||||
# [ 54., 512.]], device='cuda:0')
|
||||
"""
|
||||
|
||||
# Baseline: naive TV layout
|
||||
# * mA layout: (4096, 4096):(4096, 1)
|
||||
# * TV layout map to (512, 4) tile
|
||||
# * tidx maps to mode-0 but input layout is contiguous on mode-1, performance will be bad
|
||||
# tv_layout = cute.make_layout((128, (4, 4)), stride=(4, (512, 1)))
|
||||
# cta_tiler = (512, 4)
|
||||
|
||||
# Opt-1: better TV layout with better 1D thread layout (SOL with 1D thread layout)
|
||||
# * mA layout: (4096, 4096):(4096, 1)
|
||||
# * TV layout map to (4, 512) tile
|
||||
# * tidx maps to mode-1 which is leading mode of input tensor for coalesced load
|
||||
# tv_layout = cute.make_layout((128, (4, 4)), stride=(16, (4, 1)))
|
||||
# cta_tiler = (4, 512)
|
||||
|
||||
# Opt-2: 2D tile but worse
|
||||
# * mA layout: (4096, 4096):(4096, 1)
|
||||
# * TV layout map to (128, 16) logical tile
|
||||
# * V layout is bad as contiguous mode is not on right-most
|
||||
# * `cute.copy` only supports vectorize when stride-1 of v-layout on right-most )
|
||||
# tv_layout = cute.make_layout(((32, 4), (4, 4)), stride=((4, 512), (1, 128)))
|
||||
# cta_tiler = (128, 16)
|
||||
|
||||
# Opt-3: SOL with 2D thread tile
|
||||
# * mA layout: (4096, 4096):(4096, 1)
|
||||
# * TV layout map to (16, 128) logical tile
|
||||
# * tidx maps to mode-1 and input layout is contiguous on mode-1 for coalesced load-store
|
||||
thr_layout = cute.make_layout((4, 32), stride=(32, 1))
|
||||
val_layout = cute.make_layout((4, 4), stride=(4, 1))
|
||||
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
||||
|
||||
print(f"[DSL INFO] Input Tensors:")
|
||||
print(f"[DSL INFO] a = {a.type}")
|
||||
print(f"[DSL INFO] b = {b.type}")
|
||||
print(f"[DSL INFO] result = {result.type}")
|
||||
|
||||
print(f"[DSL INFO] Tiling Parameters:")
|
||||
print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block")
|
||||
print(f"[DSL INFO] tv_layout = {tv_layout}")
|
||||
|
||||
gA = cute.zipped_divide(a, tiler_mn) # ((TileM, TileN), (RestM, RestN))
|
||||
gB = cute.zipped_divide(b, tiler_mn) # ((TileM, TileN), (RestM, RestN))
|
||||
gC = cute.zipped_divide(result, tiler_mn) # ((TileM, TileN), (RestM, RestN))
|
||||
|
||||
print(f"[DSL INFO] Tiled Tensors:")
|
||||
print(f"[DSL INFO] gA = {gA.type}")
|
||||
print(f"[DSL INFO] gB = {gB.type}")
|
||||
print(f"[DSL INFO] gC = {gC.type}")
|
||||
|
||||
idC = cute.make_identity_tensor(result.shape)
|
||||
cC = cute.zipped_divide(idC, tiler=tiler_mn)
|
||||
print(f"[DSL INFO] coord tensor = {cC.type}")
|
||||
|
||||
# Launch the kernel asynchronously
|
||||
# Async token(s) can also be specified as dependencies
|
||||
elementwise_apply_kernel(
|
||||
op,
|
||||
gA,
|
||||
gB,
|
||||
gC,
|
||||
cC,
|
||||
result.shape,
|
||||
tv_layout,
|
||||
).launch(
|
||||
grid=[cute.size(gC, mode=[1]), 1, 1],
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
)
|
||||
|
||||
|
||||
def run_elementwise_apply_and_verify(
|
||||
op,
|
||||
M,
|
||||
N,
|
||||
dtype: Type[cutlass.Numeric],
|
||||
skip_ref_check=False,
|
||||
benchmark=True,
|
||||
warmup_iterations=2,
|
||||
iterations=100,
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(f"Ampere GPU is required to run this example!")
|
||||
|
||||
print(f"\nRunning Elementwise Apply test with:")
|
||||
print(f"Tensor dimensions: [{M}, {N}]")
|
||||
print(f"Input and Output Data type: {dtype}")
|
||||
print(f"Warmup iterations: {warmup_iterations}")
|
||||
print(f"Measurement iterations: {iterations}\n")
|
||||
|
||||
torch_dtype = cutlass_torch.dtype(dtype)
|
||||
|
||||
# Allocate tensors with random values.
|
||||
a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
c = torch.zeros_like(a)
|
||||
|
||||
print(f"Input tensor shapes:")
|
||||
print(f"a: {a.shape}, dtype: {a.dtype}")
|
||||
print(f"b: {b.shape}, dtype: {b.dtype}")
|
||||
print(f"c: {c.shape}, dtype: {c.dtype}\n")
|
||||
|
||||
epsilon = 1.2
|
||||
if op in (operator.truediv, operator.floordiv):
|
||||
b = torch.where(b == 0, torch.tensor(epsilon), b)
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
print("Executing elementwise apply kernel...")
|
||||
# Get current CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
if not skip_ref_check:
|
||||
elementwise_apply(
|
||||
op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()
|
||||
)
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(op(a, b), c)
|
||||
print("Results verified successfully!")
|
||||
|
||||
if not benchmark:
|
||||
return
|
||||
|
||||
# Create CUDA events for timing
|
||||
start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
elementwise_apply(
|
||||
op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()
|
||||
)
|
||||
|
||||
# Record start event
|
||||
cuda.cuEventRecord(start_event, current_stream)
|
||||
|
||||
# Execute the kernel
|
||||
for _ in range(iterations):
|
||||
elementwise_apply(
|
||||
op, from_dlpack(a), from_dlpack(b), from_dlpack(c).mark_layout_dynamic()
|
||||
)
|
||||
|
||||
# Record end event
|
||||
cuda.cuEventRecord(end_event, current_stream)
|
||||
cuda.cuEventSynchronize(end_event)
|
||||
|
||||
# Calculate elapsed time
|
||||
err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event)
|
||||
avg_time = elapsed_time / iterations
|
||||
|
||||
# Print execution results
|
||||
print(f"Kernel execution time: {avg_time:.4f} ms")
|
||||
print(
|
||||
f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9:.2f} GB/s"
|
||||
)
|
||||
print(f"First few elements of result: \n{c[:3, :3]}")
|
||||
|
||||
# Destroy events
|
||||
cuda.cuEventDestroy(start_event)
|
||||
cuda.cuEventDestroy(end_event)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="example of elementwise apply to demonstrate building elementwise kernels"
|
||||
)
|
||||
parser.add_argument("--M", default=128, type=int)
|
||||
parser.add_argument("--N", default=128, type=int)
|
||||
parser.add_argument("--op", default="add", type=str)
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument("--benchmark", action="store_true")
|
||||
args = parser.parse_args()
|
||||
run_elementwise_apply_and_verify(
|
||||
getattr(operator, args.op),
|
||||
args.M,
|
||||
args.N,
|
||||
dtype=cutlass.Float32,
|
||||
warmup_iterations=args.warmup_iterations,
|
||||
iterations=args.iterations,
|
||||
skip_ref_check=args.skip_ref_check,
|
||||
benchmark=args.benchmark,
|
||||
)
|
||||
print("\nPASS")
|
||||
1353
examples/python/CuTeDSL/ampere/flash_attention_v2.py
Normal file
1353
examples/python/CuTeDSL/ampere/flash_attention_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
780
examples/python/CuTeDSL/ampere/sgemm.py
Normal file
780
examples/python/CuTeDSL/ampere/sgemm.py
Normal file
@@ -0,0 +1,780 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
A dense FP32 SIMT GEMM (C = A * B) example using CUTE DSL.
|
||||
- Matrix A is MxK, A can be row-major("K") or column-major("M")
|
||||
- Matrix B is NxK, B can be row-major("N") or column-major("K")
|
||||
- Matrix C is MxN, C can be row-major("N") or column-major("M")
|
||||
|
||||
This GEMM kernel supports the following features:
|
||||
- Utilizes FPU for matrix multiply-accumulate (MMA) operations
|
||||
- Use multistage pipeline to overlap computation and memory access
|
||||
* Shared memory pipeline: hides gmem-to-smem latency.
|
||||
* Register pipeline: overlaps shared memory-to-register transfers with
|
||||
computations and eliminates false data dependencies for
|
||||
better parallelism.
|
||||
- Use vectorized copies
|
||||
- Add padding to reduce bank conflicts in global -> shared memory copies
|
||||
- Use predication to avoid unnecessary copies or copies of stale data
|
||||
|
||||
This GEMM works as follows:
|
||||
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies.
|
||||
2. Perform matrix multiply-accumulate (MMA) operations using simple fused multiply-add atomics.
|
||||
3. Store results from registers (RMEM) to global memory (GMEM).
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/ampere/sgemm.py \
|
||||
--mnk 8192,8192,8192 \
|
||||
--a_major m --b_major n --c_major n
|
||||
|
||||
To collect performance with NCU profiler:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
ncu python examples/ampere/sgemm.py \
|
||||
--mnk 8192,8192,8192 \
|
||||
--a_major m --b_major n --c_major n \
|
||||
--skip_ref_check --iterations 2
|
||||
|
||||
Constraints:
|
||||
* Supported input, output, and accumulator data types: fp32
|
||||
* Default tile shape is set to be 128x128x8
|
||||
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned
|
||||
"""
|
||||
|
||||
|
||||
class SGemm:
|
||||
def __init__(
|
||||
self,
|
||||
cta_tiler: Tuple[int, int, int] = (128, 128, 8),
|
||||
num_stages: int = 3,
|
||||
num_threads: int = 256,
|
||||
):
|
||||
self._cta_tiler = cta_tiler
|
||||
self._num_stages = num_stages
|
||||
self._num_threads = num_threads
|
||||
assert num_threads > 0, "needs at least one thread"
|
||||
assert num_threads % 16 == 0, "multiples of 16 required for MMA thread layout"
|
||||
|
||||
self._bM, self._bN, self._bK = self._cta_tiler
|
||||
assert self._bM % 16 == 0, "multiple of 16 required for tile dimension M"
|
||||
assert self._bN % 16 == 0, "multiple of 16 required for tile dimension N"
|
||||
assert self._num_stages >= 3, "num_stages must be greater than or equal to 3"
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
mA: cute.Tensor,
|
||||
mB: cute.Tensor,
|
||||
mC: cute.Tensor,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
):
|
||||
self.a_major_mode = utils.LayoutEnum.from_tensor(mA)
|
||||
self.b_major_mode = utils.LayoutEnum.from_tensor(mB)
|
||||
self.c_major_mode = utils.LayoutEnum.from_tensor(mC)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create layouts for shared memory for A and B:
|
||||
# - sA/sB is m/n-major to vectorized copies from shared
|
||||
# memory to registers. This is because the MMA layouts
|
||||
# for sA/sB are also m/n-major
|
||||
# - When gA/gB is k-major, pad 4 elements to reduce bank conflicts
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
padding_a = 4 if self.a_major_mode == utils.LayoutEnum.ROW_MAJOR else 0
|
||||
padding_b = 4 if self.b_major_mode == utils.LayoutEnum.ROW_MAJOR else 0
|
||||
sA_layout = cute.make_layout(
|
||||
(self._bM, self._bK, self._num_stages),
|
||||
stride=(1, (self._bM + padding_a), self._bK * (self._bM + padding_a)),
|
||||
)
|
||||
sB_layout = cute.make_layout(
|
||||
(self._bN, self._bK, self._num_stages),
|
||||
stride=(1, (self._bN + padding_b), self._bK * (self._bN + padding_b)),
|
||||
)
|
||||
|
||||
smem_size = cute.size_in_bytes(mA.element_type, sA_layout) + cute.size_in_bytes(
|
||||
mB.element_type, sB_layout
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create copy layouts that will be used for asynchronous
|
||||
# global memory -> shared memory copies:
|
||||
# - The majorness of tA/tB follows the majorness of gA/gB
|
||||
# - For k-major, these layouts will copy values one-by-one from
|
||||
# from global memory, without vectorizing
|
||||
# - For m/n-major, it will vectorize to a 128bit copy for faster
|
||||
# data transfer between global and shared memory, as long
|
||||
# as the alignment of the tensor allows it. Otherwise, it
|
||||
# defaults to a non-vectorized copy
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
tA = cute.make_layout(
|
||||
(self._num_threads // self._bK, self._bK), stride=(self._bK, 1)
|
||||
)
|
||||
tB = cute.make_layout(
|
||||
(self._num_threads // self._bK, self._bK), stride=(self._bK, 1)
|
||||
)
|
||||
vA = cute.make_layout((1, 1))
|
||||
vB = cute.make_layout((1, 1))
|
||||
atom_async_copy_A = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(),
|
||||
mA.element_type,
|
||||
num_bits_per_copy=mA.element_type.width,
|
||||
)
|
||||
atom_async_copy_B = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(),
|
||||
mA.element_type,
|
||||
num_bits_per_copy=mB.element_type.width,
|
||||
)
|
||||
|
||||
if self.a_major_mode == utils.LayoutEnum.COL_MAJOR:
|
||||
num_vectorized = 4 if (mA.layout.max_alignment % 16 == 0) else 1
|
||||
atom_async_copy_A = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(),
|
||||
mA.element_type,
|
||||
num_bits_per_copy=mA.element_type.width * num_vectorized,
|
||||
)
|
||||
major_mode_size = self._bM // num_vectorized
|
||||
tA = cute.make_layout(
|
||||
(major_mode_size, self._num_threads // major_mode_size),
|
||||
stride=(1, major_mode_size),
|
||||
)
|
||||
vA = cute.make_layout((num_vectorized, 1))
|
||||
|
||||
if self.b_major_mode == utils.LayoutEnum.COL_MAJOR:
|
||||
num_vectorized = 4 if (mB.layout.max_alignment % 16 == 0) else 1
|
||||
atom_async_copy_B = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(),
|
||||
mA.element_type,
|
||||
num_bits_per_copy=mB.element_type.width * num_vectorized,
|
||||
)
|
||||
major_mode_size = self._bN // num_vectorized
|
||||
tB = cute.make_layout(
|
||||
(major_mode_size, self._num_threads // major_mode_size),
|
||||
stride=(1, major_mode_size),
|
||||
)
|
||||
vB = cute.make_layout((num_vectorized, 1))
|
||||
|
||||
tiled_copy_A = cute.make_tiled_copy_tv(atom_async_copy_A, tA, vA)
|
||||
tiled_copy_B = cute.make_tiled_copy_tv(atom_async_copy_B, tB, vB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create layouts for GEMM:
|
||||
# We tile an MMA atom across a tensor. `atoms_layout` is the layout
|
||||
# of atoms in the tiled MMA. (Because we use an `MmaUniversalOp`,
|
||||
# which has a trivial 1x1x1 MMA trait, `atoms_layout` is also
|
||||
# simply the thread layout for C.) `permutation_tiler` reorders the
|
||||
# elements of the tensor that the tiled MMA is applied to.
|
||||
# Different combinations of `atoms_layout` and `permutation_tiler`
|
||||
# values can create different MMA thread-value patterns.
|
||||
#
|
||||
# Here, the MMA layout is set so that each thread copies four
|
||||
# consecutive elements from shared memory to registers.
|
||||
# `permutation_tiler_M/N` maps the elements handled by each thread
|
||||
# to the permuted element in the tensor.
|
||||
# For increasing indices in the tensor, the thread ID that reads it is:
|
||||
# - (without permutation) ==>
|
||||
# 0 1 2 ... 15 0 1 2 ... 15 0 1 2 ... 15 0 1 2 ... 15 ......
|
||||
# - (with permutation) ==>
|
||||
# 0 0 0 0 1 1 1 1 2 2 2 2 ... 15 15 15 15 0 0 0 0 1 1 1 1 ......
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
atoms_layout = cute.make_layout(
|
||||
(self._num_threads // 16, 16, 1), stride=(16, 1, 0)
|
||||
)
|
||||
if self.c_major_mode == utils.LayoutEnum.COL_MAJOR:
|
||||
atoms_layout = cute.make_layout(
|
||||
(16, self._num_threads // 16, 1), stride=(1, 16, 0)
|
||||
)
|
||||
op = cute.nvgpu.MmaUniversalOp(cutlass.Float32)
|
||||
permutation_tiler_M = cute.make_layout(
|
||||
(atoms_layout.shape[0], 4), stride=(4, 1)
|
||||
)
|
||||
permutation_tiler_N = cute.make_layout(
|
||||
(atoms_layout.shape[1], 4), stride=(4, 1)
|
||||
)
|
||||
tiled_mma = cute.make_tiled_mma(
|
||||
op,
|
||||
atoms_layout,
|
||||
permutation_mnk=(permutation_tiler_M, permutation_tiler_N, None),
|
||||
)
|
||||
|
||||
# grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, 1)
|
||||
grid_dim = *cute.ceil_div(mC.shape, (self._bM, self._bN)), 1
|
||||
|
||||
self.kernel(
|
||||
mA,
|
||||
mB,
|
||||
mC,
|
||||
sA_layout,
|
||||
sB_layout,
|
||||
tiled_copy_A,
|
||||
tiled_copy_B,
|
||||
tiled_mma,
|
||||
epilogue_op,
|
||||
).launch(
|
||||
grid=grid_dim,
|
||||
block=[cute.size(atoms_layout), 1, 1],
|
||||
smem=smem_size,
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
mA: cute.Tensor,
|
||||
mB: cute.Tensor,
|
||||
mC: cute.Tensor,
|
||||
sA_layout: cute.Layout,
|
||||
sB_layout: cute.Layout,
|
||||
tiled_copy_A: cute.TiledCopy,
|
||||
tiled_copy_B: cute.TiledCopy,
|
||||
tiled_mma: cute.TiledMma,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
):
|
||||
# Thread and block indices
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
tiler_coord = (bidx, bidy, None)
|
||||
thr_mma = tiled_mma.get_slice(tidx)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Get the appropriate tiles for this thread block.
|
||||
# gA: (BLK_M, BLK_K, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
gA = cute.local_tile(
|
||||
mA, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, None, 1)
|
||||
)
|
||||
gB = cute.local_tile(
|
||||
mB, tiler=self._cta_tiler, coord=tiler_coord, proj=(None, 1, 1)
|
||||
)
|
||||
gC = cute.local_tile(
|
||||
mC, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, 1, None)
|
||||
)
|
||||
|
||||
# Move the pointer of gA/gB in the `-k`` direction, making the first
|
||||
# tile (instead of the last one) irregular in shape when k is irregular.
|
||||
# We first handle the irregular tile to avoid checking for this
|
||||
# condition within the mainloop.
|
||||
residue_k = mA.shape[1] - cutlass.Int32(self._bK) * gA.shape[2]
|
||||
gA = cute.domain_offset((0, residue_k, 0), gA)
|
||||
gB = cute.domain_offset((0, residue_k, 0), gB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Get the appropriate tiles for this thread.
|
||||
# sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE)
|
||||
# tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k)
|
||||
# tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create shared memory buffer
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
|
||||
sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
tAgA = thr_copy_A.partition_S(gA)
|
||||
tAsA = thr_copy_A.partition_D(sA)
|
||||
tBgB = thr_copy_B.partition_S(gB)
|
||||
tBsB = thr_copy_B.partition_D(sB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Predicate: Mark indices that need to copy when the problem shape
|
||||
# isn't a multiple of the tile shape. If tApA/B[i] is 0, then do not
|
||||
# do the copy atom associated with index i.
|
||||
# cA: (BLK_M, BLK_K) => (blk_m, blk_k)
|
||||
# cB: (BLK_N, BLK_K) => (blk_n, blk_k)
|
||||
# tAcA: (CPY, CPY_M, CPY_K) => (blk_m, blk_k)
|
||||
# tBcB: (CPY, CPY_N, CPY_K) => (blk_n, blk_k)
|
||||
# tApA: (rest_v, CPY_M, CPY_K), stride=(..., ..., 0)
|
||||
# tBpB: (rest_v, CPY_N, CPY_K), stride=(..., ..., 0)
|
||||
# CPY = (atom_v, rest_v)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Construct identity layout for sA and sB, used for predication
|
||||
mcA = cute.make_identity_tensor(mA.shape)
|
||||
mcB = cute.make_identity_tensor(mB.shape)
|
||||
cA = cute.local_tile(
|
||||
mcA, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, None, 1)
|
||||
)
|
||||
cB = cute.local_tile(
|
||||
mcB, tiler=self._cta_tiler, coord=tiler_coord, proj=(None, 1, 1)
|
||||
)
|
||||
cA = cute.domain_offset((0, residue_k, 0), cA)
|
||||
cB = cute.domain_offset((0, residue_k, 0), cB)
|
||||
# Repeat the partitioning with identity layouts
|
||||
tAcA = thr_copy_A.partition_S(cA)
|
||||
tBcB = thr_copy_B.partition_S(cB)
|
||||
# Allocate predicate tensors for m and n
|
||||
tApA = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tAsA.shape[0][1],
|
||||
cute.size(tAsA, mode=[1]),
|
||||
cute.size(tAsA, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tAsA, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
tBpB = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tBsB.shape[0][1],
|
||||
cute.size(tBsB, mode=[1]),
|
||||
cute.size(tBsB, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tBsB, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
# Allocate predicate tensors for m, n and k for residue k-tile
|
||||
tApA_residue_k = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tAsA.shape[0][1],
|
||||
cute.size(tAsA, mode=[1]),
|
||||
cute.size(tAsA, mode=[2]),
|
||||
),
|
||||
stride=(
|
||||
cute.size(tAsA, mode=[1]) * cute.size(tAsA, mode=[2]),
|
||||
cute.size(tAsA, mode=[2]),
|
||||
1,
|
||||
),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
tBpB_residue_k = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tBsB.shape[0][1],
|
||||
cute.size(tBsB, mode=[1]),
|
||||
cute.size(tBsB, mode=[2]),
|
||||
),
|
||||
stride=(
|
||||
cute.size(tBsB, mode=[1]) * cute.size(tBsB, mode=[2]),
|
||||
cute.size(tBsB, mode=[2]),
|
||||
1,
|
||||
),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
# Set predicates for m/n bounds for mainloop
|
||||
for rest_v in range(tApA.shape[0]):
|
||||
for m in range(tApA.shape[1]):
|
||||
tApA[rest_v, m, 0] = cute.elem_less(
|
||||
tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0]
|
||||
)
|
||||
for rest_v in range(tBpB.shape[0]):
|
||||
for n in range(tBpB.shape[1]):
|
||||
tBpB[rest_v, n, 0] = cute.elem_less(
|
||||
tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0]
|
||||
)
|
||||
|
||||
# Set predicates for m/n/k bounds for residue k tile
|
||||
for rest_v in range(tApA_residue_k.shape[0]):
|
||||
for m in range(tApA_residue_k.shape[1]):
|
||||
for k in range(tApA_residue_k.shape[2]):
|
||||
coord_A = tAcA[(0, rest_v), m, k, 0]
|
||||
tApA_residue_k[rest_v, m, k] = cute.elem_less(
|
||||
(coord_A[0], cutlass.Int32(-1)), (mA.shape[0], coord_A[1])
|
||||
)
|
||||
for rest_v in range(tBpB_residue_k.shape[0]):
|
||||
for n in range(tBpB_residue_k.shape[1]):
|
||||
for k in range(tBpB_residue_k.shape[2]):
|
||||
coord_B = tBcB[(0, rest_v), n, k, 0]
|
||||
tBpB_residue_k[rest_v, n, k] = cute.elem_less(
|
||||
(coord_B[0], cutlass.Int32(-1)), (mB.shape[0], coord_B[1])
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Prefetch Prologue
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Start async loads for 0th k-tile, where we take care of the k-residue
|
||||
k_pipe_max = cute.size(tAsA, mode=[3])
|
||||
k_tile_count = cute.size(tAgA, mode=[3])
|
||||
gmem_pipe_read = cutlass.Int32(0)
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, gmem_pipe_read],
|
||||
tAsA[None, None, None, 0],
|
||||
pred=tApA_residue_k,
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, gmem_pipe_read],
|
||||
tBsB[None, None, None, 0],
|
||||
pred=tBpB_residue_k,
|
||||
)
|
||||
cute.arch.cp_async_commit_group()
|
||||
gmem_pipe_read = (
|
||||
gmem_pipe_read + 1
|
||||
if gmem_pipe_read + 1 < k_tile_count
|
||||
else cutlass.Int32(0)
|
||||
)
|
||||
# Start async loads for 1st k-tile onwards, no k-residue handling needed
|
||||
for k_tile in range(1, k_pipe_max - 1):
|
||||
if k_tile < k_tile_count:
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, gmem_pipe_read],
|
||||
tAsA[None, None, None, k_tile],
|
||||
pred=tApA,
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, gmem_pipe_read],
|
||||
tBsB[None, None, None, k_tile],
|
||||
pred=tBpB,
|
||||
)
|
||||
|
||||
gmem_pipe_read = (
|
||||
gmem_pipe_read + 1
|
||||
if gmem_pipe_read + 1 < k_tile_count
|
||||
else cutlass.Int32(0)
|
||||
)
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# all tiles have been copied from global memory, so clear the
|
||||
# predicate tensor
|
||||
if k_tile_count < k_pipe_max:
|
||||
for rest_v in range(tApA.shape[0]):
|
||||
for m in range(tApA.shape[1]):
|
||||
tApA[rest_v, m, 0] = cutlass.Boolean(0)
|
||||
for rest_v in range(tBpB.shape[0]):
|
||||
for n in range(tBpB.shape[1]):
|
||||
tBpB[rest_v, n, 0] = cutlass.Boolean(0)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Define A/B partitioning and C accumulators.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
tCsA = thr_mma.partition_A(sA)
|
||||
tCsB = thr_mma.partition_B(sB)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
|
||||
tCrC = tiled_mma.make_fragment_C(tCgC)
|
||||
# Clear the accumulator
|
||||
tCrC.fill(0.0)
|
||||
|
||||
# Current pipe index in smem to read from / write to
|
||||
smem_pipe_read = cutlass.Int32(0)
|
||||
smem_pipe_write = cutlass.Int32(k_pipe_max - 1)
|
||||
|
||||
tCsA_p = tCsA[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB[None, None, None, smem_pipe_read]
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# PREFETCH register pipeline
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
k_block_max = cute.size(tCrA, mode=[2])
|
||||
|
||||
if k_block_max > 1:
|
||||
# Wait until our first prefetched tile is loaded in
|
||||
cute.arch.cp_async_wait_group(k_pipe_max - 2)
|
||||
cute.arch.barrier()
|
||||
# Prefetch the first rmem from the first k-tile
|
||||
cute.autovec_copy(tCsA_p[None, None, 0], tCrA[None, None, 0])
|
||||
cute.autovec_copy(tCsB_p[None, None, 0], tCrB[None, None, 0])
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Mainloop
|
||||
# 1. Shared memory pipeline (gmem -> smem):
|
||||
# The default smem pipeline depth is 3, meaning that for shared
|
||||
# memory buffers, we allocate three times the size described by the
|
||||
# CTA tiler. We prefetch 2 of these buffers before entering the main
|
||||
# loop. Considering only the transfer from global memory to shared
|
||||
# memory, the general structure of the mainloop is:
|
||||
# (1) copy k-tile from gmem to smem;
|
||||
# (2) perform gemm computation on k-tile;
|
||||
# (3) wait for the next copy to finish.
|
||||
# The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command
|
||||
# waits for the number of unfinished 'copy' to be <= 1. The advantage
|
||||
# of this approach is that it allows for simultaneous production
|
||||
# (i.e., step (1)) and consumption (i.e., step (2)) of smem.
|
||||
# A common misconception is to prefetch N buffers and rewrite
|
||||
# the pipeline logic to wait on N-1 pending copies. The disadvantage
|
||||
# of this approach is that it requires fully consuming a buffer in
|
||||
# order to open an empty buffer for the next copy.
|
||||
# 2. Register pipeline (smem -> register):
|
||||
# Similarly, the register pipeline produces i+1, consumes i, and
|
||||
# produces i+2... Notably, i and i+1 do not use the same register,
|
||||
# eliminating dependencies on the same register for better parallelism.
|
||||
# 3. Combining the smem and register pipelines results in the mainloop.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
for _ in cutlass.range_dynamic(k_tile_count, unroll=1):
|
||||
for k_block in range(k_block_max):
|
||||
if k_block == k_block_max - 1:
|
||||
tCsA_p = tCsA[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB[None, None, None, smem_pipe_read]
|
||||
cute.arch.cp_async_wait_group(k_pipe_max - 2)
|
||||
cute.arch.barrier()
|
||||
|
||||
# Load A, B from shared memory to registers for k_block + 1
|
||||
k_block_next = (k_block + 1) % k_block_max # static
|
||||
cute.autovec_copy(
|
||||
tCsA_p[None, None, k_block_next],
|
||||
tCrA[None, None, k_block_next],
|
||||
)
|
||||
cute.autovec_copy(
|
||||
tCsB_p[None, None, k_block_next],
|
||||
tCrB[None, None, k_block_next],
|
||||
)
|
||||
|
||||
# Fetch next A: To better interleave global memory access and
|
||||
# compute instructions, we intentionally use the sequence:
|
||||
# copy A, perform GEMM, then copy B.
|
||||
if k_block == 0:
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, gmem_pipe_read],
|
||||
tAsA[None, None, None, smem_pipe_write],
|
||||
# Use predicates because the m-mode may be irregular
|
||||
pred=tApA,
|
||||
)
|
||||
|
||||
# Thread-level register gemm for k_block
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCrC,
|
||||
tCrA[None, None, k_block],
|
||||
tCrB[None, None, k_block],
|
||||
tCrC,
|
||||
)
|
||||
|
||||
# Fetch next B and update smem pipeline read/write
|
||||
if k_block == 0:
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, gmem_pipe_read],
|
||||
tBsB[None, None, None, smem_pipe_write],
|
||||
# Use predicates because the n-mode may be irregular
|
||||
pred=tBpB,
|
||||
)
|
||||
cute.arch.cp_async_commit_group()
|
||||
smem_pipe_write = smem_pipe_read
|
||||
smem_pipe_read = smem_pipe_read + 1
|
||||
if smem_pipe_read == k_pipe_max:
|
||||
smem_pipe_read = cutlass.Int32(0)
|
||||
# After copying all tiles, we avoid clearing the predicate
|
||||
# tensor in the `mainloop` to prevent increasing its
|
||||
# instruction count. Instead, we continue copying the
|
||||
# first tile, though it won't be used. The 0-th tile is not
|
||||
# copied due to its irregular shape, which could lead to
|
||||
# illegal memory accesses.
|
||||
gmem_pipe_read = (
|
||||
gmem_pipe_read + 1
|
||||
if gmem_pipe_read + 1 < k_tile_count
|
||||
else cutlass.Int32(1)
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Epilogue
|
||||
# Applies the epilogue operation to the accumulated results and copies
|
||||
# them without vectorization.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
cute.arch.cp_async_wait_group(0)
|
||||
cute.arch.barrier()
|
||||
tCrC.store(epilogue_op(tCrC.load()))
|
||||
|
||||
# predicate
|
||||
cC = cute.make_identity_tensor(gC.shape)
|
||||
tCpC = thr_mma.partition_C(cC)
|
||||
predC = cute.make_fragment(tCrC.layout, cutlass.Boolean)
|
||||
residue_m = mC.shape[0] - cutlass.Int32(self._bM) * bidx
|
||||
residue_n = mC.shape[1] - cutlass.Int32(self._bN) * bidy
|
||||
for i in range(cute.size(tCrC.shape)):
|
||||
predC[i] = cute.elem_less(tCpC[i], (residue_m, residue_n))
|
||||
numIterM = cute.size(tCrC, mode=[1])
|
||||
numIterN = cute.size(tCrC, mode=[2])
|
||||
atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mC.element_type)
|
||||
cute.copy(atom, tCrC, tCgC, pred=predC)
|
||||
return
|
||||
|
||||
|
||||
def main(
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
problem_shape: Tuple[int, int, int],
|
||||
warmup_iterations: int = 2,
|
||||
iterations: int = 100,
|
||||
skip_ref_check: bool = False,
|
||||
):
|
||||
torch.manual_seed(1024)
|
||||
M, N, K = problem_shape
|
||||
|
||||
# Create and permute tensor A/B/C
|
||||
def create_and_permute_tensor(mode0, mode1, is_mode0_major, dtype):
|
||||
# is_mode0_major: (mode1, mode0) -> (mode0, mode1)
|
||||
# else: (mode0, mode1) -> (mode0, mode1)
|
||||
shape = (mode1, mode0) if is_mode0_major else (mode0, mode1)
|
||||
permute_order = (1, 0) if is_mode0_major else (0, 1)
|
||||
|
||||
return (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-5, 5)
|
||||
.to(dtype=dtype)
|
||||
.permute(permute_order)
|
||||
.cuda()
|
||||
)
|
||||
|
||||
a = create_and_permute_tensor(M, K, a_major == "m", torch.float32)
|
||||
b = create_and_permute_tensor(N, K, b_major == "n", torch.float32)
|
||||
c = create_and_permute_tensor(M, N, c_major == "m", torch.float32)
|
||||
|
||||
divisibility_a = a.shape[1] if a_major == "k" else a.shape[0]
|
||||
divisibility_b = b.shape[1] if b_major == "k" else b.shape[0]
|
||||
divisibility_c = c.shape[1] if c_major == "n" else c.shape[0]
|
||||
|
||||
a_tensor = (
|
||||
from_dlpack(a, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if a_major == "k" else 0),
|
||||
divisibility=divisibility_a,
|
||||
)
|
||||
)
|
||||
|
||||
b_tensor = (
|
||||
from_dlpack(b, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if b_major == "k" else 0),
|
||||
divisibility=divisibility_b,
|
||||
)
|
||||
)
|
||||
|
||||
c_tensor = (
|
||||
from_dlpack(c, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if c_major == "n" else 0),
|
||||
divisibility=divisibility_c,
|
||||
)
|
||||
)
|
||||
|
||||
sgemm = SGemm()
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor)
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
print("Executing GEMM kernel...")
|
||||
|
||||
# Get current CUDA stream from PyTorch
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
|
||||
# Get the raw stream pointer as a CUstream
|
||||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
# Create CUDA events for timing
|
||||
start_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
end_event = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1]
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
# Use the current stream for CUDA events instead of the default stream
|
||||
# Record start event
|
||||
cuda.cuEventRecord(start_event, current_stream)
|
||||
|
||||
# Execute the kernel
|
||||
for _ in range(iterations):
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
# Record end event
|
||||
cuda.cuEventRecord(end_event, current_stream)
|
||||
cuda.cuEventSynchronize(end_event)
|
||||
|
||||
# Calculate elapsed time
|
||||
err, elapsed_time = cuda.cuEventElapsedTime(start_event, end_event)
|
||||
|
||||
# Print execution results
|
||||
print(f"Kernel execution time: {elapsed_time / iterations:.4f} ms")
|
||||
|
||||
# Destroy events
|
||||
cuda.cuEventDestroy(start_event)
|
||||
cuda.cuEventDestroy(end_event)
|
||||
|
||||
if not skip_ref_check:
|
||||
print("Verifying results...")
|
||||
ref = torch.einsum("mk,nk->mn", a, b)
|
||||
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
|
||||
try:
|
||||
return tuple(int(x.strip()) for x in s.split(","))
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format. Expected comma-separated integers."
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--mnk", type=parse_comma_separated_ints, default=(256, 256, 64)
|
||||
)
|
||||
parser.add_argument("--a_major", choices=["k", "m"], default="k")
|
||||
parser.add_argument("--b_major", choices=["k", "n"], default="k")
|
||||
parser.add_argument("--c_major", choices=["n", "m"], default="n")
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
print("Running SIMT GEMM example:")
|
||||
main(
|
||||
args.a_major,
|
||||
args.b_major,
|
||||
args.c_major,
|
||||
args.mnk,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
)
|
||||
print("PASS")
|
||||
968
examples/python/CuTeDSL/ampere/tensorop_gemm.py
Normal file
968
examples/python/CuTeDSL/ampere/tensorop_gemm.py
Normal file
@@ -0,0 +1,968 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
from typing import Tuple, Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
A dense GEMM (C = A * B) example for the NVIDIA Ampere architecture using CUTE DSL.
|
||||
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
|
||||
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
|
||||
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
|
||||
|
||||
This GEMM kernel supports the following features:
|
||||
- Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations
|
||||
- Supports multi-stage pipeline to overlap computation and memory access
|
||||
- Implements shared memory buffering for epilogue to increase coalesed global memory access
|
||||
|
||||
This GEMM works as follows:
|
||||
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies.
|
||||
2. Perform matrix multiply-accumulate (MMA) operations.
|
||||
3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM).
|
||||
|
||||
The Ampere tensor core instruction used operates as follows:
|
||||
- Read matrix A from SMEM
|
||||
- Read matrix B from SMEM
|
||||
- Perform MMA operation and store the result in Accumulator(register)
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/ampere/tensorop_gemm.py \
|
||||
--mnkl 8192,8192,8192,1 --atom_layout_mnk 2,2,1 \
|
||||
--ab_dtype Float16 \
|
||||
--c_dtype Float16 --acc_dtype Float32 \
|
||||
--a_major m --b_major n --c_major n
|
||||
|
||||
The above example command computes with M=8192, N=8192, K=8192,
|
||||
batch_count=1. The atom layout's shape is 2x2x1 and the input, mma
|
||||
accumulator, and output data type are set as fp16, fp32 and fp16,
|
||||
respectively.
|
||||
|
||||
To collect performance with NCU profiler:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
ncu python examples/ampere/tensorop_gemm.py \
|
||||
--mnkl 8192,8192,8192,1 --atom_layout_mnk 2,2,1 \
|
||||
--ab_dtype Float16 \
|
||||
--c_dtype Float16 --acc_dtype Float32 \
|
||||
--a_major m --b_major n --c_major n \
|
||||
--skip_ref_check --iterations 2
|
||||
|
||||
Constraints:
|
||||
* Supported input and output data types: fp16
|
||||
* Support accumulator data types: f32
|
||||
* Default tile shape is set to be 128x128x32
|
||||
* Atom layout's MNK shape is set so that tile shape can be divided by MMA
|
||||
instruction shape
|
||||
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
||||
i.e, number of elements is a multiple of 8
|
||||
"""
|
||||
|
||||
|
||||
class TensorOpGemm:
|
||||
def __init__(
|
||||
self,
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
acc_dtype: Type[cutlass.Numeric],
|
||||
atom_layout_mnk: Tuple[int, int, int],
|
||||
):
|
||||
self.ab_dtype = ab_dtype
|
||||
self.c_dtype = c_dtype
|
||||
self.acc_dtype = acc_dtype
|
||||
self.cta_tiler = (128, 128, 32)
|
||||
self.num_stages = 3
|
||||
self.atom_layout_mnk = atom_layout_mnk
|
||||
atom_lay_M, atom_lay_N, atom_lay_K = self.atom_layout_mnk
|
||||
self.num_threads = atom_lay_M * atom_lay_N * atom_lay_K * 32
|
||||
|
||||
self.bM, self.bN, self.bK = self.cta_tiler
|
||||
self.mma_inst_shape = (16, 8, 16)
|
||||
mmaM, mmaN, mmaK = self.mma_inst_shape
|
||||
|
||||
assert (
|
||||
self.bM % (atom_lay_M * mmaM) == 0
|
||||
), "bM must be divisible by MMA instruction"
|
||||
assert (
|
||||
self.bN % (atom_lay_N * mmaN) == 0
|
||||
), "bN must be divisible by MMA instruction"
|
||||
assert atom_lay_K == 1, "this example does not support atom layout K > 1"
|
||||
assert self.bK % mmaK == 0, "bK must be divisible by MMA instruction"
|
||||
assert self.num_stages >= 3, "num_stages must be greater than or equal to 3"
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
mA: cute.Tensor,
|
||||
mB: cute.Tensor,
|
||||
mC: cute.Tensor,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
):
|
||||
# The grid divides the problems's M, N, and L dimensions by the
|
||||
# respective modes of the tile shape (bM, bN, 1). The K dimension is
|
||||
# handled within a block via a multistage process.
|
||||
|
||||
self.a_major_mode = utils.LayoutEnum.from_tensor(mA)
|
||||
self.b_major_mode = utils.LayoutEnum.from_tensor(mB)
|
||||
self.c_major_mode = utils.LayoutEnum.from_tensor(mC)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Shared memory layout:
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# Creates a layout with the size required for the provided tile
|
||||
# size and num stages (stages are used for K dimension) that is also
|
||||
# sectioned into 64x8 or 8x32 layout atoms. The swizzle is set so that
|
||||
# the atom for the shared memory -> register copy does not encounter
|
||||
# bank conflicts
|
||||
|
||||
# assume the input is 16B align
|
||||
ab_copy_bits = 128
|
||||
sA_layout = self._make_smem_layout_AB(
|
||||
mA.element_type,
|
||||
self.a_major_mode,
|
||||
ab_copy_bits,
|
||||
(self.cta_tiler[0], self.cta_tiler[2], self.num_stages),
|
||||
)
|
||||
sB_layout = self._make_smem_layout_AB(
|
||||
mB.element_type,
|
||||
self.b_major_mode,
|
||||
ab_copy_bits,
|
||||
(self.cta_tiler[1], self.cta_tiler[2], self.num_stages),
|
||||
)
|
||||
|
||||
# Creates a similar layout but without num_stages or layout atoms
|
||||
sC_layout = self._make_smem_layout_C(
|
||||
mC.element_type,
|
||||
self.c_major_mode,
|
||||
ab_copy_bits,
|
||||
(self.cta_tiler[0], self.cta_tiler[1]),
|
||||
)
|
||||
|
||||
# Shared memory allocated for operations with A, B will be
|
||||
# overwritten for operations on C. This is to improve performance
|
||||
# by reducing the size of shared memory requested by each block
|
||||
smem_size = max(
|
||||
cute.size_in_bytes(mC.element_type, sC_layout),
|
||||
cute.size_in_bytes(mA.element_type, sA_layout)
|
||||
+ cute.size_in_bytes(mB.element_type, sB_layout),
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Tiled copy:
|
||||
# The majorness of tA/tB/tC follows the majorness of gA/gB/gC,
|
||||
# enabling merged accesses to global memory for faster data
|
||||
# transfer between global and shared memory.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# Create a copy atom for a global to shared memory asynchronous copy
|
||||
atom_async_copy = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(
|
||||
cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL
|
||||
),
|
||||
mA.element_type,
|
||||
num_bits_per_copy=ab_copy_bits,
|
||||
)
|
||||
|
||||
# Create thread layouts for tiled copy from the copy atom where the
|
||||
# thread layout simply follows the leading dimension of the tensor
|
||||
tiled_copy_A = self._make_gmem_tiled_copy_AB(
|
||||
atom_async_copy, mA.element_type, self.a_major_mode, ab_copy_bits
|
||||
)
|
||||
tiled_copy_B = self._make_gmem_tiled_copy_AB(
|
||||
atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits
|
||||
)
|
||||
|
||||
# Creates a synchonous copy atom and thread layouts for the epilogue
|
||||
c_copy_bits = 128
|
||||
atom_sync_copy = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
mC.element_type,
|
||||
num_bits_per_copy=c_copy_bits,
|
||||
)
|
||||
tiled_copy_C = self._make_gmem_tiled_copy_C(
|
||||
atom_sync_copy, mC.element_type, self.c_major_mode, c_copy_bits
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Tiled MMA
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# Creates a mma atom with 16x8x16 shape for MNK
|
||||
op = cute.nvgpu.warp.MmaF16BF16Op(
|
||||
self.ab_dtype, self.acc_dtype, self.mma_inst_shape
|
||||
)
|
||||
|
||||
permutation_mnk = (
|
||||
self.atom_layout_mnk[0] * self.mma_inst_shape[0],
|
||||
# if atom layout's N-mode is 1, to leverage the largest coalesced
|
||||
# shared memory -> register copy, set the tiled mma's N mode to 16
|
||||
self.atom_layout_mnk[1] * self.mma_inst_shape[1] * 2,
|
||||
self.atom_layout_mnk[2] * self.mma_inst_shape[2],
|
||||
)
|
||||
|
||||
# Created a tiled mma that tiles the atom according to specified layout.
|
||||
# For a 2x2x1 atom layout, the mma atom is duplicated 4 times, twice
|
||||
# across M and twice across N
|
||||
tC = cute.make_layout(self.atom_layout_mnk)
|
||||
tiled_mma = cute.make_tiled_mma(
|
||||
op,
|
||||
tC,
|
||||
permutation_mnk=permutation_mnk,
|
||||
)
|
||||
|
||||
# grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, l)
|
||||
grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
||||
|
||||
self.kernel(
|
||||
mA,
|
||||
mB,
|
||||
mC,
|
||||
sA_layout,
|
||||
sB_layout,
|
||||
sC_layout,
|
||||
tiled_copy_A,
|
||||
tiled_copy_B,
|
||||
tiled_copy_C,
|
||||
tiled_mma,
|
||||
epilogue_op,
|
||||
).launch(
|
||||
grid=grid_dim,
|
||||
block=[self.num_threads, 1, 1],
|
||||
smem=smem_size,
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
mA: cute.Tensor,
|
||||
mB: cute.Tensor,
|
||||
mC: cute.Tensor,
|
||||
sA_layout: cute.ComposedLayout,
|
||||
sB_layout: cute.ComposedLayout,
|
||||
sC_layout: cute.ComposedLayout,
|
||||
tiled_copy_A: cute.TiledCopy,
|
||||
tiled_copy_B: cute.TiledCopy,
|
||||
tiled_copy_C: cute.TiledCopy,
|
||||
tiled_mma: cute.TiledMma,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
):
|
||||
# Thread index, block index
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
tiler_coord = (bidx, bidy, None)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Get the appropriate tiles for this thread block.
|
||||
# gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
gA = cute.local_tile(
|
||||
mA[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, None, 1),
|
||||
)
|
||||
gB = cute.local_tile(
|
||||
mB[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(None, 1, 1),
|
||||
)
|
||||
gC = cute.local_tile(
|
||||
mC[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, 1, None),
|
||||
)
|
||||
|
||||
# By default, if the tensor k mode does not divide into the tile k
|
||||
# size, then last tiles in the k dimension are irregular.
|
||||
# Instead, make the first tiles irregular when k is irregular.
|
||||
# This allows us to handle the irregular tile first to avoid
|
||||
# checking for this condition within the mainloop.
|
||||
|
||||
# residual_k is a negative number indicating the amount needed to
|
||||
# shift the pointer by in dimension k
|
||||
residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size(
|
||||
gA, mode=[2]
|
||||
)
|
||||
|
||||
# move the pointer of gA/gB in the `-k` direction
|
||||
gA = cute.domain_offset((0, residual_k, 0), gA)
|
||||
gB = cute.domain_offset((0, residual_k, 0), gB)
|
||||
# input is 16B aligned
|
||||
gA = cute.make_tensor(gA.iterator.align(16), gA.layout)
|
||||
gB = cute.make_tensor(gB.iterator.align(16), gB.layout)
|
||||
|
||||
# Construct identity layout for sA and sB (mirrors global tensors,
|
||||
# used for predication only)
|
||||
mcA = cute.make_identity_tensor(mA.layout.shape)
|
||||
mcB = cute.make_identity_tensor(mB.layout.shape)
|
||||
cA = cute.local_tile(
|
||||
mcA[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, None, 1),
|
||||
)
|
||||
cB = cute.local_tile(
|
||||
mcB[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(None, 1, 1),
|
||||
)
|
||||
|
||||
cA = cute.domain_offset((0, residual_k, 0), cA)
|
||||
cB = cute.domain_offset((0, residual_k, 0), cB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create shared memory buffers and get the appropriate fragments for this thread.
|
||||
# sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE)
|
||||
# tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k)
|
||||
# tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Shared memory buffer
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
|
||||
sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
|
||||
sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
|
||||
sC = cute.make_tensor(
|
||||
cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout
|
||||
)
|
||||
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
thr_copy_C = tiled_copy_C.get_slice(tidx)
|
||||
tAgA = thr_copy_A.partition_S(gA)
|
||||
tAsA = thr_copy_A.partition_D(sA)
|
||||
tBgB = thr_copy_B.partition_S(gB)
|
||||
tBsB = thr_copy_B.partition_D(sB)
|
||||
tCsC_epilogue = thr_copy_C.partition_S(sC)
|
||||
tCgC_epilogue = thr_copy_C.partition_D(gC)
|
||||
|
||||
# Repeat the partitioning with identity layouts
|
||||
tAcA = thr_copy_A.partition_S(cA)
|
||||
tBcB = thr_copy_B.partition_S(cB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
||||
# of tile_shape
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# For predication over the tensors A (M/K), B (N/K), and (in the
|
||||
# epilogue) C (M/N), we will compute it in a fashion similar to an
|
||||
# outer product. The predication along one of the dimensions is
|
||||
# evaluated and stored in a predication tensor. Then, the
|
||||
# predication for the remaining dimension is handled later via an
|
||||
# if/else branch at the copy.
|
||||
# For A and B, predication booleans along M/N are stored in a
|
||||
# predication tensor and along K is handled via a if/else branch.
|
||||
|
||||
# Allocate predicate tensors for M and N. Predication is checked
|
||||
# at the granularity of a copy atom, so the predicate tensor does not
|
||||
# need separate booleans for individual elements within a copy
|
||||
# atom (for example, the elements of tAgA.shape[0][0].)
|
||||
tApA = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tAgA.shape[0][1],
|
||||
cute.size(tAgA, mode=[1]),
|
||||
cute.size(tAgA, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tAgA, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
tBpB = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tBsB.shape[0][1],
|
||||
cute.size(tBsB, mode=[1]),
|
||||
cute.size(tBsB, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tBsB, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
# Set predicates for M/N bounds
|
||||
for rest_v in range(tApA.shape[0]):
|
||||
for m in range(tApA.shape[1]):
|
||||
tApA[rest_v, m, 0] = cute.elem_less(
|
||||
tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0]
|
||||
)
|
||||
for rest_v in range(tBpB.shape[0]):
|
||||
for n in range(tBpB.shape[1]):
|
||||
tBpB[rest_v, n, 0] = cute.elem_less(
|
||||
tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0]
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Prefetch Prologue
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Clear the smem tiles to account for predicated off loads
|
||||
tAsA.fill(0)
|
||||
tBsB.fill(0)
|
||||
cute.arch.sync_threads()
|
||||
# Start async loads for the first k-tile. Here we take care of the k residue
|
||||
# via if/else check along the k dimension. Because we shifted the identity tensor
|
||||
# by the residue_k and because the identity tensor is a counting tensor, the
|
||||
# values of any identity tensor element that is poison is less than -1
|
||||
num_smem_stages = cute.size(tAsA, mode=[3])
|
||||
k_tile_count = cute.size(tAgA, mode=[3])
|
||||
k_tile_index = cutlass.Int32(0)
|
||||
|
||||
for k in range(tApA.shape[2]):
|
||||
if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]):
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, k, k_tile_index],
|
||||
tAsA[None, None, k, 0],
|
||||
pred=tApA[None, None, k],
|
||||
)
|
||||
for k in range(tBpB.shape[2]):
|
||||
if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]):
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, k, k_tile_index],
|
||||
tBsB[None, None, k, 0],
|
||||
pred=tBpB[None, None, k],
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# Start async loads for rest of the k-tiles
|
||||
for k_tile in range(1, num_smem_stages - 1):
|
||||
if k_tile == k_tile_count:
|
||||
tApA.fill(0)
|
||||
tBpB.fill(0)
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, k_tile_index],
|
||||
tAsA[None, None, None, k_tile],
|
||||
pred=tApA,
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, k_tile_index],
|
||||
tBsB[None, None, None, k_tile],
|
||||
pred=tBpB,
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Tile MMA compute thread partitions and allocate accumulators
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
thr_mma = tiled_mma.get_slice(tidx)
|
||||
tCsA = thr_mma.partition_A(sA)
|
||||
tCsB = thr_mma.partition_B(sB)
|
||||
tCsC = thr_mma.partition_C(sC)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
|
||||
tCrC = tiled_mma.make_fragment_C(tCgC)
|
||||
# Clear the accumulator
|
||||
tCrC.fill(0.0)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Copy Atom A/B retiling
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# Create the copy atoms for the copy from shared memory to register
|
||||
atom_copy_s2r_A = cute.make_copy_atom(
|
||||
cute.nvgpu.warp.LdMatrix8x8x16bOp(
|
||||
self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
|
||||
),
|
||||
mA.element_type,
|
||||
)
|
||||
atom_copy_s2r_B = cute.make_copy_atom(
|
||||
cute.nvgpu.warp.LdMatrix8x8x16bOp(
|
||||
self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
|
||||
),
|
||||
mB.element_type,
|
||||
)
|
||||
|
||||
# Creates the tiled copy so that it matches the thread-value layout
|
||||
# expected by the tiled mma
|
||||
tiled_copy_s2r_A = cute.make_tiled_copy(
|
||||
atom_copy_s2r_A,
|
||||
layout_tv=tiled_mma.tv_layout_A_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
tiled_copy_s2r_B = cute.make_tiled_copy(
|
||||
atom_copy_s2r_B,
|
||||
layout_tv=tiled_mma.tv_layout_B_tiled,
|
||||
tiler_mn=(tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)),
|
||||
)
|
||||
|
||||
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
|
||||
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
|
||||
tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
|
||||
tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA)
|
||||
tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
|
||||
tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB)
|
||||
|
||||
# Current pipe index in smem to read from / write to
|
||||
smem_pipe_read = 0
|
||||
smem_pipe_write = num_smem_stages - 1
|
||||
|
||||
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# PREFETCH register pipeline
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
num_k_block = cute.size(tCrA, mode=[2])
|
||||
if num_k_block > 1:
|
||||
# Wait until our first prefetched tile is loaded in
|
||||
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
||||
cute.arch.sync_threads()
|
||||
# Prefetch the first k-block rmem from the first k-tile
|
||||
cute.copy(
|
||||
tiled_copy_s2r_A,
|
||||
tCsA_p[None, None, 0],
|
||||
tCrA_copy_view[None, None, 0],
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_s2r_B,
|
||||
tCsB_p[None, None, 0],
|
||||
tCrB_copy_view[None, None, 0],
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Mainloop
|
||||
# 1. Shared memory pipeline (gmem -> smem):
|
||||
# The default smem pipeline depth is 3, meaning that for shared
|
||||
# memory buffers, we allocate three times the size described by the
|
||||
# CTA tiler. We prefetch 2 of these buffers before entering the main
|
||||
# loop. Considering only the transfer from global memory to shared
|
||||
# memory, the general structure of the mainloop is:
|
||||
# (1) copy k-tile from gmem to smem;
|
||||
# (2) perform gemm computation on k-tile;
|
||||
# (3) wait for the next copy to finish.
|
||||
# The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command
|
||||
# waits for the number of unfinished 'copy' to be <= 1. The advantage
|
||||
# of this approach is that it allows for simultaneous production
|
||||
# (i.e., step (1)) and consumption (i.e., step (2)) of smem.
|
||||
# A common misconception is to prefetch N buffers and rewrite
|
||||
# the pipeline logic to wait on N-1 pending copies. The disadvantage
|
||||
# of this approach is that it requires fully consuming a buffer in
|
||||
# order to open an empty buffer for the next copy.
|
||||
# 2. Register pipeline (smem -> register):
|
||||
# Similarly, the register pipeline produces i+1, consumes i, and
|
||||
# produces i+2... Notably, i and i+1 do not use the same register,
|
||||
# eliminating dependencies on the same register for better parallelism.
|
||||
# 3. Combining the smem and register pipelines results in the mainloop.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
for k_tile in cutlass.range_dynamic(k_tile_count, unroll=1):
|
||||
for k_block in range(num_k_block):
|
||||
if k_block == num_k_block - 1:
|
||||
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
|
||||
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
||||
cute.arch.sync_threads()
|
||||
|
||||
# Load A, B from shared memory to registers for k_block + 1
|
||||
k_block_next = (k_block + 1) % num_k_block # static
|
||||
cute.copy(
|
||||
tiled_copy_s2r_A,
|
||||
tCsA_p[None, None, k_block_next],
|
||||
tCrA_copy_view[None, None, k_block_next],
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_s2r_B,
|
||||
tCsB_p[None, None, k_block_next],
|
||||
tCrB_copy_view[None, None, k_block_next],
|
||||
)
|
||||
|
||||
# Fetch next A: To better interleave global memory access and compute
|
||||
# instructions, we intentionally use the sequence: copy A, perform GEMM,
|
||||
# then copy B.
|
||||
if k_block == 0:
|
||||
if k_tile + num_smem_stages - 1 < k_tile_count:
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, k_tile_index],
|
||||
tAsA[None, None, None, smem_pipe_write],
|
||||
pred=tApA,
|
||||
)
|
||||
|
||||
# Thread-level register gemm for k_block
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCrC,
|
||||
tCrA[None, None, k_block],
|
||||
tCrB[None, None, k_block],
|
||||
tCrC,
|
||||
)
|
||||
|
||||
# Fetch next B and update smem pipeline read/write
|
||||
if k_block == 0:
|
||||
if k_tile + num_smem_stages - 1 < k_tile_count:
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, k_tile_index],
|
||||
tBsB[None, None, None, smem_pipe_write],
|
||||
pred=tBpB,
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
smem_pipe_write = smem_pipe_read
|
||||
smem_pipe_read = smem_pipe_read + 1
|
||||
if smem_pipe_read == num_smem_stages:
|
||||
smem_pipe_read = 0
|
||||
|
||||
# Sync before epilogue
|
||||
cute.arch.cp_async_wait_group(0)
|
||||
cute.arch.sync_threads()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Epilogue with fusion
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
tCrD = cute.make_fragment_like(tCrC, self.c_dtype)
|
||||
tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype)
|
||||
|
||||
# Copy results of D back to shared memory
|
||||
cute.autovec_copy(tCrD, tCsC)
|
||||
|
||||
# Create counting tensor for C
|
||||
ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
||||
mcC = cute.make_identity_tensor(
|
||||
(
|
||||
cute.size(ceilM) * self.cta_tiler[0],
|
||||
cute.size(ceilN) * self.cta_tiler[1],
|
||||
1,
|
||||
)
|
||||
)
|
||||
cC = cute.local_tile(
|
||||
mcC[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, 1, None),
|
||||
)
|
||||
tCcC = thr_copy_C.partition_S(cC)
|
||||
|
||||
tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue)
|
||||
# Wait for all writes to shared memory to finish before starting copies
|
||||
# using the new layouts
|
||||
cute.arch.sync_threads()
|
||||
cute.autovec_copy(tCsC_epilogue, tCrC_epilogue)
|
||||
|
||||
# Create predication tensor for m
|
||||
tCpC = cute.make_fragment(
|
||||
cute.make_layout(
|
||||
(
|
||||
tCgC_epilogue.shape[0][1],
|
||||
cute.size(tCgC_epilogue, mode=[1]),
|
||||
cute.size(tCgC_epilogue, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
for rest_v in range(tCpC.shape[0]):
|
||||
for m in range(tCpC.shape[1]):
|
||||
tCpC[rest_v, m, 0] = cute.elem_less(
|
||||
tCcC[(0, rest_v), m, 0][0], mC.shape[0]
|
||||
)
|
||||
|
||||
# Copy to global memory using better vectorization
|
||||
for rest_v in range(tCpC.shape[0]):
|
||||
for n in range(tCpC.shape[2]):
|
||||
if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]):
|
||||
cute.copy(
|
||||
tiled_copy_C,
|
||||
tCrC_epilogue[None, None, n],
|
||||
tCgC_epilogue[None, None, n],
|
||||
pred=tCpC[None, None, n],
|
||||
)
|
||||
return
|
||||
|
||||
def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler):
|
||||
major_mode_size = (
|
||||
smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0]
|
||||
)
|
||||
major_mode_size = 64 if major_mode_size >= 64 else major_mode_size
|
||||
|
||||
swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
|
||||
swizzle_bits = min(swizzle_bits, 3)
|
||||
|
||||
layout_atom_outer = (
|
||||
cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
|
||||
)
|
||||
layout_atom = cute.make_composed_layout(
|
||||
cute.make_swizzle(swizzle_bits, 3, 3),
|
||||
0,
|
||||
layout_atom_outer,
|
||||
)
|
||||
layout = cute.tile_to_shape(layout_atom, smem_tiler, (0, 1, 2))
|
||||
return layout
|
||||
|
||||
def _make_smem_layout_C(self, dtype, major_mode, copy_bits, smem_tiler):
|
||||
major_mode_size = (
|
||||
smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0]
|
||||
)
|
||||
|
||||
swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
|
||||
swizzle_bits = min(swizzle_bits, 3)
|
||||
|
||||
layout_atom_outer = (
|
||||
cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
|
||||
)
|
||||
layout_atom = cute.make_composed_layout(
|
||||
cute.make_swizzle(swizzle_bits, 3, 4),
|
||||
0,
|
||||
layout_atom_outer,
|
||||
)
|
||||
|
||||
# Due to the thread layout of the mma, remove swizzle in C to
|
||||
# prevent shared memory fragments owned by an single thread from
|
||||
# holding swizzles
|
||||
if major_mode == utils.LayoutEnum.COL_MAJOR:
|
||||
layout_atom = cute.make_composed_layout(
|
||||
cute.make_swizzle(0, 3, 4), 0, layout_atom_outer
|
||||
)
|
||||
layout = cute.tile_to_shape(
|
||||
layout_atom,
|
||||
smem_tiler,
|
||||
(0, 1),
|
||||
)
|
||||
return layout
|
||||
|
||||
def _make_gmem_tiled_copy_AB(self, atom_copy, dtype, major_mode, copy_bits):
|
||||
copy_elems = copy_bits // dtype.width
|
||||
shape_dim_1 = cute.size(self.bK) // copy_elems
|
||||
# thread layout for copy
|
||||
thread_layout = cute.make_layout(
|
||||
(self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
||||
)
|
||||
if major_mode != utils.LayoutEnum.ROW_MAJOR:
|
||||
shape_dim_0 = cute.size(self.bM) // copy_elems
|
||||
thread_layout = cute.make_layout(
|
||||
(shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
||||
)
|
||||
# Value layout for copy
|
||||
value_layout = (
|
||||
cute.make_layout((1, copy_elems))
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((copy_elems, 1))
|
||||
)
|
||||
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
|
||||
|
||||
def _make_gmem_tiled_copy_C(self, atom_copy, dtype, major_mode, copy_bits):
|
||||
copy_elems = copy_bits // dtype.width
|
||||
shape_dim_1 = cute.size(self.bN) // copy_elems
|
||||
# thread layout for copy
|
||||
thread_layout = cute.make_layout(
|
||||
(self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
||||
)
|
||||
if major_mode != utils.LayoutEnum.ROW_MAJOR:
|
||||
shape_dim_0 = cute.size(self.bM) // copy_elems
|
||||
thread_layout = cute.make_layout(
|
||||
(shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
||||
)
|
||||
value_layout = (
|
||||
cute.make_layout((1, copy_elems))
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((copy_elems, 1))
|
||||
)
|
||||
tiler_mn, layout_tv = cute.make_layout_tv(thread_layout, value_layout)
|
||||
return cute.make_tiled_copy(atom_copy, layout_tv, tiler_mn)
|
||||
|
||||
|
||||
def run_tensor_op_gemm(
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
ab_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
acc_dtype: Type[cutlass.Numeric],
|
||||
problem_shape: Tuple[int, int, int, int],
|
||||
atom_layout_mnk: Tuple[int, int, int],
|
||||
warmup_iterations: int = 2,
|
||||
iterations: int = 100,
|
||||
skip_ref_check: bool = False,
|
||||
):
|
||||
M, N, K, L = problem_shape
|
||||
|
||||
# Create and permute tensor A/B/C
|
||||
def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype):
|
||||
# is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l)
|
||||
# else: (l, mode0, mode1) -> (mode0, mode1, l)
|
||||
shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
|
||||
permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
|
||||
|
||||
return (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
.to(dtype=dtype)
|
||||
.permute(permute_order)
|
||||
.cuda()
|
||||
)
|
||||
|
||||
a = create_and_permute_tensor(
|
||||
L, M, K, a_major == "m", cutlass_torch.dtype(ab_dtype)
|
||||
)
|
||||
b = create_and_permute_tensor(
|
||||
L, N, K, b_major == "n", cutlass_torch.dtype(ab_dtype)
|
||||
)
|
||||
c = create_and_permute_tensor(L, M, N, c_major == "m", cutlass_torch.dtype(c_dtype))
|
||||
ref = torch.einsum("mkl,nkl->mnl", a, b).to(cutlass_torch.dtype(c_dtype))
|
||||
|
||||
tensor_op_gemm = TensorOpGemm(
|
||||
ab_dtype,
|
||||
c_dtype,
|
||||
acc_dtype,
|
||||
atom_layout_mnk,
|
||||
)
|
||||
|
||||
# assume input is 16B aligned
|
||||
a_tensor = (
|
||||
from_dlpack(a, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if a_major == "k" else 0),
|
||||
stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0),
|
||||
divisibility=(128 // ab_dtype.width),
|
||||
)
|
||||
)
|
||||
b_tensor = (
|
||||
from_dlpack(b, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if b_major == "k" else 0),
|
||||
stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0),
|
||||
divisibility=(128 // ab_dtype.width),
|
||||
)
|
||||
)
|
||||
c_tensor = (
|
||||
from_dlpack(c, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=(1 if c_major == "n" else 0),
|
||||
stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0),
|
||||
divisibility=(128 // c_dtype.width),
|
||||
)
|
||||
)
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
gemm = cute.compile(tensor_op_gemm, a_tensor, b_tensor, c_tensor)
|
||||
|
||||
print("Executing GEMM kernel...")
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
# Execute the kernel
|
||||
for _ in range(iterations):
|
||||
gemm(a_tensor, b_tensor, c_tensor)
|
||||
|
||||
if not skip_ref_check:
|
||||
print("Verifying results...")
|
||||
torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
|
||||
try:
|
||||
return tuple(int(x.strip()) for x in s.split(","))
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format. Expected comma-separated integers."
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="example of multistage block matmul with CuTe on GPU"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mnkl", type=parse_comma_separated_ints, default=(112, 136, 40, 1)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atom_layout_mnk", type=parse_comma_separated_ints, default=(2, 2, 1)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ab_dtype",
|
||||
type=cutlass.dtype,
|
||||
choices=[cutlass.Float16],
|
||||
default=cutlass.Float16,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--acc_dtype",
|
||||
type=cutlass.dtype,
|
||||
choices=[cutlass.Float32],
|
||||
default=cutlass.Float32,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--c_dtype",
|
||||
type=cutlass.dtype,
|
||||
choices=[cutlass.Float16],
|
||||
default=cutlass.Float16,
|
||||
)
|
||||
parser.add_argument("--a_major", choices=["k", "m"], default="m")
|
||||
parser.add_argument("--b_major", choices=["k", "n"], default="n")
|
||||
parser.add_argument("--c_major", choices=["n", "m"], default="n")
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
print("Running Ampere tensor core GEMM example:")
|
||||
run_tensor_op_gemm(
|
||||
args.a_major,
|
||||
args.b_major,
|
||||
args.c_major,
|
||||
args.ab_dtype,
|
||||
args.c_dtype,
|
||||
args.acc_dtype,
|
||||
args.mnkl,
|
||||
args.atom_layout_mnk,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
)
|
||||
print("PASS")
|
||||
1922
examples/python/CuTeDSL/blackwell/dense_gemm.py
Normal file
1922
examples/python/CuTeDSL/blackwell/dense_gemm.py
Normal file
File diff suppressed because it is too large
Load Diff
2144
examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py
Normal file
2144
examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py
Normal file
File diff suppressed because it is too large
Load Diff
2984
examples/python/CuTeDSL/blackwell/fmha.py
Normal file
2984
examples/python/CuTeDSL/blackwell/fmha.py
Normal file
File diff suppressed because it is too large
Load Diff
2287
examples/python/CuTeDSL/blackwell/grouped_gemm.py
Normal file
2287
examples/python/CuTeDSL/blackwell/grouped_gemm.py
Normal file
File diff suppressed because it is too large
Load Diff
31
examples/python/CuTeDSL/notebooks/README.md
Normal file
31
examples/python/CuTeDSL/notebooks/README.md
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
648
examples/python/CuTeDSL/notebooks/cuda_graphs.ipynb
Normal file
648
examples/python/CuTeDSL/notebooks/cuda_graphs.ipynb
Normal file
@@ -0,0 +1,648 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0e95f0df-4d1a-4e2e-92ff-90539bb4c517",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Example 06: CUDA Graphs\n",
|
||||
"\n",
|
||||
"In this example we demonstrate how to use CUDA graphs through PyTorch with CuTe DSL.\n",
|
||||
"The process of interacting with PyTorch's CUDA graph implementation requires exposing PyTorch's CUDA streams to CUTLASS.\n",
|
||||
"\n",
|
||||
"To use CUDA graphs with Blackwell requires a version of PyTorch that supports Blackwell.\n",
|
||||
"This can be obtained through:\n",
|
||||
"- The [PyTorch NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)\n",
|
||||
"- [PyTorch 2.7 with CUDA 12.8 or later](https://pytorch.org/) (e.g., `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128`)\n",
|
||||
"- Building PyTorch directly with your version of CUDA."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "46b8fb6f-9ac5-4a3d-b765-b6476f182bf7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# import torch for CUDA graphs\n",
|
||||
"import torch\n",
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute\n",
|
||||
"# import CUstream type from the cuda driver bindings\n",
|
||||
"from cuda.bindings.driver import CUstream\n",
|
||||
"# import the current_stream function from torch\n",
|
||||
"from torch.cuda import current_stream"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bcf5e06e-1f5b-4d72-ad73-9b36efb78ca0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Kernel Creation\n",
|
||||
"\n",
|
||||
"We create a kernel which prints \"Hello world\" as well as a host function to launch the kernel.\n",
|
||||
"We then compile the kernel for use in our graph, by passing in a default stream.\n",
|
||||
"\n",
|
||||
"Kernel compilation before graph capture is required since CUDA graphs cannot JIT compile kernels during graph execution."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "0c2a6ca8-98d7-4837-b91f-af769ca8fcd8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def hello_world_kernel():\n",
|
||||
" \"\"\"\n",
|
||||
" A kernel that prints hello world\n",
|
||||
" \"\"\"\n",
|
||||
" cute.printf(\"Hello world\")\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def hello_world(stream : CUstream):\n",
|
||||
" \"\"\"\n",
|
||||
" Host function that launches our (1,1,1), (1,1,1) grid in stream\n",
|
||||
" \"\"\"\n",
|
||||
" hello_world_kernel().launch(grid=[1, 1, 1], block=[1, 1, 1], stream=stream)\n",
|
||||
"\n",
|
||||
"# Grab a stream from PyTorch, this will also initialize our context\n",
|
||||
"# so we can omit cutlass.cuda.initialize_cuda_context()\n",
|
||||
"stream = current_stream()\n",
|
||||
"hello_world_compiled = cute.compile(hello_world, CUstream(stream.cuda_stream))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ecc850af-09f8-4a29-9c93-ff31fbb9326f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Creating and replaying a CUDA Graph\n",
|
||||
"\n",
|
||||
"We create a stream through torch as well as a graph.\n",
|
||||
"When we create the graph we can pass the stream we want to capture to torch. We similarly run the compiled kernel with the stream passed as a CUstream.\n",
|
||||
"\n",
|
||||
"Finally we can replay our graph and synchronize."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "f673e5ae-42bb-44d0-b652-3280606181c4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hello world\n",
|
||||
"Hello world\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Create a CUDA Graph\n",
|
||||
"g = torch.cuda.CUDAGraph()\n",
|
||||
"# Capture our graph\n",
|
||||
"with torch.cuda.graph(g):\n",
|
||||
" # Turn our torch Stream into a cuStream stream.\n",
|
||||
" # This is done by getting the underlying CUstream with .cuda_stream\n",
|
||||
" graph_stream = CUstream(current_stream().cuda_stream)\n",
|
||||
" # Run 2 iterations of our compiled kernel\n",
|
||||
" for _ in range(2):\n",
|
||||
" # Run our kernel in the stream\n",
|
||||
" hello_world_compiled(graph_stream)\n",
|
||||
"\n",
|
||||
"# Replay our graph\n",
|
||||
"g.replay()\n",
|
||||
"# Synchronize all streams (equivalent to cudaDeviceSynchronize() in C++)\n",
|
||||
"torch.cuda.synchronize()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "db76d9c3-7617-4bf2-b326-11982e6803bf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Our run results in the following execution when viewed in NSight Systems:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"We can observe the launch of the two kernels followed by a `cudaDeviceSynchronize()`.\n",
|
||||
"\n",
|
||||
"Now we can confirm that this minimizes some launch overhead:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "3ebe15bf-dc97-42e9-913c-224ecfb472e8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n",
|
||||
"Hello world\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Get our CUDA stream from PyTorch\n",
|
||||
"stream = CUstream(current_stream().cuda_stream)\n",
|
||||
"\n",
|
||||
"# Create a larger CUDA Graph of 100 iterations\n",
|
||||
"g = torch.cuda.CUDAGraph()\n",
|
||||
"# Capture our graph\n",
|
||||
"with torch.cuda.graph(g):\n",
|
||||
" # Turn our torch Stream into a cuStream stream.\n",
|
||||
" # This is done by getting the underlying CUstream with .cuda_stream\n",
|
||||
" graph_stream = CUstream(current_stream().cuda_stream)\n",
|
||||
" # Run 2 iterations of our compiled kernel\n",
|
||||
" for _ in range(100):\n",
|
||||
" # Run our kernel in the stream\n",
|
||||
" hello_world_compiled(graph_stream)\n",
|
||||
"\n",
|
||||
"# Create CUDA events for measuring performance\n",
|
||||
"start = torch.cuda.Event(enable_timing=True)\n",
|
||||
"end = torch.cuda.Event(enable_timing=True)\n",
|
||||
"\n",
|
||||
"# Run our kernel to warm up the GPU\n",
|
||||
"for _ in range(100):\n",
|
||||
" hello_world_compiled(stream)\n",
|
||||
"\n",
|
||||
"# Record our start time\n",
|
||||
"start.record()\n",
|
||||
"# Run 100 kernels\n",
|
||||
"for _ in range(100):\n",
|
||||
" hello_world_compiled(stream)\n",
|
||||
"# Record our end time\n",
|
||||
"end.record()\n",
|
||||
"# Synchronize (cudaDeviceSynchronize())\n",
|
||||
"torch.cuda.synchronize()\n",
|
||||
"\n",
|
||||
"# Calculate the time spent when launching kernels in a stream\n",
|
||||
"# Results are in ms\n",
|
||||
"stream_time = start.elapsed_time(end) \n",
|
||||
"\n",
|
||||
"# Warmup our GPU again\n",
|
||||
"g.replay()\n",
|
||||
"# Record our start time\n",
|
||||
"start.record()\n",
|
||||
"# Run our graph\n",
|
||||
"g.replay()\n",
|
||||
"# Record our end time\n",
|
||||
"end.record()\n",
|
||||
"# Synchronize (cudaDeviceSynchronize())\n",
|
||||
"torch.cuda.synchronize()\n",
|
||||
"\n",
|
||||
"# Calculate the time spent when launching kernels in a graph\n",
|
||||
"# units are ms\n",
|
||||
"graph_time = start.elapsed_time(end)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "12b8151a-46b3-4c99-9945-301f6b628131",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"8.94% speedup when using CUDA graphs for this kernel!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Print out speedup when using CUDA graphs\n",
|
||||
"percent_speedup = (stream_time - graph_time) / graph_time\n",
|
||||
"print(f\"{percent_speedup * 100.0:.2f}% speedup when using CUDA graphs for this kernel!\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
1001
examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb
Normal file
1001
examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
310
examples/python/CuTeDSL/notebooks/data_types.ipynb
Normal file
310
examples/python/CuTeDSL/notebooks/data_types.ipynb
Normal file
@@ -0,0 +1,310 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Understanding data structure in CuTe DSL\n",
|
||||
"\n",
|
||||
"In most cases, data structures in CuTe DSL work the same as Python data structures with the notable difference that Python data structures in most cases are considered as static data which are interpreted by the DSL compiler embedded inside Python interpreter.\n",
|
||||
"\n",
|
||||
"To differentiate between compile-time and runtime values, CuTe DSL introduces primitive types that \n",
|
||||
"represent dynamic values in JIT-compiled code.\n",
|
||||
"\n",
|
||||
"CuTe DSL provides a comprehensive set of primitive numeric types for representing dynamic values at \n",
|
||||
"runtime. These types are formally defined within the CuTe DSL typing system:\n",
|
||||
"\n",
|
||||
"### Integer Types\n",
|
||||
"- `Int8` - 8-bit signed integer\n",
|
||||
"- `Int16` - 16-bit signed integer \n",
|
||||
"- `Int32` - 32-bit signed integer\n",
|
||||
"- `Int64` - 64-bit signed integer\n",
|
||||
"- `Int128` - 128-bit signed integer\n",
|
||||
"- `Uint8` - 8-bit unsigned integer\n",
|
||||
"- `Uint16` - 16-bit unsigned integer\n",
|
||||
"- `Uint32` - 32-bit unsigned integer\n",
|
||||
"- `Uint64` - 64-bit unsigned integer\n",
|
||||
"- `Uint128` - 128-bit unsigned integer\n",
|
||||
"\n",
|
||||
"### Floating Point Types\n",
|
||||
"- `Float16` - 16-bit floating point\n",
|
||||
"- `Float32` - 32-bit floating point \n",
|
||||
"- `Float64` - 64-bit floating point\n",
|
||||
"- `BFloat16` - Brain Floating Point format (16-bit)\n",
|
||||
"- `TFloat32` - Tensor Float32 format (reduced precision format used in tensor operations)\n",
|
||||
"- `Float8E4M3` - 8-bit floating point with 4-bit exponent and 3-bit mantissa\n",
|
||||
"- `Float8E5M2` - 8-bit floating point with 5-bit exponent and 2-bit mantissa\n",
|
||||
"\n",
|
||||
"These specialized types are designed to represent dynamic values in CuTe DSL code that will be \n",
|
||||
"evaluated at runtime, in contrast to Python's built-in numeric types which are evaluated during \n",
|
||||
"compilation.\n",
|
||||
"\n",
|
||||
"### Example usage:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"x = cutlass.Int32(5) # Creates a 32-bit integer\n",
|
||||
"y = cutlass.Float32(3.14) # Creates a 32-bit float\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def foo(a: cutlass.Int32): # annotate `a` as 32-bit integer passed to jit function via ABI\n",
|
||||
" ...\n",
|
||||
"```\n",
|
||||
"To differentiate between compile-time and runtime values, CuTe DSL introduces primitive types that \n",
|
||||
"represent dynamic values in JIT-compiled code.\n",
|
||||
"\n",
|
||||
"CuTe DSL provides a comprehensive set of primitive numeric types for representing dynamic values at \n",
|
||||
"runtime. These types are formally defined within the CuTe DSL typing system:\n",
|
||||
"\n",
|
||||
"### Integer Types\n",
|
||||
"- `Int8` - 8-bit signed integer\n",
|
||||
"- `Int16` - 16-bit signed integer \n",
|
||||
"- `Int32` - 32-bit signed integer\n",
|
||||
"- `Int64` - 64-bit signed integer\n",
|
||||
"- `Int128` - 128-bit signed integer\n",
|
||||
"- `Uint8` - 8-bit unsigned integer\n",
|
||||
"- `Uint16` - 16-bit unsigned integer\n",
|
||||
"- `Uint32` - 32-bit unsigned integer\n",
|
||||
"- `Uint64` - 64-bit unsigned integer\n",
|
||||
"- `Uint128` - 128-bit unsigned integer\n",
|
||||
"\n",
|
||||
"### Floating Point Types\n",
|
||||
"- `Float16` - 16-bit floating point\n",
|
||||
"- `Float32` - 32-bit floating point \n",
|
||||
"- `Float64` - 64-bit floating point\n",
|
||||
"- `BFloat16` - Brain Floating Point format (16-bit)\n",
|
||||
"- `TFloat32` - Tensor Float32 format (reduced precision format used in tensor operations)\n",
|
||||
"- `Float8E4M3` - 8-bit floating point with 4-bit exponent and 3-bit mantissa\n",
|
||||
"- `Float8E5M2` - 8-bit floating point with 5-bit exponent and 2-bit mantissa\n",
|
||||
"\n",
|
||||
"These specialized types are designed to represent dynamic values in CuTe DSL code that will be \n",
|
||||
"evaluated at runtime, in contrast to Python's built-in numeric types which are evaluated during \n",
|
||||
"compilation.\n",
|
||||
"\n",
|
||||
"### Example usage:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"x = cutlass.Int32(5) # Creates a 32-bit integer\n",
|
||||
"y = cutlass.Float32(3.14) # Creates a 32-bit float\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def foo(a: cutlass.Int32): # annotate `a` as 32-bit integer passed to jit function via ABI\n",
|
||||
" ...\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"a(static) = ?\n",
|
||||
"b(static) = ?\n",
|
||||
"a(dynamic) = 3.140000\n",
|
||||
"b(dynamic) = 5\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def bar():\n",
|
||||
" a = cutlass.Float32(3.14)\n",
|
||||
" print(\"a(static) =\", a) # prints `a(static) = ?`\n",
|
||||
" cute.printf(\"a(dynamic) = {}\", a) # prints `a(dynamic) = 3.140000`\n",
|
||||
"\n",
|
||||
" b = cutlass.Int32(5)\n",
|
||||
" print(\"b(static) =\", b) # prints `b(static) = 5`\n",
|
||||
" cute.printf(\"b(dynamic) = {}\", b) # prints `b(dynamic) = 5`\n",
|
||||
"\n",
|
||||
"bar()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Type Conversion API\n",
|
||||
"\n",
|
||||
"CUTLASS numeric types provide type conversion through the `to()` method available on all Numeric types. This allows you to convert between different numeric data types at runtime.\n",
|
||||
"\n",
|
||||
"Syntax:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"new_value = value.to(target_type)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The `to()` method supports conversion between:\n",
|
||||
"- Integer types (Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64)\n",
|
||||
"- Floating point types (Float16, Float32, Float64, BFloat16)\n",
|
||||
"- Mixed integer/floating point conversions\n",
|
||||
"\n",
|
||||
"Note that when converting from floating point to integer types, the decimal portion is truncated. When converting between types with different ranges, values may be clamped or lose precision if they exceed the target type's representable range."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Int32(42) => Float32(42.000000)\n",
|
||||
"Float32(3.140000) => Int32(3)\n",
|
||||
"Int32(127) => Int8(127)\n",
|
||||
"Int32(300) => Int8(44) (truncated due to range limitation)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def type_conversion():\n",
|
||||
" # Convert from Int32 to Float32\n",
|
||||
" x = cutlass.Int32(42)\n",
|
||||
" y = x.to(cutlass.Float32)\n",
|
||||
" cute.printf(\"Int32({}) => Float32({})\", x, y)\n",
|
||||
"\n",
|
||||
" # Convert from Float32 to Int32\n",
|
||||
" a = cutlass.Float32(3.14)\n",
|
||||
" b = a.to(cutlass.Int32)\n",
|
||||
" cute.printf(\"Float32({}) => Int32({})\", a, b)\n",
|
||||
"\n",
|
||||
" # Convert from Int32 to Int8\n",
|
||||
" c = cutlass.Int32(127)\n",
|
||||
" d = c.to(cutlass.Int8)\n",
|
||||
" cute.printf(\"Int32({}) => Int8({})\", c, d)\n",
|
||||
"\n",
|
||||
" # Convert from Int32 to Int8 with value exceeding Int8 range\n",
|
||||
" e = cutlass.Int32(300)\n",
|
||||
" f = e.to(cutlass.Int8)\n",
|
||||
" cute.printf(\"Int32({}) => Int8({}) (truncated due to range limitation)\", e, f)\n",
|
||||
"\n",
|
||||
"type_conversion()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Operator Overloading\n",
|
||||
"\n",
|
||||
"CUTLASS numeric types support Python's built-in operators, allowing you to write natural mathematical expressions. The operators work with both CUTLASS numeric types and Python native numeric types.\n",
|
||||
"\n",
|
||||
"Supported operators include:\n",
|
||||
"- Arithmetic: `+`, `-`, `*`, `/`, `//`, `%`, `**`\n",
|
||||
"- Comparison: `<`, `<=`, `==`, `!=`, `>=`, `>`\n",
|
||||
"- Bitwise: `&`, `|`, `^`, `<<`, `>>`\n",
|
||||
"- Unary: `-` (negation), `~` (bitwise NOT)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"a: Int32(10), b: Int32(3)\n",
|
||||
"x: Float32(5.500000)\n",
|
||||
"\n",
|
||||
"a + b = 13\n",
|
||||
"x * 2 = 11.000000\n",
|
||||
"a + x = 15.500000 (Int32 + Float32 promotes to Float32)\n",
|
||||
"a / b = 3.333333\n",
|
||||
"x / 2.0 = 2.750000\n",
|
||||
"a > b = 1\n",
|
||||
"a & b = 2\n",
|
||||
"-a = -10\n",
|
||||
"~a = -11\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def operator_demo():\n",
|
||||
" # Arithmetic operators\n",
|
||||
" a = cutlass.Int32(10)\n",
|
||||
" b = cutlass.Int32(3)\n",
|
||||
" cute.printf(\"a: Int32({}), b: Int32({})\", a, b)\n",
|
||||
"\n",
|
||||
" x = cutlass.Float32(5.5)\n",
|
||||
" cute.printf(\"x: Float32({})\", x)\n",
|
||||
"\n",
|
||||
" cute.printf(\"\")\n",
|
||||
"\n",
|
||||
" sum_result = a + b\n",
|
||||
" cute.printf(\"a + b = {}\", sum_result)\n",
|
||||
"\n",
|
||||
" y = x * 2 # Multiplying with Python native type\n",
|
||||
" cute.printf(\"x * 2 = {}\", y)\n",
|
||||
"\n",
|
||||
" # Mixed type arithmetic (Int32 + Float32) that integer is converted into float32\n",
|
||||
" mixed_result = a + x\n",
|
||||
" cute.printf(\"a + x = {} (Int32 + Float32 promotes to Float32)\", mixed_result)\n",
|
||||
"\n",
|
||||
" # Division with Int32 (note: integer division)\n",
|
||||
" div_result = a / b\n",
|
||||
" cute.printf(\"a / b = {}\", div_result)\n",
|
||||
"\n",
|
||||
" # Float division\n",
|
||||
" float_div = x / cutlass.Float32(2.0)\n",
|
||||
" cute.printf(\"x / 2.0 = {}\", float_div)\n",
|
||||
"\n",
|
||||
" # Comparison operators\n",
|
||||
" is_greater = a > b\n",
|
||||
" cute.printf(\"a > b = {}\", is_greater)\n",
|
||||
"\n",
|
||||
" # Bitwise operators\n",
|
||||
" bit_and = a & b\n",
|
||||
" cute.printf(\"a & b = {}\", bit_and)\n",
|
||||
"\n",
|
||||
" neg_a = -a\n",
|
||||
" cute.printf(\"-a = {}\", neg_a)\n",
|
||||
"\n",
|
||||
" not_a = ~a\n",
|
||||
" cute.printf(\"~a = {}\", not_a)\n",
|
||||
"\n",
|
||||
"operator_demo()\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
838
examples/python/CuTeDSL/notebooks/elementwise_add.ipynb
Normal file
838
examples/python/CuTeDSL/notebooks/elementwise_add.ipynb
Normal file
@@ -0,0 +1,838 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"editable": true,
|
||||
"slideshow": {
|
||||
"slide_type": ""
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from functools import partial\n",
|
||||
"\n",
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute\n",
|
||||
"from cutlass.cute.runtime import from_dlpack"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tutorial: Elementwise Add Kernel in CuTe DSL\n",
|
||||
"\n",
|
||||
"This tutorial demonstrates how to implement a simple elementwise\n",
|
||||
"addition kernel using the CuTe DSL (Domain Specific Language).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Elementwise Addition\n",
|
||||
"---------------------\n",
|
||||
"\n",
|
||||
"Elementwise addition is a fundamental operation in linear algebra.\n",
|
||||
"Given two tensors of the same shape, the operation performs element-wise\n",
|
||||
"addition to produce a result tensor of the same shape.\n",
|
||||
"\n",
|
||||
"For two 2D tensors :math:`A` and :math:`B` of shape :math:`(M, N)`,\n",
|
||||
"the elementwise addition operation :math:`C = A + B` is defined as:\n",
|
||||
"\n",
|
||||
"$\n",
|
||||
" C_{i,j} = A_{i,j} + B_{i,j}\n",
|
||||
"$\n",
|
||||
"\n",
|
||||
"where:\n",
|
||||
"\n",
|
||||
"- $i \\in [0, M-1]$ represents the row index\n",
|
||||
"- $j \\in [0, N-1]$ represents the column index\n",
|
||||
"- $A_{i,j}$, $B_{i,j}$, and $C_{i,j}$ are the elements at position $(i,j)$ \n",
|
||||
" in tensors $A$, $B$, and $C$ respectively\n",
|
||||
"\n",
|
||||
"This operation is performed independently for each element position,\n",
|
||||
"making it highly parallelizable and well-suited for GPU implementation.\n",
|
||||
"\n",
|
||||
"Naive Elementwise Add Kernel\n",
|
||||
"-----------------------------\n",
|
||||
"\n",
|
||||
"Let's start with a naive implementation that loads each element from\n",
|
||||
"$A$ and $B$, adds them, and stores the result back to $C$."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def naive_elementwise_add_kernel(\n",
|
||||
" gA: cute.Tensor,\n",
|
||||
" gB: cute.Tensor,\n",
|
||||
" gC: cute.Tensor,\n",
|
||||
"):\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
" bidx, _, _ = cute.arch.block_idx()\n",
|
||||
" bdim, _, _ = cute.arch.block_dim()\n",
|
||||
"\n",
|
||||
" thread_idx = bidx * bdim + tidx\n",
|
||||
"\n",
|
||||
" # Map thread index to logical index of input tensor\n",
|
||||
" m, n = gA.shape\n",
|
||||
" ni = thread_idx % n\n",
|
||||
" mi = thread_idx // n\n",
|
||||
"\n",
|
||||
" # Map logical index to physical address via tensor layout\n",
|
||||
" a_val = gA[mi, ni]\n",
|
||||
" b_val = gB[mi, ni]\n",
|
||||
"\n",
|
||||
" # Perform element-wise addition\n",
|
||||
" gC[mi, ni] = a_val + b_val"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Structure of the Kernel\n",
|
||||
"\n",
|
||||
"The naive kernel simply maps each thread to one element with a 1-to-1 mapping.\n",
|
||||
"In this kernel, we don't use CuTe layout algebra but only use basic\n",
|
||||
"addressing to index the tensor.\n",
|
||||
"\n",
|
||||
"We can launch the kernel with the following JIT function:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def naive_elementwise_add(\n",
|
||||
" mA: cute.Tensor,\n",
|
||||
" mB: cute.Tensor,\n",
|
||||
" mC: cute.Tensor\n",
|
||||
"):\n",
|
||||
" num_threads_per_block = 256\n",
|
||||
"\n",
|
||||
" m, n = mA.shape\n",
|
||||
" kernel = naive_elementwise_add_kernel(mA, mB, mC)\n",
|
||||
" kernel.launch(grid=((m * n) // num_threads_per_block, 1, 1),\n",
|
||||
" block=(num_threads_per_block, 1, 1))\n",
|
||||
"\n",
|
||||
"M, N = 2048, 2048\n",
|
||||
"\n",
|
||||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||||
"\n",
|
||||
"# Compile kernel\n",
|
||||
"naive_elementwise_add_ = cute.compile(naive_elementwise_add, a_, b_, c_)\n",
|
||||
"naive_elementwise_add_(a_, b_, c_)\n",
|
||||
"\n",
|
||||
"# verify correctness\n",
|
||||
"torch.testing.assert_close(c, a + b)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Benchmark performance\n",
|
||||
"\n",
|
||||
"Here's a utility function to benchmark our kernel implementations:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark(callable, *, num_warmups, num_iterations):\n",
|
||||
" start_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" end_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
"\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
"\n",
|
||||
" for _ in range(num_warmups):\n",
|
||||
" callable()\n",
|
||||
"\n",
|
||||
" start_event.record(stream=torch.cuda.current_stream())\n",
|
||||
" for _ in range(num_iterations):\n",
|
||||
" callable()\n",
|
||||
" end_event.record(stream=torch.cuda.current_stream())\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
"\n",
|
||||
" elapsed_time = start_event.elapsed_time(end_event)\n",
|
||||
" avg_time = elapsed_time / num_iterations\n",
|
||||
"\n",
|
||||
" print(f\"Average execution time: {avg_time:.4f} ms\")\n",
|
||||
" print(f\"Throughput: {(3 * a.numel() * 2) / (avg_time / 1000) / 1e9:.2f} GB/s\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Average execution time: 0.0385 ms\n",
|
||||
"Throughput: 653.44 GB/s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"benchmark(partial(naive_elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=100)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Performance Analysis\n",
|
||||
"\n",
|
||||
"While our naive implementation maps thread indices to contiguous tensor\n",
|
||||
"dimensions for coalesced memory access, it doesn't have enough\n",
|
||||
"in-flight load & store operations to hide memory latency.\n",
|
||||
"\n",
|
||||
"According to Little's Law:\n",
|
||||
"\n",
|
||||
"$ L = \\lambda \\times W $\n",
|
||||
"\n",
|
||||
"Where:\n",
|
||||
"- $L$ is the average number of items in a system\n",
|
||||
"- $\\lambda$ is the average arrival rate of items (bandwidth)\n",
|
||||
"- $W$ is the average time an item spends in the system (latency)\n",
|
||||
"\n",
|
||||
"For our elementwise addition kernel:\n",
|
||||
"\n",
|
||||
"1. $L$: The number of load & store operations in-flight\n",
|
||||
"2. $\\lambda$ (Bandwidth): Data transfer rate between memory and compute units\n",
|
||||
"3. $W$ (Latency): Round-trip delay of memory requests\n",
|
||||
"\n",
|
||||
"For memory-bound operations like elementwise addition, performance is\n",
|
||||
"limited by the number of in-flight load & store operations.\n",
|
||||
"\n",
|
||||
"## Vectorized Load and Store\n",
|
||||
"\n",
|
||||
"To improve performance according to Little's Law, we need to increase the number\n",
|
||||
"of in-flight requests. We can do this by increasing the number of bytes handled\n",
|
||||
"in each load & store operation per thread through vectorized memory access.\n",
|
||||
"\n",
|
||||
"Since Ampere GPUs support up to 128-bit per load/store and each element is 32-bit,\n",
|
||||
"we can load 4 elements per vectorized operation on contiguous rows.\n",
|
||||
"CuTe tiling operations make this vectorization straightforward.\n",
|
||||
"\n",
|
||||
"Using ``tiled_tensor = cute.zipped_divide(tensor, tiler)``, we can partition the input\n",
|
||||
"``tensor`` into groups of ``tiler`` blocks. For vectorization, we specify ``tiler``\n",
|
||||
"as the block of data each thread accesses (4 contiguous elements in the same row, or ``(1,4)``).\n",
|
||||
"Different threads can then access different blocks by indexing into the 2nd mode of ``tiled_tensor``.\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"mA : cute.Tensor # (2048,2048):(2048,1)\n",
|
||||
"gA = cute.zipped_divide(a, tiler=(1, 4)) # tiled/vectorized => ((1,4),(2048,512)):((0,1),(2048,4))\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"$\n",
|
||||
" \\begin{array}{ccccc}\n",
|
||||
" & ((1,4) & , & (2048,512)) & : ((0,1),(2048,4)) \\\\\n",
|
||||
" & \\underbrace{\\phantom{(1,4)}}_{tiler} & & \\underbrace{\\phantom{(2048,512)}}_{threads} & \\\\\n",
|
||||
" & \\text{\\scriptsize per-thread} & & \\text{\\scriptsize num of tiles}\n",
|
||||
" \\end{array}\n",
|
||||
"$"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def vectorized_elementwise_add_kernel(\n",
|
||||
" gA: cute.Tensor,\n",
|
||||
" gB: cute.Tensor,\n",
|
||||
" gC: cute.Tensor,\n",
|
||||
"):\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
" bidx, _, _ = cute.arch.block_idx()\n",
|
||||
" bdim, _, _ = cute.arch.block_dim()\n",
|
||||
"\n",
|
||||
" thread_idx = bidx * bdim + tidx\n",
|
||||
"\n",
|
||||
" # Map thread index to logical index of input tensor\n",
|
||||
" m, n = gA.shape[1] # thread-domain\n",
|
||||
" ni = thread_idx % n\n",
|
||||
" mi = thread_idx // n\n",
|
||||
"\n",
|
||||
" # Map logical index to physical address via tensor layout\n",
|
||||
" a_val = gA[(None, (mi, ni))].load()\n",
|
||||
" b_val = gB[(None, (mi, ni))].load()\n",
|
||||
" print(f\"[DSL INFO] sliced gA = {gA[(None, (mi, ni))]}\")\n",
|
||||
" print(f\"[DSL INFO] sliced gB = {gB[(None, (mi, ni))]}\")\n",
|
||||
"\n",
|
||||
" # Perform element-wise addition\n",
|
||||
" gC[(None, (mi, ni))] = a_val + b_val"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This vectorized kernel follows a similar structure to its naive non-vectorized counterpart,\n",
|
||||
"with one key difference: the tensor slicing pattern. By using `(None, (mi, ni))` as the slice indices,\n",
|
||||
"we can extract a `(1,4)` sub-tensor from `gA`, `gB` and `gC` like \n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"gA[(None, (mi, ni))]\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Then tensor data can be loaded into vector via the `.load()` method.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
" slice\n",
|
||||
" ((1,4),(2048,512)):((0,1),(2048,4)) ==> ((1,4)):((0,1))\n",
|
||||
" ^ ^ ^\n",
|
||||
" | | |\n",
|
||||
" (None, (mi, ni))\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[DSL INFO] Tiled Tensors:\n",
|
||||
"[DSL INFO] gA = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>\n",
|
||||
"[DSL INFO] gB = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>\n",
|
||||
"[DSL INFO] gC = tensor<ptr<f16, gmem, align<16>> o ((1,4),(2048,512)):((0,1),(2048,4))>\n",
|
||||
"[DSL INFO] sliced gA = tensor<ptr<f16, gmem, align<8>> o ((1,4)):((0,1))>\n",
|
||||
"[DSL INFO] sliced gB = tensor<ptr<f16, gmem, align<8>> o ((1,4)):((0,1))>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def vectorized_elementwise_add(\n",
|
||||
" mA: cute.Tensor,\n",
|
||||
" mB: cute.Tensor,\n",
|
||||
" mC: cute.Tensor\n",
|
||||
"):\n",
|
||||
" threads_per_block = 256\n",
|
||||
"\n",
|
||||
" gA = cute.zipped_divide(mA, (1, 4))\n",
|
||||
" gB = cute.zipped_divide(mB, (1, 4))\n",
|
||||
" gC = cute.zipped_divide(mC, (1, 4))\n",
|
||||
"\n",
|
||||
" print(f\"[DSL INFO] Tiled Tensors:\")\n",
|
||||
" print(f\"[DSL INFO] gA = {gA}\")\n",
|
||||
" print(f\"[DSL INFO] gB = {gB}\")\n",
|
||||
" print(f\"[DSL INFO] gC = {gC}\")\n",
|
||||
"\n",
|
||||
" vectorized_elementwise_add_kernel(gA, gB, gC).launch(\n",
|
||||
" grid=(cute.size(gC, mode=[1]) // threads_per_block, 1, 1),\n",
|
||||
" block=(threads_per_block, 1, 1),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||||
"\n",
|
||||
"compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n",
|
||||
"compiled_func(a_, b_, c_)\n",
|
||||
"\n",
|
||||
"# verify correctness\n",
|
||||
"torch.testing.assert_close(c, a + b)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Average execution time: 0.0202 ms\n",
|
||||
"Throughput: 1244.98 GB/s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"benchmark(partial(compiled_func, a_, b_, c_), num_warmups=5, num_iterations=100)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## TV Layout\n",
|
||||
"\n",
|
||||
"Both the naive and vectorized kernels follow a common pattern to map thread indices\n",
|
||||
"to physical addresses:\n",
|
||||
"\n",
|
||||
"Step 1: Map thread index to logical M/N coordinates\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
" mi = thread_idx // n\n",
|
||||
" ni = thread_idx % n\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Step 2: Map logical M/N coordinates to physical addresses using the tensor layout\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
" a[(None, (mi, ni))].load()\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"CuTe uses TV layout to represent this mapping from thread index and value index\n",
|
||||
"(i.e., the 4 elements loaded per thread) to the logical coordinate space of a tensor.\n",
|
||||
"By configuring different TV layouts, we can experiment with different memory access\n",
|
||||
"patterns with minimal code changes.\n",
|
||||
"\n",
|
||||
"The following example demonstrates two levels of tiling: at the thread-block level\n",
|
||||
"and at the thread level.\n",
|
||||
"\n",
|
||||
"For thread-block level tiling, each input & output tensor is first divided\n",
|
||||
"into a group of ``(TileM, TileN)`` sub-tensors at the host side.\n",
|
||||
"\n",
|
||||
"Inside the GPU kernel, we provide the thread-block index to the 2nd mode of the tiled tensor\n",
|
||||
"(``gA[((None, None), bidx)]``), which returns a thread-block local view of\n",
|
||||
"a single ``(TileM, TileN)`` sub-tensor.\n",
|
||||
"\n",
|
||||
"For thread level tiling, we compose the sub-tensor (which maps from logical coordinates\n",
|
||||
"to physical addresses) with the TV layout (which maps from thread & value indices to\n",
|
||||
"logical coordinates). This gives us a tiled sub-tensor that maps from thread & value\n",
|
||||
"indices directly to physical addresses.\n",
|
||||
"\n",
|
||||
"We then provide the thread index to the tiled sub-tensor (``tidfrgA[(tidx, None)]``)\n",
|
||||
"to get a thread-local view of the data each thread accesses. Note that the thread index\n",
|
||||
"is now in the 1st mode, as the tiled sub-tensor puts the thread mode before the value mode."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def elementwise_add_kernel(\n",
|
||||
" gA: cute.Tensor,\n",
|
||||
" gB: cute.Tensor,\n",
|
||||
" gC: cute.Tensor,\n",
|
||||
" tv_layout: cute.Layout\n",
|
||||
"):\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
" bidx, _, _ = cute.arch.block_idx()\n",
|
||||
"\n",
|
||||
" #--------------------------------\n",
|
||||
" # slice for thread-block level view\n",
|
||||
" #--------------------------------\n",
|
||||
" blk_coord = ((None, None), bidx)\n",
|
||||
"\n",
|
||||
" # logical coord -> address\n",
|
||||
" blkA = gA[blk_coord] # (TileM, TileN) -> physical address\n",
|
||||
" blkB = gB[blk_coord] # (TileM, TileN) -> physical address\n",
|
||||
" blkC = gC[blk_coord] # (TileM, TileN) -> physical address\n",
|
||||
"\n",
|
||||
" #--------------------------------\n",
|
||||
" # compose for thread-index & value-index to physical mapping\n",
|
||||
" #--------------------------------\n",
|
||||
" # blockA: (TileM, TileN) -> physical address\n",
|
||||
" # tv_layout: (tid, vid) -> (TileM, TileN)\n",
|
||||
" # tidfrgA = blkA o tv_layout\n",
|
||||
" # tidfrgA: (tid, vid) -> physical address\n",
|
||||
" tidfrgA = cute.composition(blkA, tv_layout)\n",
|
||||
" tidfrgB = cute.composition(blkB, tv_layout)\n",
|
||||
" tidfrgC = cute.composition(blkC, tv_layout)\n",
|
||||
"\n",
|
||||
" print(f\"Composed with TV layout:\")\n",
|
||||
" print(f\" tidfrgA: {tidfrgA.type}\")\n",
|
||||
"\n",
|
||||
" #--------------------------------\n",
|
||||
" # slice for thread-level view\n",
|
||||
" #--------------------------------\n",
|
||||
" # `None` represent slice of the entire per-thread data\n",
|
||||
" thr_coord = (tidx, None)\n",
|
||||
"\n",
|
||||
" # slice for threads: vid -> address\n",
|
||||
" thrA = tidfrgA[thr_coord] # (V) -> physical address\n",
|
||||
" thrB = tidfrgB[thr_coord] # (V) -> physical address\n",
|
||||
" thrC = tidfrgC[thr_coord] # (V) -> physical address\n",
|
||||
"\n",
|
||||
" thrC[None] = thrA.load() + thrB.load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we take a closer look at the layout of zipped divided input tensor `gA`:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"Tiled to Thread Block:\n",
|
||||
"\n",
|
||||
" ((16,256),(128,8)) : ((2048,1),(32768,256))\n",
|
||||
" ~~~~~~~~ ~~~~~~ ~~~~~~~~\n",
|
||||
" | | |\n",
|
||||
" | | |\n",
|
||||
" | `------------------------> Number of Thread Blocks\n",
|
||||
" | |\n",
|
||||
" | |\n",
|
||||
" `--------------------'\n",
|
||||
" |\n",
|
||||
" V\n",
|
||||
" Thread Block\n",
|
||||
" Tile\n",
|
||||
"\n",
|
||||
"Sliced to Thread-Block local sub-tensor (a (16, 128) tile): gA[((None, None), bidx)]\n",
|
||||
"\n",
|
||||
" (16,256) : (2048,1)\n",
|
||||
" ~~~~~~ ~~~~~~\n",
|
||||
" | | Tiled/Composed with TV Layout\n",
|
||||
" | | \n",
|
||||
" | | o ((32,4),(8,4)):((128,4),(16,1))\n",
|
||||
" V V \n",
|
||||
"~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~ \n",
|
||||
"((32,4), (8,4)) : ((4,8192),(1,2048))\n",
|
||||
" | |\n",
|
||||
" | `--------> per thread fragment\n",
|
||||
" |\n",
|
||||
"Thread Block\n",
|
||||
" Shape\n",
|
||||
"\n",
|
||||
"Sliced to Thread local sub-tensor (a (4,8) tile): tidfrgA[(tidx, None)]\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The host code below shows the construction of the TV layout. By composing\n",
|
||||
"a thread layout of ``(4,32):(32,1)`` (32 threads read contiguous elements on the row dimension,\n",
|
||||
"then 4 warps read different rows) with a value layout of ``(4,8):(8,1)`` (each thread reads\n",
|
||||
"8 contiguous elements on the row dimension across 4 contiguous rows),\n",
|
||||
"we obtain the TV layout shown in the figure above."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Tiler: (16, 256)\n",
|
||||
"TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n",
|
||||
"Tiled Input Tensors:\n",
|
||||
" gA: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
" gB: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
" gC: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
"Composed with TV layout:\n",
|
||||
" tidfrgA: !cute.memref<f16, gmem, align<16>, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def elementwise_add(\n",
|
||||
" mA: cute.Tensor,\n",
|
||||
" mB: cute.Tensor,\n",
|
||||
" mC: cute.Tensor,\n",
|
||||
"):\n",
|
||||
" # mA layout: (M, N):(N, 1)\n",
|
||||
" # TV layout map thread & value index to (16, 256) logical tile\n",
|
||||
" # - contiguous thread index maps to mode-1 because input layout is contiguous on\n",
|
||||
" # mode-1 for coalesced load-store\n",
|
||||
" # - each thread load 8 contiguous element each row and load 4 rows\n",
|
||||
" thr_layout = cute.make_layout((4, 32), stride=(32, 1))\n",
|
||||
" val_layout = cute.make_layout((4, 8), stride=(8, 1))\n",
|
||||
" tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n",
|
||||
" print(f\"Tiler: {tiler_mn}\")\n",
|
||||
" print(f\"TV Layout: {tv_layout}\")\n",
|
||||
"\n",
|
||||
" gA = cute.zipped_divide(mA, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||||
" gB = cute.zipped_divide(mB, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||||
" gC = cute.zipped_divide(mC, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||||
"\n",
|
||||
" print(f\"Tiled Input Tensors:\")\n",
|
||||
" print(f\" gA: {gA.type}\")\n",
|
||||
" print(f\" gB: {gB.type}\")\n",
|
||||
" print(f\" gC: {gC.type}\")\n",
|
||||
"\n",
|
||||
" # Launch the kernel asynchronously\n",
|
||||
" # Async token(s) can also be specified as dependencies\n",
|
||||
" elementwise_add_kernel(\n",
|
||||
" gA, gB, gC, tv_layout\n",
|
||||
" ).launch(\n",
|
||||
" grid=[cute.size(gC, mode=[1]), 1, 1],\n",
|
||||
" block=[cute.size(tv_layout, mode=[0]), 1, 1],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||||
"\n",
|
||||
"elementwise_add_ = cute.compile(elementwise_add, a_, b_, c_)\n",
|
||||
"elementwise_add_(a_, b_, c_)\n",
|
||||
"\n",
|
||||
"# verify correctness\n",
|
||||
"torch.testing.assert_close(c, a + b)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Average execution time: 0.0222 ms\n",
|
||||
"Throughput: 1133.58 GB/s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"benchmark(partial(elementwise_add_, a_, b_, c_), num_warmups=5, num_iterations=200)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Using Lambda Function\n",
|
||||
"\n",
|
||||
"CuTe DSL is built on top of Python. It can leverage Python to implement meta-programming to generate flexible kernels.\n",
|
||||
"E.g. we can write kernel template that take custom binary operations to generate kernels for arbitrary binary operations.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"@cute.jit\n",
|
||||
"def elementwise_apply(\n",
|
||||
" op: cutlass.Constexpr,\n",
|
||||
" mA: cute.Tensor,\n",
|
||||
" mB: cute.Tensor,\n",
|
||||
" mC: cute.Tensor\n",
|
||||
"):\n",
|
||||
" ...\n",
|
||||
"\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Tiler: (16, 256)\n",
|
||||
"TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n",
|
||||
"Tiled Input Tensors:\n",
|
||||
" gA: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
" gB: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
" gC: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
"Composed with TV layout:\n",
|
||||
" tidfrgA: !cute.memref<f16, gmem, align<16>, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def elementwise_apply_kernel(\n",
|
||||
" op: cutlass.Constexpr, # lambda function must be const expr to generate code at compile time\n",
|
||||
" gA: cute.Tensor,\n",
|
||||
" gB: cute.Tensor,\n",
|
||||
" gC: cute.Tensor,\n",
|
||||
" tv_layout: cute.Layout\n",
|
||||
"):\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
" bidx, _, _ = cute.arch.block_idx()\n",
|
||||
"\n",
|
||||
" blk_coord = ((None, None), bidx)\n",
|
||||
"\n",
|
||||
" # logical coord -> address\n",
|
||||
" blkA = gA[blk_coord] # (TileM, TileN) -> physical address\n",
|
||||
" blkB = gB[blk_coord] # (TileM, TileN) -> physical address\n",
|
||||
" blkC = gC[blk_coord] # (TileM, TileN) -> physical address\n",
|
||||
"\n",
|
||||
" tidfrgA = cute.composition(blkA, tv_layout)\n",
|
||||
" tidfrgB = cute.composition(blkB, tv_layout)\n",
|
||||
" tidfrgC = cute.composition(blkC, tv_layout)\n",
|
||||
"\n",
|
||||
" print(f\"Composed with TV layout:\")\n",
|
||||
" print(f\" tidfrgA: {tidfrgA.type}\")\n",
|
||||
"\n",
|
||||
" thr_coord = (tidx, None)\n",
|
||||
"\n",
|
||||
" # slice for threads: vid -> address\n",
|
||||
" thrA = tidfrgA[thr_coord] # (V) -> physical address\n",
|
||||
" thrB = tidfrgB[thr_coord] # (V) -> physical address\n",
|
||||
" thrC = tidfrgC[thr_coord] # (V) -> physical address\n",
|
||||
"\n",
|
||||
" #--------------------------------\n",
|
||||
" # apply custom operation\n",
|
||||
" #--------------------------------\n",
|
||||
" thrC[None] = op(thrA.load(), thrB.load())\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def elementwise_op(\n",
|
||||
" op: cutlass.Constexpr,\n",
|
||||
" mA: cute.Tensor,\n",
|
||||
" mB: cute.Tensor,\n",
|
||||
" mC: cute.Tensor,\n",
|
||||
"):\n",
|
||||
" # mA layout: (M, N):(N, 1)\n",
|
||||
" # TV layout map thread & value index to (16, 256) logical tile\n",
|
||||
" # - contiguous thread index maps to mode-1 because input layout is contiguous on\n",
|
||||
" # mode-1 for coalesced load-store\n",
|
||||
" # - each thread load 8 contiguous element each row and load 4 rows\n",
|
||||
" thr_layout = cute.make_layout((4, 32), stride=(32, 1))\n",
|
||||
" val_layout = cute.make_layout((4, 8), stride=(8, 1))\n",
|
||||
" tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n",
|
||||
" print(f\"Tiler: {tiler_mn}\")\n",
|
||||
" print(f\"TV Layout: {tv_layout}\")\n",
|
||||
"\n",
|
||||
" gA = cute.zipped_divide(mA, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||||
" gB = cute.zipped_divide(mB, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||||
" gC = cute.zipped_divide(mC, tiler_mn) # ((TileM, TileN), (RestM, RestN))\n",
|
||||
"\n",
|
||||
" print(f\"Tiled Input Tensors:\")\n",
|
||||
" print(f\" gA: {gA.type}\")\n",
|
||||
" print(f\" gB: {gB.type}\")\n",
|
||||
" print(f\" gC: {gC.type}\")\n",
|
||||
"\n",
|
||||
" # Launch the kernel asynchronously\n",
|
||||
" # Async token(s) can also be specified as dependencies\n",
|
||||
" elementwise_apply_kernel(\n",
|
||||
" op, gA, gB, gC, tv_layout\n",
|
||||
" ).launch(\n",
|
||||
" grid=[cute.size(gC, mode=[1]), 1, 1],\n",
|
||||
" block=[cute.size(tv_layout, mode=[0]), 1, 1],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||||
"\n",
|
||||
"from operator import mul\n",
|
||||
"\n",
|
||||
"elementwise_op(mul, a_, b_, c_)\n",
|
||||
"\n",
|
||||
"# verify correctness\n",
|
||||
"torch.testing.assert_close(c, mul(a, b))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Custom operators can be more complex. For example, here's a function that performs\n",
|
||||
"multiplication followed by ReLU:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Tiler: (16, 256)\n",
|
||||
"TV Layout: ((32,4),(8,4)):((128,4),(16,1))\n",
|
||||
"Tiled Input Tensors:\n",
|
||||
" gA: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
" gB: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
" gC: !cute.memref<f16, gmem, align<16>, \"((16,256),(128,8)):((2048,1),(32768,256))\">\n",
|
||||
"Composed with TV layout:\n",
|
||||
" tidfrgA: !cute.memref<f16, gmem, align<16>, \"((32,4),(8,4)):((8,8192),(1,2048))\">\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def mul_relu(a, b):\n",
|
||||
" tmp = a * b\n",
|
||||
" return cute.where(tmp > 0, tmp, cute.full_like(tmp, 0))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# As we uses cute.where in customized operation, we need to create another relu function\n",
|
||||
"def mul_relu_ref(a, b):\n",
|
||||
" tmp = a * b\n",
|
||||
" return torch.relu(tmp)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"elementwise_op(mul_relu, a_, b_, c_)\n",
|
||||
"\n",
|
||||
"# verify correctness\n",
|
||||
"torch.testing.assert_close(c, mul_relu_ref(a, b))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"widgets": {
|
||||
"application/vnd.jupyter.widget-state+json": {
|
||||
"state": {},
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
173
examples/python/CuTeDSL/notebooks/hello_world.ipynb
Normal file
173
examples/python/CuTeDSL/notebooks/hello_world.ipynb
Normal file
@@ -0,0 +1,173 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Your First Program with CuTe DSL\n",
|
||||
"\n",
|
||||
"## Introduction\n",
|
||||
"\n",
|
||||
"Welcome! In this tutorial, we'll write a simple \"Hello World\" program that runs on your GPU using CuTe DSL. This will help you understand the basics of GPU programming with our framework.\n",
|
||||
"\n",
|
||||
"### What You'll Learn\n",
|
||||
"\n",
|
||||
"- How to write code that runs on both CPU (host) and GPU (device),\n",
|
||||
"- How to launch a GPU kernel (a function that runs on the GPU),\n",
|
||||
"- Basic CUDA concepts like threads and thread blocks,\n",
|
||||
"\n",
|
||||
"### Step 1: Import Required Libraries\n",
|
||||
"\n",
|
||||
"First, let's import the libraries we need:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass \n",
|
||||
"import cutlass.cute as cute "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"### Step 2: Write Our GPU Kernel\n",
|
||||
"A GPU kernel is a function that runs on the GPU. Here's a simple kernel that prints \"Hello World\".\n",
|
||||
"Key concepts:\n",
|
||||
"- `@cute.kernel`: This decorator tells CUTLASS that this function should run on the GPU\n",
|
||||
"- `cute.arch.thread_idx()`: Gets the ID of the current GPU thread (like a worker's ID number)\n",
|
||||
"- We only want one thread to print the message (thread 0) to avoid multiple prints"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.kernel\n",
|
||||
"def kernel():\n",
|
||||
" # Get the x component of the thread index (y and z components are unused)\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
" # Only the first thread (thread 0) prints the message\n",
|
||||
" if tidx == 0:\n",
|
||||
" cute.printf(\"Hello world\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Step 3: Write Our Host Function\n",
|
||||
"\n",
|
||||
"Now we need a function that sets up the GPU and launches our kernel.\n",
|
||||
"Key concepts:\n",
|
||||
"- `@cute.jit`: This decorator is for functions that run on the CPU but can launch GPU code\n",
|
||||
"- We need to initialize CUDA before using the GPU\n",
|
||||
"- `.launch()` tells CUDA how many blocks, threads, shared memory, etc. to use"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def hello_world():\n",
|
||||
"\n",
|
||||
" # Print hello world from host code\n",
|
||||
" cute.printf(\"hello world\")\n",
|
||||
" \n",
|
||||
" # Initialize CUDA context for launching a kernel with error checking\n",
|
||||
" # We make context initialization explicit to allow users to control the context creation \n",
|
||||
" # and avoid potential issues with multiple contexts\n",
|
||||
" cutlass.cuda.initialize_cuda_context()\n",
|
||||
"\n",
|
||||
" # Launch kernel\n",
|
||||
" kernel().launch(\n",
|
||||
" grid=(1, 1, 1), # Single thread block\n",
|
||||
" block=(32, 1, 1) # One warp (32 threads) per thread block\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Step 4: Run Our Program\n",
|
||||
"\n",
|
||||
"There are 2 ways we can run our program:\n",
|
||||
"\n",
|
||||
"1. compile and run immediately\n",
|
||||
"2. separate compilation which allows us to compile the code once and run multiple times\n",
|
||||
" \n",
|
||||
"Please note the `Compiling...` for Method 2 prints before the \"Hello world\" of the first kernel. This shows the asynchronous behavior between CPU and GPU prints. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Running hello_world()...\n",
|
||||
"hello world\n",
|
||||
"Compiling...\n",
|
||||
"Hello world\n",
|
||||
"Running compiled version...\n",
|
||||
"hello world\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Method 1: Just-In-Time (JIT) compilation - compiles and runs the code immediately\n",
|
||||
"print(\"Running hello_world()...\")\n",
|
||||
"hello_world()\n",
|
||||
"\n",
|
||||
"# Method 2: Compile first (useful if you want to run the same code multiple times)\n",
|
||||
"print(\"Compiling...\")\n",
|
||||
"hello_world_compiled = cute.compile(hello_world)\n",
|
||||
"# Run the pre-compiled version\n",
|
||||
"print(\"Running compiled version...\")\n",
|
||||
"hello_world_compiled()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"widgets": {
|
||||
"application/vnd.jupyter.widget-state+json": {
|
||||
"state": {},
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
BIN
examples/python/CuTeDSL/notebooks/images/cuda_graphs_image.png
Normal file
BIN
examples/python/CuTeDSL/notebooks/images/cuda_graphs_image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.4 KiB |
425
examples/python/CuTeDSL/notebooks/print.ipynb
Normal file
425
examples/python/CuTeDSL/notebooks/print.ipynb
Normal file
@@ -0,0 +1,425 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Printing with CuTe DSL\n",
|
||||
"\n",
|
||||
"This notebook demonstrates the different ways to print values in CuTe and explains the important distinction between static (compile-time) and dynamic (runtime) values.\n",
|
||||
"\n",
|
||||
"## Key Concepts\n",
|
||||
"- Static values: Known at compile time\n",
|
||||
"- Dynamic values: Only known at runtime\n",
|
||||
"- Different printing methods for different scenarios\n",
|
||||
"- Layout representation in CuTe\n",
|
||||
"- Tensor visualization and formatting"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Print Example Function\n",
|
||||
"\n",
|
||||
"The `print_example` function demonstrates several important concepts:\n",
|
||||
"\n",
|
||||
"### 1. Python's `print` vs CuTe's `cute.printf`\n",
|
||||
"- `print`: Can only show static values at compile time\n",
|
||||
"- `cute.printf`: Can display both static and dynamic values at runtime\n",
|
||||
"\n",
|
||||
"### 2. Value Types\n",
|
||||
"- `a`: Dynamic `Int32` value (runtime)\n",
|
||||
"- `b`: Static `Constexpr[int]` value (compile-time)\n",
|
||||
"\n",
|
||||
"### 3. Layout Printing\n",
|
||||
"Shows how layouts are represented differently in static vs dynamic contexts:\n",
|
||||
"- Static context: Unknown values shown as `?`\n",
|
||||
"- Dynamic context: Actual values displayed"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def print_example(a: cutlass.Int32, b: cutlass.Constexpr[int]):\n",
|
||||
" \"\"\"\n",
|
||||
" Demonstrates different printing methods in CuTe and how they handle static vs dynamic values.\n",
|
||||
"\n",
|
||||
" This example shows:\n",
|
||||
" 1. How Python's `print` function works with static values at compile time but can't show dynamic values\n",
|
||||
" 2. How `cute.printf` can display both static and dynamic values at runtime\n",
|
||||
" 3. The difference between types in static vs dynamic contexts\n",
|
||||
" 4. How layouts are represented in both printing methods\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" a: A dynamic Int32 value that will be determined at runtime\n",
|
||||
" b: A static (compile-time constant) integer value\n",
|
||||
" \"\"\"\n",
|
||||
" # Use Python `print` to print static information\n",
|
||||
" print(\">>>\", b) # => 2\n",
|
||||
" # `a` is dynamic value\n",
|
||||
" print(\">>>\", a) # => ?\n",
|
||||
"\n",
|
||||
" # Use `cute.printf` to print dynamic information\n",
|
||||
" cute.printf(\">?? {}\", a) # => 8\n",
|
||||
" cute.printf(\">?? {}\", b) # => 2\n",
|
||||
"\n",
|
||||
" print(\">>>\", type(a)) # => <class 'cutlass.Int32'>\n",
|
||||
" print(\">>>\", type(b)) # => <class 'int'>\n",
|
||||
"\n",
|
||||
" layout = cute.make_layout((a, b))\n",
|
||||
" print(\">>>\", layout) # => (?,2):(1,?)\n",
|
||||
" cute.printf(\">?? {}\", layout) # => (8,2):(1,8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compile and Run\n",
|
||||
"\n",
|
||||
"**Direct Compilation and Run**\n",
|
||||
" - `print_example(cutlass.Int32(8), 2)`\n",
|
||||
" - Compiles and runs in one step will execute both static and dynamic print\n",
|
||||
" * `>>>` stands for static print\n",
|
||||
" * `>??` stands for dynamic print"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
">>> 2\n",
|
||||
">>> ?\n",
|
||||
">>> Int32\n",
|
||||
">>> <class 'int'>\n",
|
||||
">>> (?,2):(1,?)\n",
|
||||
">?? 8\n",
|
||||
">?? 2\n",
|
||||
">?? (8,2):(1,8)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print_example(cutlass.Int32(8), 2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Compile Function\n",
|
||||
"\n",
|
||||
"When compiles the function with `cute.compile(print_example, cutlass.Int32(8), 2)`, Python interpreter \n",
|
||||
"traces code and only evaluate static expression and print static information."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
">>> 2\n",
|
||||
">>> ?\n",
|
||||
">>> Int32\n",
|
||||
">>> <class 'int'>\n",
|
||||
">>> (?,2):(1,?)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print_example_compiled = cute.compile(print_example, cutlass.Int32(8), 2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Call compiled function\n",
|
||||
"\n",
|
||||
"Only print out runtime information"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
">?? 8\n",
|
||||
">?? 2\n",
|
||||
">?? (8,2):(1,8)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print_example_compiled(cutlass.Int32(8))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Format String Example\n",
|
||||
"\n",
|
||||
"The `format_string_example` function shows an important limitation:\n",
|
||||
"- F-strings in CuTe are evaluated at compile time\n",
|
||||
"- This means dynamic values won't show their runtime values in f-strings\n",
|
||||
"- Use `cute.printf` when you need to see runtime values"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Direct run output:\n",
|
||||
"a: ?, b: 2\n",
|
||||
"layout: (?,2):(1,?)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def format_string_example(a: cutlass.Int32, b: cutlass.Constexpr[int]):\n",
|
||||
" \"\"\"\n",
|
||||
" Format string is evaluated at compile time.\n",
|
||||
" \"\"\"\n",
|
||||
" print(f\"a: {a}, b: {b}\")\n",
|
||||
"\n",
|
||||
" layout = cute.make_layout((a, b))\n",
|
||||
" print(f\"layout: {layout}\")\n",
|
||||
"\n",
|
||||
"print(\"Direct run output:\")\n",
|
||||
"format_string_example(cutlass.Int32(8), 2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Printing Tensor Examples\n",
|
||||
"\n",
|
||||
"CuTe provides specialized functionality for printing tensors through the `print_tensor` operation. The `cute.print_tensor` takes the following parameter:\n",
|
||||
"- `Tensor` (required): A CuTe tensor object that you want to print. The tensor must support load and store operations\n",
|
||||
"- `verbose` (optional, default=False): A boolean flag that controls the level of detail in the output. When set to True, it will print indices details for each element in the tensor.\n",
|
||||
"\n",
|
||||
"Below example code shows the difference between verbose ON and OFF, and how to print a sub range of the given tensor."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cutlass.cute.runtime import from_dlpack\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def print_tensor_basic(x : cute.Tensor):\n",
|
||||
" # Print the tensor\n",
|
||||
" print(\"Basic output:\")\n",
|
||||
" cute.print_tensor(x)\n",
|
||||
" \n",
|
||||
"@cute.jit\n",
|
||||
"def print_tensor_verbose(x : cute.Tensor):\n",
|
||||
" # Print the tensor with verbose mode\n",
|
||||
" print(\"Verbose output:\")\n",
|
||||
" cute.print_tensor(x, verbose=True)\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def print_tensor_slice(x : cute.Tensor, coord : tuple):\n",
|
||||
" # slice a 2D tensor from the 3D tensor\n",
|
||||
" sliced_data = cute.slice_(x, coord)\n",
|
||||
" y = cute.make_fragment(sliced_data.layout, sliced_data.element_type)\n",
|
||||
" # Convert to TensorSSA format by loading the sliced data into the fragment\n",
|
||||
" y.store(sliced_data.load())\n",
|
||||
" print(\"Slice output:\")\n",
|
||||
" cute.print_tensor(y)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The default `cute.print_tensor` will output CuTe tensor with datatype, storage space, CuTe layout information, and print data in torch-style format."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Basic output:\n",
|
||||
"tensor(raw_ptr(0x000000000a5f1d50: f32, generic, align<4>) o (4,3,2):(6,2,1), data=\n",
|
||||
" [[[ 0.000000, 2.000000, 4.000000, ],\n",
|
||||
" [ 6.000000, 8.000000, 10.000000, ],\n",
|
||||
" [ 12.000000, 14.000000, 16.000000, ],\n",
|
||||
" [ 18.000000, 20.000000, 22.000000, ]],\n",
|
||||
"\n",
|
||||
" [[ 1.000000, 3.000000, 5.000000, ],\n",
|
||||
" [ 7.000000, 9.000000, 11.000000, ],\n",
|
||||
" [ 13.000000, 15.000000, 17.000000, ],\n",
|
||||
" [ 19.000000, 21.000000, 23.000000, ]]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def tensor_print_example1():\n",
|
||||
" shape = (4, 3, 2)\n",
|
||||
" \n",
|
||||
" # Creates [0,...,23] and reshape to (4, 3, 2)\n",
|
||||
" data = np.arange(24, dtype=np.float32).reshape(*shape) \n",
|
||||
" \n",
|
||||
" print_tensor_basic(from_dlpack(data))\n",
|
||||
"\n",
|
||||
"tensor_print_example1()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The verbosed print will show coodination details of each element in the tensor. The below example shows how we index element in a 2D 4x3 tensor space."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Verbose output:\n",
|
||||
"tensor(raw_ptr(0x000000000a814cc0: f32, generic, align<4>) o (4,3):(3,1), data= (\n",
|
||||
"\t(0,0)= 0.000000\n",
|
||||
"\t(0,1)= 1.000000\n",
|
||||
"\t(0,2)= 2.000000\n",
|
||||
"\t(1,0)= 3.000000\n",
|
||||
"\t(1,1)= 4.000000\n",
|
||||
"\t(1,2)= 5.000000\n",
|
||||
"\t(2,0)= 6.000000\n",
|
||||
"\t(2,1)= 7.000000\n",
|
||||
"\t(2,2)= 8.000000\n",
|
||||
"\t(3,0)= 9.000000\n",
|
||||
"\t(3,1)= 10.000000\n",
|
||||
"\t(3,2)= 11.000000\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def tensor_print_example2():\n",
|
||||
" shape = (4, 3)\n",
|
||||
" \n",
|
||||
" # Creates [0,...,11] and reshape to (4, 3)\n",
|
||||
" data = np.arange(12, dtype=np.float32).reshape(*shape) \n",
|
||||
" \n",
|
||||
" print_tensor_verbose(from_dlpack(data))\n",
|
||||
"\n",
|
||||
"tensor_print_example2()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To print a subset elements in the given Tensor, we can use cute.slice_ to select a range of the given tensor, load them into register and then print the values with `cute.print_tensor`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Slice output:\n",
|
||||
"tensor(raw_ptr(0x00007ffeeae1fc60: f32, rmem, align<32>) o (4):(3), data=\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [Slice output:\n",
|
||||
" 6.000000, ],\n",
|
||||
" [ 9.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00007ffeeae1fc60: f32, rmem, align<32>) o (3):(1), data=\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 4.000000, ],\n",
|
||||
" [ 5.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def tensor_print_example3():\n",
|
||||
" shape = (4, 3)\n",
|
||||
" \n",
|
||||
" # Creates [0,...,11] and reshape to (4, 3)\n",
|
||||
" data = np.arange(12, dtype=np.float32).reshape(*shape) \n",
|
||||
" \n",
|
||||
" print_tensor_slice(from_dlpack(data), (None, 0))\n",
|
||||
" print_tensor_slice(from_dlpack(data), (1, None))\n",
|
||||
"\n",
|
||||
"tensor_print_example3()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
390
examples/python/CuTeDSL/notebooks/tensor.ipynb
Normal file
390
examples/python/CuTeDSL/notebooks/tensor.ipynb
Normal file
@@ -0,0 +1,390 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tensor\n",
|
||||
"\n",
|
||||
"A tensor in CuTe is created through the composition of two key components:\n",
|
||||
"\n",
|
||||
"1. An **Engine** (E) - A random-access, pointer-like object that supports:\n",
|
||||
" - Offset operation: `e + d → e` (offset engine by elements of a layout's codomain)\n",
|
||||
" - Dereference operation: `*e → v` (dereference engine to produce value)\n",
|
||||
"\n",
|
||||
"2. A **Layout** (L) - Defines the mapping from coordinates to offsets\n",
|
||||
"\n",
|
||||
"A tensor is formally defined as the composition of an engine E with a layout L, expressed as `T = E ∘ L`. When evaluating a tensor at coordinate c, it:\n",
|
||||
"\n",
|
||||
"1. Maps the coordinate c to the codomain using the layout\n",
|
||||
"2. Offsets the engine accordingly\n",
|
||||
"3. Dereferences the result to obtain the tensor's value\n",
|
||||
"\n",
|
||||
"This can be expressed mathematically as:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"T(c) = (E ∘ L)(c) = *(E + L(c))\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"## Example Usage\n",
|
||||
"\n",
|
||||
"Here's a simple example of creating a tensor using pointer and layout `(8,5):(5,1)` and fill with ones:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def create_tensor_from_ptr(ptr: cute.Pointer):\n",
|
||||
" layout = cute.make_layout((8, 5), stride=(5, 1))\n",
|
||||
" tensor = cute.make_tensor(ptr, layout)\n",
|
||||
" tensor.fill(1)\n",
|
||||
" cute.print_tensor(tensor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This creates a tensor where:\n",
|
||||
"- The engine is a pointer\n",
|
||||
"- The layout with shape `(8, 5)` and stride `(5, 1)`\n",
|
||||
"- The resulting tensor can be evaluated using coordinates defined by the layout\n",
|
||||
"\n",
|
||||
"We can test this by allocating buffer with torch and run test with pointer to torch tensor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x000000000736b0c0: f32, generic, align<4>) o (8,5):(5,1), data=\n",
|
||||
" [[ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n",
|
||||
" [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n",
|
||||
" [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n",
|
||||
" ...\n",
|
||||
" [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n",
|
||||
" [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ],\n",
|
||||
" [ 1.000000, 1.000000, 1.000000, 1.000000, 1.000000, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from cutlass.torch import dtype as torch_dtype\n",
|
||||
"import cutlass.cute.runtime as cute_rt\n",
|
||||
"\n",
|
||||
"a = torch.randn(8, 5, dtype=torch_dtype(cutlass.Float32))\n",
|
||||
"ptr_a = cute_rt.make_ptr(cutlass.Float32, a.data_ptr())\n",
|
||||
"\n",
|
||||
"create_tensor_from_ptr(ptr_a)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## DLPACK support \n",
|
||||
"\n",
|
||||
"CuTe DSL is designed to support dlpack protocol natively. This offers easy integration with frameworks \n",
|
||||
"supporting DLPack, e.g. torch, numpy, jax, tensorflow, etc.\n",
|
||||
"\n",
|
||||
"For more information, please refer to DLPACK project: https://github.com/dmlc/dlpack\n",
|
||||
"\n",
|
||||
"Calling `from_dlpack` can convert any tensor or ndarray object supporting `__dlpack__` and `__dlpack_device__`.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cutlass.cute.runtime import from_dlpack\n",
|
||||
"\n",
|
||||
"@cute.jit\n",
|
||||
"def print_tensor_dlpack(src: cute.Tensor):\n",
|
||||
" print(src)\n",
|
||||
" cute.print_tensor(src)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor<ptr<f32, generic> o (8,5):(5,1)>\n",
|
||||
"tensor(raw_ptr(0x0000000007559340: f32, generic, align<4>) o (8,5):(5,1), data=\n",
|
||||
" [[-1.151769, 1.019397, -0.371175, -0.717776, 0.502176, ],\n",
|
||||
" [ 0.114282, 0.900084, 0.320770, 1.564574, -0.632329, ],\n",
|
||||
" [-0.570140, 0.178112, -0.423079, 1.936198, 0.003355, ],\n",
|
||||
" ...\n",
|
||||
" [-2.425393, -0.275528, 1.267157, -0.811101, -0.985456, ],\n",
|
||||
" [ 0.777889, -2.114074, 0.357184, -0.321312, -0.938138, ],\n",
|
||||
" [ 1.959564, 1.797602, 0.116901, 0.306198, -1.837295, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a = torch.randn(8, 5, dtype=torch_dtype(cutlass.Float32))\n",
|
||||
"\n",
|
||||
"print_tensor_dlpack(from_dlpack(a))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor<ptr<f32, generic> o (8,8):(8,1)>\n",
|
||||
"tensor(raw_ptr(0x0000000007979da0: f32, generic, align<4>) o (8,8):(8,1), data=\n",
|
||||
" [[ 0.122739, -0.605744, -1.442022, ..., -0.356501, -0.993329, -0.091110, ],\n",
|
||||
" [ 0.278448, 0.318482, -0.276867, ..., 1.542181, -1.701539, -0.309454, ],\n",
|
||||
" [ 0.563565, -0.753936, 0.131214, ..., 0.437912, -0.482277, -0.051540, ],\n",
|
||||
" ...\n",
|
||||
" [-1.974096, -0.177881, 0.426807, ..., -1.579115, -0.304974, 0.451164, ],\n",
|
||||
" [ 0.149851, -0.704689, -0.295063, ..., -0.653001, 0.008871, 0.903916, ],\n",
|
||||
" [ 1.188619, 1.519662, 1.270734, ..., 0.404082, 0.173200, 0.093476, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"a = np.random.randn(8, 8).astype(np.float32)\n",
|
||||
"\n",
|
||||
"print_tensor_dlpack(from_dlpack(a))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tensor Evaluation Methods\n",
|
||||
"\n",
|
||||
"Tensors support two primary methods of evaluation:\n",
|
||||
"\n",
|
||||
"### 1. Full Evaluation\n",
|
||||
"When applying the tensor evaluation with a complete coordinate c, it computes the offset, applies it to the engine, \n",
|
||||
"and dereferences it to return the stored value. This is the straightforward case where you want to access \n",
|
||||
"a specific element of the tensor.\n",
|
||||
"\n",
|
||||
"### 2. Partial Evaluation (Slicing)\n",
|
||||
"When evaluating with an incomplete coordinate c = c' ⊕ c* (where c* represents the unspecified portion), \n",
|
||||
"the result is a new tensor which is a slice of the original tensor with its engine offset to account for \n",
|
||||
"the coordinates that were provided. This operation can be expressed as:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"T(c) = (E ∘ L)(c) = (E + L(c')) ∘ L(c*) = T'(c*)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Slicing effectively reduces the dimensionality of the tensor, creating a sub-tensor that can be \n",
|
||||
"further evaluated or manipulated."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"a[2] = 10.000000 (equivalent to a[(2,0)])\n",
|
||||
"a[9] = 6.000000 (equivalent to a[(1,1)])\n",
|
||||
"a[2,0] = 10.000000\n",
|
||||
"a[2,4] = 14.000000\n",
|
||||
"a[(2,4)] = 14.000000\n",
|
||||
"a[2,3] = 100.000000\n",
|
||||
"a[(2,4)] = 101.000000\n",
|
||||
"tensor([[ 0., 1., 2., 3., 4.],\n",
|
||||
" [ 5., 6., 7., 8., 9.],\n",
|
||||
" [ 10., 11., 12., 100., 101.],\n",
|
||||
" [ 15., 16., 17., 18., 19.],\n",
|
||||
" [ 20., 21., 22., 23., 24.],\n",
|
||||
" [ 25., 26., 27., 28., 29.],\n",
|
||||
" [ 30., 31., 32., 33., 34.],\n",
|
||||
" [ 35., 36., 37., 38., 39.]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def tensor_access_item(a: cute.Tensor):\n",
|
||||
" # access data using linear index\n",
|
||||
" cute.printf(\"a[2] = {} (equivalent to a[{}])\", a[2],\n",
|
||||
" cute.make_identity_tensor(a.layout.shape)[2])\n",
|
||||
" cute.printf(\"a[9] = {} (equivalent to a[{}])\", a[9],\n",
|
||||
" cute.make_identity_tensor(a.layout.shape)[9])\n",
|
||||
"\n",
|
||||
" # access data using n-d coordinates, following two are equivalent\n",
|
||||
" cute.printf(\"a[2,0] = {}\", a[2, 0])\n",
|
||||
" cute.printf(\"a[2,4] = {}\", a[2, 4])\n",
|
||||
" cute.printf(\"a[(2,4)] = {}\", a[2, 4])\n",
|
||||
"\n",
|
||||
" # assign value to tensor@(2,4)\n",
|
||||
" a[2,3] = 100.0\n",
|
||||
" a[2,4] = 101.0\n",
|
||||
" cute.printf(\"a[2,3] = {}\", a[2,3])\n",
|
||||
" cute.printf(\"a[(2,4)] = {}\", a[(2,4)])\n",
|
||||
"\n",
|
||||
"@cute.kernel\n",
|
||||
"def print_tensor_gpu(ptr: cute.Pointer):\n",
|
||||
" layout = cute.make_layout((8, 5), stride=(5, 1))\n",
|
||||
" tensor = cute.make_tensor(ptr, layout)\n",
|
||||
"\n",
|
||||
" tidx, _, _ = cute.arch.thread_idx()\n",
|
||||
"\n",
|
||||
" if tidx == 0:\n",
|
||||
" cute.print_tensor(tensor)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create a tensor with sequential data using torch\n",
|
||||
"data = torch.arange(0, 8*5, dtype=torch.float32).reshape(8, 5)\n",
|
||||
"tensor_access_item(from_dlpack(data))\n",
|
||||
"\n",
|
||||
"print(data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Tensor as memory view\n",
|
||||
"\n",
|
||||
"In CUDA programming, different memory spaces have different characteristics in terms of access speed, scope, and lifetime:\n",
|
||||
"\n",
|
||||
"- **generic**: Default memory space that can refer to any other memory space.\n",
|
||||
"- **global memory (gmem)**: Accessible by all threads across all blocks, but has higher latency.\n",
|
||||
"- **shared memory (smem)**: Accessible by all threads within a block, with much lower latency than global memory.\n",
|
||||
"- **register memory (rmem)**: Thread-private memory with the lowest latency, but limited capacity.\n",
|
||||
"- **tensor memory (tmem)**: Specialized memory introduced in NVIDIA Blackwell architecture for tensor operations.\n",
|
||||
"\n",
|
||||
"When creating tensors in CuTe, you can specify the memory space to optimize performance based on your access patterns.\n",
|
||||
"\n",
|
||||
"For more information on CUDA memory spaces, see the [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-hierarchy).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Coordinate Tensor\n",
|
||||
"\n",
|
||||
"A coordinate tensor is a special type of tensor that maps coordinates to coordinates rather than to values. \n",
|
||||
"The key distinction is that while regular tensors map coordinates to some value type (like numbers), \n",
|
||||
"coordinate tensors map coordinates to other coordinates.\n",
|
||||
"\n",
|
||||
"For example, given a shape (4,4), a coordinate tensor using row-major layout would appear as:\n",
|
||||
"\n",
|
||||
"\\begin{bmatrix} \n",
|
||||
"(0,0) & (0,1) & (0,2) & (0,3) \\\\\n",
|
||||
"(1,0) & (1,1) & (1,2) & (1,3) \\\\\n",
|
||||
"(2,0) & (2,1) & (2,2) & (2,3) \\\\\n",
|
||||
"(3,0) & (3,1) & (3,2) & (3,3)\n",
|
||||
"\\end{bmatrix}\n",
|
||||
"\n",
|
||||
"The same shape with a column-major layout would appear as:\n",
|
||||
"\n",
|
||||
"\\begin{bmatrix}\n",
|
||||
"(0,0) & (1,0) & (2,0) & (3,0) \\\\\n",
|
||||
"(0,1) & (1,1) & (2,1) & (3,1) \\\\\n",
|
||||
"(0,2) & (1,2) & (2,2) & (3,2) \\\\\n",
|
||||
"(0,3) & (1,3) & (2,3) & (3,3)\n",
|
||||
"\\end{bmatrix}\n",
|
||||
"\n",
|
||||
"The key points about coordinate tensors are:\n",
|
||||
"- Each element in the tensor is itself a coordinate tuple (i,j) rather than a scalar value\n",
|
||||
"- The coordinates map to themselves - so position (1,2) contains the coordinate (1,2)\n",
|
||||
"- The layout (row-major vs column-major) determines how these coordinate tuples are arranged in memory\n",
|
||||
"\n",
|
||||
"For example, coordinate tensors can be created using the `make_identity_tensor` utility:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"coord_tensor = make_identity_tensor(layout.shape())\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"This creates a tensor that maps each coordinate to itself, providing a reference point for understanding how other layouts transform these coordinates."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor<(0,0) o (8,4):(1@0,1@1)>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def print_tensor_coord(a: cute.Tensor):\n",
|
||||
" coord_tensor = cute.make_identity_tensor(a.layout.shape)\n",
|
||||
" print(coord_tensor)\n",
|
||||
"\n",
|
||||
"a = torch.randn(8,4, dtype=torch_dtype(cutlass.Float32))\n",
|
||||
"print_tensor_coord(from_dlpack(a))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"widgets": {
|
||||
"application/vnd.jupyter.widget-state+json": {
|
||||
"state": {},
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
558
examples/python/CuTeDSL/notebooks/tensorssa.ipynb
Normal file
558
examples/python/CuTeDSL/notebooks/tensorssa.ipynb
Normal file
@@ -0,0 +1,558 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute\n",
|
||||
"from cutlass.cute.runtime import from_dlpack\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Introduction to the TensorSSA in CuTe DSL\n",
|
||||
"\n",
|
||||
"This tutorial introduces what is the `TensorSSA` and why we need it. We also give some examples to show how to use `TensorSSA`.\n",
|
||||
"\n",
|
||||
"## What is TensorSSA\n",
|
||||
"\n",
|
||||
"`TensorSSA` is a Python class that represents a tensor value in Static Single Assignment (SSA) form within the CuTe DSL. You can think of it as a tensor residing in a (simulated) register.\n",
|
||||
"\n",
|
||||
"## Why TensorSSA\n",
|
||||
"\n",
|
||||
"`TensorSSA` encapsulates the underlying MLIR tensor value into an object that's easier to manipulate in Python. By overloading numerous Python operators (like `+`, `-`, `*`, `/`, `[]`, etc.), it allows users to express tensor computations (primarily element-wise operations and reductions) in a more Pythonic way. These element-wise operations are then translated into optimized vectorization instructions.\n",
|
||||
"\n",
|
||||
"It's part of the CuTe DSL, serving as a bridge between the user-described computational logic and the lower-level MLIR IR, particularly for representing and manipulating register-level data.\n",
|
||||
"\n",
|
||||
"## When to use TensorSSA\n",
|
||||
"\n",
|
||||
"`TensorSSA` is primarily used in the following scenarios:\n",
|
||||
"\n",
|
||||
"### Load from memory and store to memory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"a_vec: tensor_value<vector<12xf32> o (3, 4)>\n",
|
||||
"b_vec: tensor_value<vector<12xf32> o (3, 4)>\n",
|
||||
"tensor(raw_ptr(0x0000000006cff170: f32, generic, align<4>) o (3,4):(4,1), data=\n",
|
||||
" [[ 2.000000, 2.000000, 2.000000, 2.000000, ],\n",
|
||||
" [ 2.000000, 2.000000, 2.000000, 2.000000, ],\n",
|
||||
" [ 2.000000, 2.000000, 2.000000, 2.000000, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def load_and_store(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
|
||||
" \"\"\"\n",
|
||||
" Load data from memory and store the result to memory.\n",
|
||||
"\n",
|
||||
" :param res: The destination tensor to store the result.\n",
|
||||
" :param a: The source tensor to be loaded.\n",
|
||||
" :param b: The source tensor to be loaded.\n",
|
||||
" \"\"\"\n",
|
||||
" a_vec = a.load()\n",
|
||||
" print(f\"a_vec: {a_vec}\") # prints `a_vec: vector<12xf32> o (3, 4)`\n",
|
||||
" b_vec = b.load()\n",
|
||||
" print(f\"b_vec: {b_vec}\") # prints `b_vec: vector<12xf32> o (3, 4)`\n",
|
||||
" res.store(a_vec + b_vec)\n",
|
||||
" cute.print_tensor(res)\n",
|
||||
"\n",
|
||||
"a = np.ones(12).reshape((3, 4)).astype(np.float32)\n",
|
||||
"b = np.ones(12).reshape((3, 4)).astype(np.float32)\n",
|
||||
"c = np.zeros(12).reshape((3, 4)).astype(np.float32)\n",
|
||||
"load_and_store(from_dlpack(c), from_dlpack(a), from_dlpack(b))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Register-Level Tensor Operations\n",
|
||||
"\n",
|
||||
"When writing kernel logic, various computations, transformations, slicing, etc., are performed on data loaded into registers."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor_value<vector<24xf32> o (4, 2, 3)> -> tensor_value<vector<12xf32> o (4, 3)>\n",
|
||||
"tensor(raw_ptr(0x00000000071acaf0: f32, generic, align<4>) o (4,3):(3,1), data=\n",
|
||||
" [[ 3.000000, 4.000000, 5.000000, ],\n",
|
||||
" [ 9.000000, 10.000000, 11.000000, ],\n",
|
||||
" [ 15.000000, 16.000000, 17.000000, ],\n",
|
||||
" [ 21.000000, 22.000000, 23.000000, ]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def apply_slice(src: cute.Tensor, dst: cute.Tensor, indices: cutlass.Constexpr):\n",
|
||||
" \"\"\"\n",
|
||||
" Apply slice operation on the src tensor and store the result to the dst tensor.\n",
|
||||
"\n",
|
||||
" :param src: The source tensor to be sliced.\n",
|
||||
" :param dst: The destination tensor to store the result.\n",
|
||||
" :param indices: The indices to slice the source tensor.\n",
|
||||
" \"\"\"\n",
|
||||
" src_vec = src.load()\n",
|
||||
" dst_vec = src_vec[indices]\n",
|
||||
" print(f\"{src_vec} -> {dst_vec}\")\n",
|
||||
" if isinstance(dst_vec, cute.TensorSSA):\n",
|
||||
" dst.store(dst_vec)\n",
|
||||
" cute.print_tensor(dst)\n",
|
||||
" else:\n",
|
||||
" dst[0] = dst_vec\n",
|
||||
" cute.print_tensor(dst)\n",
|
||||
"\n",
|
||||
"def slice_1():\n",
|
||||
" src_shape = (4, 2, 3)\n",
|
||||
" dst_shape = (4, 3)\n",
|
||||
" indices = (None, 1, None)\n",
|
||||
"\n",
|
||||
" \"\"\"\n",
|
||||
" a:\n",
|
||||
" [[[ 0. 1. 2.]\n",
|
||||
" [ 3. 4. 5.]]\n",
|
||||
"\n",
|
||||
" [[ 6. 7. 8.]\n",
|
||||
" [ 9. 10. 11.]]\n",
|
||||
"\n",
|
||||
" [[12. 13. 14.]\n",
|
||||
" [15. 16. 17.]]\n",
|
||||
"\n",
|
||||
" [[18. 19. 20.]\n",
|
||||
" [21. 22. 23.]]]\n",
|
||||
" \"\"\"\n",
|
||||
" a = np.arange(np.prod(src_shape)).reshape(*src_shape).astype(np.float32)\n",
|
||||
" dst = np.random.randn(*dst_shape).astype(np.float32)\n",
|
||||
" apply_slice(from_dlpack(a), from_dlpack(dst), indices)\n",
|
||||
"\n",
|
||||
"slice_1()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor_value<vector<24xf32> o (4, 2, 3)> -> ?\n",
|
||||
"tensor(raw_ptr(0x00000000013cbbe0: f32, generic, align<4>) o (1):(1), data=\n",
|
||||
" [ 10.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def slice_2():\n",
|
||||
" src_shape = (4, 2, 3)\n",
|
||||
" dst_shape = (1,)\n",
|
||||
" indices = 10\n",
|
||||
" a = np.arange(np.prod(src_shape)).reshape(*src_shape).astype(np.float32)\n",
|
||||
" dst = np.random.randn(*dst_shape).astype(np.float32)\n",
|
||||
" apply_slice(from_dlpack(a), from_dlpack(dst), indices)\n",
|
||||
"\n",
|
||||
"slice_2()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Arithmetic Operations\n",
|
||||
"\n",
|
||||
"As we mentioned earlier, there're many tensor operations whose operands are `TensorSSA`. And they are all element-wise operations. We give some examples below.\n",
|
||||
"\n",
|
||||
"### Binary Operations\n",
|
||||
"\n",
|
||||
"For binary operations, the LHS operand is `TensorSSA` and the RHS operand can be either `TensorSSA` or `Numeric`. When the RHS is `Numeric`, it will be broadcast to a `TensorSSA`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 3.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [-1.000000, ],\n",
|
||||
" [-1.000000, ],\n",
|
||||
" [-1.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 0.500000, ],\n",
|
||||
" [ 0.500000, ],\n",
|
||||
" [ 0.500000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 0.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 1.000000, ],\n",
|
||||
" [ 1.000000, ],\n",
|
||||
" [ 1.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_1(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
|
||||
" a_vec = a.load()\n",
|
||||
" b_vec = b.load()\n",
|
||||
"\n",
|
||||
" add_res = a_vec + b_vec\n",
|
||||
" res.store(add_res)\n",
|
||||
" cute.print_tensor(res) # prints [3.000000, 3.000000, 3.000000]\n",
|
||||
"\n",
|
||||
" sub_res = a_vec - b_vec\n",
|
||||
" res.store(sub_res)\n",
|
||||
" cute.print_tensor(res) # prints [-1.000000, -1.000000, -1.000000]\n",
|
||||
"\n",
|
||||
" mul_res = a_vec * b_vec\n",
|
||||
" res.store(mul_res)\n",
|
||||
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
|
||||
"\n",
|
||||
" div_res = a_vec / b_vec\n",
|
||||
" res.store(div_res)\n",
|
||||
" cute.print_tensor(res) # prints [0.500000, 0.500000, 0.500000]\n",
|
||||
"\n",
|
||||
" floor_div_res = a_vec // b_vec\n",
|
||||
" res.store(floor_div_res)\n",
|
||||
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
|
||||
"\n",
|
||||
" mod_res = a_vec % b_vec\n",
|
||||
" res.store(mod_res)\n",
|
||||
" cute.print_tensor(res) # prints [1.000000, 1.000000, 1.000000]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"a = np.empty((3,), dtype=np.float32)\n",
|
||||
"a.fill(1.0)\n",
|
||||
"b = np.empty((3,), dtype=np.float32)\n",
|
||||
"b.fill(2.0)\n",
|
||||
"res = np.empty((3,), dtype=np.float32)\n",
|
||||
"binary_op_1(from_dlpack(res), from_dlpack(a), from_dlpack(b))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 3.000000, ],\n",
|
||||
" [ 3.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [-1.000000, ],\n",
|
||||
" [-1.000000, ],\n",
|
||||
" [-1.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 0.500000, ],\n",
|
||||
" [ 0.500000, ],\n",
|
||||
" [ 0.500000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 0.000000, ],\n",
|
||||
" [ 0.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 1.000000, ],\n",
|
||||
" [ 1.000000, ],\n",
|
||||
" [ 1.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_2(res: cute.Tensor, a: cute.Tensor, c: cutlass.Constexpr):\n",
|
||||
" a_vec = a.load()\n",
|
||||
"\n",
|
||||
" add_res = a_vec + c\n",
|
||||
" res.store(add_res)\n",
|
||||
" cute.print_tensor(res) # prints [3.000000, 3.000000, 3.000000]\n",
|
||||
"\n",
|
||||
" sub_res = a_vec - c\n",
|
||||
" res.store(sub_res)\n",
|
||||
" cute.print_tensor(res) # prints [-1.000000, -1.000000, -1.000000]\n",
|
||||
"\n",
|
||||
" mul_res = a_vec * c\n",
|
||||
" res.store(mul_res)\n",
|
||||
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
|
||||
"\n",
|
||||
" div_res = a_vec / c\n",
|
||||
" res.store(div_res)\n",
|
||||
" cute.print_tensor(res) # prints [0.500000, 0.500000, 0.500000]\n",
|
||||
"\n",
|
||||
" floor_div_res = a_vec // c\n",
|
||||
" res.store(floor_div_res)\n",
|
||||
" cute.print_tensor(res) # prints [0.000000, 0.000000, 0.000000]\n",
|
||||
"\n",
|
||||
" mod_res = a_vec % c\n",
|
||||
" res.store(mod_res)\n",
|
||||
" cute.print_tensor(res) # prints [1.000000, 1.000000, 1.000000]\n",
|
||||
"\n",
|
||||
"a = np.empty((3,), dtype=np.float32)\n",
|
||||
"a.fill(1.0)\n",
|
||||
"c = 2.0\n",
|
||||
"res = np.empty((3,), dtype=np.float32)\n",
|
||||
"binary_op_2(from_dlpack(res), from_dlpack(a), c)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[False True False]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_3(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
|
||||
" a_vec = a.load()\n",
|
||||
" b_vec = b.load()\n",
|
||||
"\n",
|
||||
" gt_res = a_vec > b_vec\n",
|
||||
" res.store(gt_res)\n",
|
||||
"\n",
|
||||
" \"\"\"\n",
|
||||
" ge_res = a_ >= b_ # [False, True, False]\n",
|
||||
" lt_res = a_ < b_ # [True, False, True]\n",
|
||||
" le_res = a_ <= b_ # [True, False, True]\n",
|
||||
" eq_res = a_ == b_ # [False, False, False]\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
"a = np.array([1, 2, 3], dtype=np.float32)\n",
|
||||
"b = np.array([2, 1, 4], dtype=np.float32)\n",
|
||||
"res = np.empty((3,), dtype=np.bool_)\n",
|
||||
"binary_op_3(from_dlpack(res), from_dlpack(a), from_dlpack(b))\n",
|
||||
"print(res) # prints [False, True, False]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[3 0 7]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def binary_op_4(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
|
||||
" a_vec = a.load()\n",
|
||||
" b_vec = b.load()\n",
|
||||
"\n",
|
||||
" xor_res = a_vec ^ b_vec\n",
|
||||
" res.store(xor_res)\n",
|
||||
"\n",
|
||||
" # or_res = a_vec | b_vec\n",
|
||||
" # res.store(or_res) # prints [3, 2, 7]\n",
|
||||
"\n",
|
||||
" # and_res = a_vec & b_vec\n",
|
||||
" # res.store(and_res) # prints [0, 2, 0]\n",
|
||||
"\n",
|
||||
"a = np.array([1, 2, 3], dtype=np.int32)\n",
|
||||
"b = np.array([2, 2, 4], dtype=np.int32)\n",
|
||||
"res = np.empty((3,), dtype=np.int32)\n",
|
||||
"binary_op_4(from_dlpack(res), from_dlpack(a), from_dlpack(b))\n",
|
||||
"print(res) # prints [3, 0, 7]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Unary Operations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ],\n",
|
||||
" [ 2.000000, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [-0.756802, ],\n",
|
||||
" [-0.756802, ],\n",
|
||||
" [-0.756802, ])\n",
|
||||
"tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
|
||||
" [ 16.000000, ],\n",
|
||||
" [ 16.000000, ],\n",
|
||||
" [ 16.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def unary_op_1(res: cute.Tensor, a: cute.Tensor):\n",
|
||||
" a_vec = a.load()\n",
|
||||
"\n",
|
||||
" sqrt_res = cute.math.sqrt(a_vec)\n",
|
||||
" res.store(sqrt_res)\n",
|
||||
" cute.print_tensor(res) # prints [2.000000, 2.000000, 2.000000]\n",
|
||||
"\n",
|
||||
" sin_res = cute.math.sin(a_vec)\n",
|
||||
" res.store(sin_res)\n",
|
||||
" cute.print_tensor(res) # prints [-0.756802, -0.756802, -0.756802]\n",
|
||||
"\n",
|
||||
" exp2_res = cute.math.exp2(a_vec)\n",
|
||||
" res.store(exp2_res)\n",
|
||||
" cute.print_tensor(res) # prints [16.000000, 16.000000, 16.000000]\n",
|
||||
"\n",
|
||||
"a = np.array([4.0, 4.0, 4.0], dtype=np.float32)\n",
|
||||
"res = np.empty((3,), dtype=np.float32)\n",
|
||||
"unary_op_1(from_dlpack(res), from_dlpack(a))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Reduction Operation\n",
|
||||
"\n",
|
||||
"The `TensorSSA`'s `reduce` method applies a specified reduction operation (`ReductionOp.ADD`, `ReductionOp.MUL`, `ReductionOp.MAX`, `ReductionOp.MIN`) starting with an initial value, and performs this reduction along the dimensions specified by the `reduction_profile.`. The result is typically a new `TensorSSA` with reduced dimensions or a scalar value if reduces across all axes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"21.000000\n",
|
||||
"tensor(raw_ptr(0x00007ffd1ea2bca0: f32, rmem, align<32>) o (2):(1), data=\n",
|
||||
" [ 6.000000, ],\n",
|
||||
" [ 15.000000, ])\n",
|
||||
"tensor(raw_ptr(0x00007ffd1ea2bcc0: f32, rmem, align<32>) o (3):(1), data=\n",
|
||||
" [ 6.000000, ],\n",
|
||||
" [ 8.000000, ],\n",
|
||||
" [ 10.000000, ])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def reduction_op(a: cute.Tensor):\n",
|
||||
" \"\"\"\n",
|
||||
" Apply reduction operation on the src tensor.\n",
|
||||
"\n",
|
||||
" :param src: The source tensor to be reduced.\n",
|
||||
" \"\"\"\n",
|
||||
" a_vec = a.load()\n",
|
||||
" red_res = a_vec.reduce(\n",
|
||||
" cute.ReductionOp.ADD,\n",
|
||||
" 0.0,\n",
|
||||
" reduction_profile=0\n",
|
||||
" )\n",
|
||||
" cute.printf(red_res) # prints 21.000000\n",
|
||||
"\n",
|
||||
" red_res = a_vec.reduce(\n",
|
||||
" cute.ReductionOp.ADD,\n",
|
||||
" 0.0,\n",
|
||||
" reduction_profile=(None, 1)\n",
|
||||
" )\n",
|
||||
" # We can't print the TensorSSA directly at this point, so we store it to a new Tensor and print it.\n",
|
||||
" res = cute.make_fragment(red_res.shape, cutlass.Float32)\n",
|
||||
" res.store(red_res)\n",
|
||||
" cute.print_tensor(res) # prints [6.000000, 15.000000]\n",
|
||||
"\n",
|
||||
" red_res = a_vec.reduce(\n",
|
||||
" cute.ReductionOp.ADD,\n",
|
||||
" 1.0,\n",
|
||||
" reduction_profile=(1, None)\n",
|
||||
" )\n",
|
||||
" res = cute.make_fragment(red_res.shape, cutlass.Float32)\n",
|
||||
" res.store(red_res)\n",
|
||||
" cute.print_tensor(res) # prints [6.000000, 8.000000, 10.000000]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)\n",
|
||||
"reduction_op(from_dlpack(a))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
Reference in New Issue
Block a user