mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-05-12 09:15:47 +00:00
Add patch to cutlass.base_dsl.dsl.BaseDSL to work-around a bug
See https://github.com/NVIDIA/cutlass/issues/3142
This commit is contained in:
@@ -642,7 +642,24 @@ def cutlass_gemm(state: bench.State) -> None:
|
||||
state.exec(launcher)
|
||||
|
||||
|
||||
def patch_cute_dsl():
|
||||
def _no_op_diagnostic(self):
|
||||
return
|
||||
|
||||
try:
|
||||
import cutlass.base_dsl.dsl as dsl_m
|
||||
|
||||
base_dsl_k = dsl_m.BaseDSL
|
||||
if hasattr(base_dsl_k, "diagnostic"):
|
||||
base_dsl_k.diagnostic = _no_op_diagnostic
|
||||
except (ModuleNotFoundError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# see https://github.com/NVIDIA/cutlass/issues/3142
|
||||
patch_cute_dsl()
|
||||
|
||||
gemm_b = bench.register(cutlass_gemm)
|
||||
gemm_b.add_int64_axis("R", [16, 64, 256])
|
||||
gemm_b.add_int64_axis("N", [256, 512, 1024, 2048])
|
||||
|
||||
Reference in New Issue
Block a user