cuda.nvbench -> cuda.bench

Per PR review suggestion:
   - `cuda.parallel`    - device-wide algorithms/Thrust
   - `cuda.cooperative` - Cooperative algorithsm/CUB
   - `cuda.bench`       - Benchmarking/NVBench
This commit is contained in:
Oleksandr Pavlyk
2025-08-04 13:42:43 -05:00
parent c2a2acc9b6
commit b5e4b4ba31
19 changed files with 136 additions and 140 deletions

View File

@@ -16,10 +16,10 @@
import sys
import cuda.bench as bench
import cuda.cccl.parallel.experimental.algorithms as algorithms
import cuda.cccl.parallel.experimental.iterators as iterators
import cuda.core.experimental as core
import cuda.nvbench as nvbench
import cupy as cp
import numpy as np
@@ -34,22 +34,22 @@ class CCCLStream:
return (0, self._ptr)
def as_core_Stream(cs: nvbench.CudaStream) -> core.Stream:
def as_core_Stream(cs: bench.CudaStream) -> core.Stream:
return core.Stream.from_handle(cs.addressof())
def as_cccl_Stream(cs: nvbench.CudaStream) -> CCCLStream:
def as_cccl_Stream(cs: bench.CudaStream) -> CCCLStream:
return CCCLStream(cs.addressof())
def as_cp_ExternalStream(
cs: nvbench.CudaStream, dev_id: int | None = -1
cs: bench.CudaStream, dev_id: int | None = -1
) -> cp.cuda.ExternalStream:
h = cs.addressof()
return cp.cuda.ExternalStream(h, dev_id)
def segmented_reduce(state: nvbench.State):
def segmented_reduce(state: bench.State):
"Benchmark segmented_reduce example"
n_elems = state.get_int64("numElems")
n_cols = state.get_int64("numCols")
@@ -100,7 +100,7 @@ def segmented_reduce(state: nvbench.State):
with cp_stream:
temp_storage = cp.empty(temp_nbytes, dtype=cp.uint8)
def launcher(launch: nvbench.Launch):
def launcher(launch: bench.Launch):
s = as_cccl_Stream(launch.get_stream())
alg(
temp_storage,
@@ -117,8 +117,8 @@ def segmented_reduce(state: nvbench.State):
if __name__ == "__main__":
b = nvbench.register(segmented_reduce)
b = bench.register(segmented_reduce)
b.add_int64_axis("numElems", [2**20, 2**22, 2**24])
b.add_int64_axis("numCols", [1024, 2048, 4096, 8192])
nvbench.run_all_benchmarks(sys.argv)
bench.run_all_benchmarks(sys.argv)