diff --git a/python/examples/cute_dsl_sgemm.py b/python/examples/cute_dsl_sgemm.py index 61ac2b3..eae3758 100644 --- a/python/examples/cute_dsl_sgemm.py +++ b/python/examples/cute_dsl_sgemm.py @@ -642,7 +642,24 @@ def cutlass_gemm(state: bench.State) -> None: state.exec(launcher) +def patch_cute_dsl(): + def _no_op_diagnostic(self): + return + + try: + import cutlass.base_dsl.dsl as dsl_m + + base_dsl_k = dsl_m.BaseDSL + if hasattr(base_dsl_k, "diagnostic"): + base_dsl_k.diagnostic = _no_op_diagnostic + except (ModuleNotFoundError, AttributeError): + pass + + if __name__ == "__main__": + # see https://github.com/NVIDIA/cutlass/issues/3142 + patch_cute_dsl() + gemm_b = bench.register(cutlass_gemm) gemm_b.add_int64_axis("R", [16, 64, 256]) gemm_b.add_int64_axis("N", [256, 512, 1024, 2048])