mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-20 06:48:53 +00:00
cuda.nvbench -> cuda.bench
Per PR review suggestion: - `cuda.parallel` - device-wide algorithms/Thrust - `cuda.cooperative` - Cooperative algorithsm/CUB - `cuda.bench` - Benchmarking/NVBench
This commit is contained in:
@@ -17,19 +17,19 @@
|
||||
|
||||
import sys
|
||||
|
||||
import cuda.bench as bench
|
||||
import cuda.bindings.driver as driver
|
||||
import cuda.core.experimental as core
|
||||
import cuda.nvbench as nvbench
|
||||
import cupy as cp
|
||||
import cutlass
|
||||
import numpy as np
|
||||
|
||||
|
||||
def as_bindings_Stream(cs: nvbench.CudaStream) -> driver.CUstream:
|
||||
def as_bindings_Stream(cs: bench.CudaStream) -> driver.CUstream:
|
||||
return driver.CUstream(cs.addressof())
|
||||
|
||||
|
||||
def as_core_Stream(cs: nvbench.CudaStream) -> core.Stream:
|
||||
def as_core_Stream(cs: bench.CudaStream) -> core.Stream:
|
||||
return core.Stream.from_handle(cs.addressof())
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ def make_cp_array(
|
||||
)
|
||||
|
||||
|
||||
def cutlass_gemm(state: nvbench.State) -> None:
|
||||
def cutlass_gemm(state: bench.State) -> None:
|
||||
n = state.get_int64("N")
|
||||
r = state.get_int64("R")
|
||||
|
||||
@@ -96,7 +96,7 @@ def cutlass_gemm(state: nvbench.State) -> None:
|
||||
# warm-up to ensure compilation is not timed
|
||||
plan.run(stream=s)
|
||||
|
||||
def launcher(launch: nvbench.Launch) -> None:
|
||||
def launcher(launch: bench.Launch) -> None:
|
||||
s = as_bindings_Stream(launch.get_stream())
|
||||
plan.run(stream=s, sync=False)
|
||||
|
||||
@@ -104,10 +104,10 @@ def cutlass_gemm(state: nvbench.State) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gemm_b = nvbench.register(cutlass_gemm)
|
||||
gemm_b = bench.register(cutlass_gemm)
|
||||
gemm_b.add_int64_axis("R", [16, 64, 256])
|
||||
gemm_b.add_int64_axis("N", [256, 512, 1024, 2048])
|
||||
|
||||
gemm_b.add_float64_axis("alpha", [1e-2])
|
||||
|
||||
nvbench.run_all_benchmarks(sys.argv)
|
||||
bench.run_all_benchmarks(sys.argv)
|
||||
|
||||
Reference in New Issue
Block a user