mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
fix performance inssues in cute-dsl examples for 4.4-ctk13.1 release (#2988)
* fix grouped gemm * fix mixed input gemm * fix mixed input grouped gemm * fix version checking * use advanced compiler options * fix comment * rename advanced compiler configs to adcanced compiler control * fix comment * fix name * fix name
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
wěÎťĐM“I)ćŕ ×fĂ7Ý÷W>Ó¸Xęuh/
|
||||
¶Ž
|
||||
@@ -2076,6 +2076,17 @@ def run(
|
||||
# Initialize Stream
|
||||
current_stream = cutlass_torch.default_stream()
|
||||
|
||||
# try to check CUDA version to decide the opt level
|
||||
try:
|
||||
from cutlass import CUDA_VERSION
|
||||
opt_level = (
|
||||
3
|
||||
if CUDA_VERSION.major < 13
|
||||
or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor < 1)
|
||||
else 2
|
||||
)
|
||||
except ImportError:
|
||||
opt_level = 3
|
||||
# Compile grouped GEMM kernel
|
||||
compiled_grouped_gemm = cute.compile(
|
||||
grouped_gemm,
|
||||
@@ -2090,6 +2101,7 @@ def run(
|
||||
tensor_of_tensormap,
|
||||
max_active_clusters,
|
||||
current_stream,
|
||||
options=f"--opt-level {opt_level}",
|
||||
)
|
||||
|
||||
if not skip_ref_check:
|
||||
|
||||
@@ -393,6 +393,7 @@ class GroupedMixedInputGemmKernel:
|
||||
|
||||
self.smem_buffer_align_bytes = 1024
|
||||
|
||||
|
||||
def _setup_attributes(self):
|
||||
"""Set up configurations that are dependent on GEMM inputs
|
||||
|
||||
@@ -2859,6 +2860,33 @@ def compare(
|
||||
torch.testing.assert_close(kernel_result, ref_result, atol=tolerance, rtol=1e-05)
|
||||
|
||||
|
||||
def get_advanced_compiler_control_path():
|
||||
"""
|
||||
Return the path to the advanced compiler control file of this example. If not found, return None.
|
||||
"""
|
||||
import os
|
||||
|
||||
need_advanced_compiler_control = False
|
||||
try:
|
||||
from cutlass import CUDA_VERSION
|
||||
|
||||
if CUDA_VERSION.major == 13 and CUDA_VERSION.minor == 1:
|
||||
need_advanced_compiler_control = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if not need_advanced_compiler_control:
|
||||
return None
|
||||
# Get the path to the advanced compiler control file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
target_path = os.path.join(current_dir, "../advanced_compiler_control/gemm0.bin")
|
||||
if os.path.exists(target_path):
|
||||
print(f"Found advanced compiler control file at {target_path}")
|
||||
return target_path
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def run(
|
||||
mnkl: tuple[int, int, int, int],
|
||||
scale_granularity_m: int,
|
||||
@@ -2955,6 +2983,12 @@ def run(
|
||||
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1],
|
||||
)
|
||||
advanced_compiler_options = None
|
||||
advanced_compiler_control_path = get_advanced_compiler_control_path()
|
||||
if advanced_compiler_control_path:
|
||||
advanced_compiler_options = (
|
||||
f"--ptxas-options '--apply-controls={advanced_compiler_control_path}'"
|
||||
)
|
||||
compiled_kernel = cute.compile(
|
||||
mixed_input_gemm,
|
||||
a_tensor,
|
||||
@@ -2964,6 +2998,7 @@ def run(
|
||||
c_tensor,
|
||||
max_active_clusters,
|
||||
current_stream,
|
||||
options=advanced_compiler_options,
|
||||
)
|
||||
|
||||
if not skip_ref_check:
|
||||
|
||||
@@ -2476,6 +2476,17 @@ def run(
|
||||
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1],
|
||||
)
|
||||
# try to check CUDA version to decide the opt level
|
||||
try:
|
||||
from cutlass import CUDA_VERSION
|
||||
opt_level = (
|
||||
3
|
||||
if CUDA_VERSION.major < 13
|
||||
or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor < 1)
|
||||
else 2
|
||||
)
|
||||
except ImportError:
|
||||
opt_level = 3
|
||||
compiled_kernel = cute.compile(
|
||||
mixed_input_gemm,
|
||||
a_tensor,
|
||||
@@ -2484,6 +2495,7 @@ def run(
|
||||
c_tensor,
|
||||
max_active_clusters,
|
||||
current_stream,
|
||||
options=f"--opt-level {opt_level}",
|
||||
)
|
||||
|
||||
if not skip_ref_check:
|
||||
|
||||
Reference in New Issue
Block a user