Add patch to cutlass.base_dsl.dsl.BaseDSL to work-around a bug

See https://github.com/NVIDIA/cutlass/issues/3142
This commit is contained in:
Oleksandr Pavlyk
2026-04-02 10:29:31 -05:00
parent 93bc59d05c
commit 9f75642387

View File

@@ -642,7 +642,24 @@ def cutlass_gemm(state: bench.State) -> None:
state.exec(launcher)
def patch_cute_dsl():
def _no_op_diagnostic(self):
return
try:
import cutlass.base_dsl.dsl as dsl_m
base_dsl_k = dsl_m.BaseDSL
if hasattr(base_dsl_k, "diagnostic"):
base_dsl_k.diagnostic = _no_op_diagnostic
except (ModuleNotFoundError, AttributeError):
pass
if __name__ == "__main__":
# see https://github.com/NVIDIA/cutlass/issues/3142
patch_cute_dsl()
gemm_b = bench.register(cutlass_gemm)
gemm_b.add_int64_axis("R", [16, 64, 256])
gemm_b.add_int64_axis("N", [256, 512, 1024, 2048])