Add implementation of and signature for State.getDevice

make batch/sync arguments of State.exec keyword-only

Provide default column_name value for State.addElementCount method,
so that it can be called state.addElementCount(count), or as
state.addElementCount(count, column_name="Descriptive Name")
This commit is contained in:
Oleksandr Pavlyk
2025-07-02 10:32:10 -05:00
parent 2507bc2263
commit 576c473481
2 changed files with 30 additions and 2 deletions

View File

@@ -1,7 +1,7 @@
# from __future__ import annotations
from collections.abc import Callable, Sequence
from typing import Optional, Self
from typing import Optional, Self, Union
class CudaStream:
"""Represents CUDA stream
@@ -15,6 +15,18 @@ class CudaStream:
Special method implement CUDA stream protocol
from `cuda.core`. Returns a pair of integers:
(protocol_version, integral_value_of_cudaStream_t pointer)
Example
-------
import cuda.core.experimental as core
import cuda.nvbench as nvbench
def bench(state: nvbench.State):
dev = core.Device(state.getDevice())
dev.set_current()
# converts CudaString to core.Stream
# using __cuda_stream__ protocol
dev.create_stream(state.getStream())
"""
...
@@ -68,6 +80,9 @@ class State:
def hasPrinters(self) -> bool:
"True if configuration has a printer"
...
def getDevice(self) -> Union[int, None]:
"Get device_id of the device from this configuration"
...
def getStream(self) -> CudaStream:
"CudaStream object from this configuration"
...
@@ -150,6 +165,8 @@ class State:
def exec(
self,
fn: Callable[[Launch], None],
/,
*,
batched: Optional[bool] = True,
sync: Optional[bool] = False,
):

View File

@@ -344,6 +344,14 @@ PYBIND11_MODULE(_nvbench, m)
pystate_cls.def("hasPrinters", [](nvbench::state &state) -> bool {
return state.get_benchmark().get_printer().has_value();
});
pystate_cls.def("getDevice", [](nvbench::state &state) {
auto dev = state.get_device();
if (dev.has_value())
{
return py::cast(dev.value().get_id());
}
return py::object(py::none());
});
pystate_cls.def(
"getStream",
@@ -359,7 +367,10 @@ PYBIND11_MODULE(_nvbench, m)
pystate_cls.def("getString", &nvbench::state::get_string);
pystate_cls.def("getString", &nvbench::state::get_string_or_default);
pystate_cls.def("addElementCount", &nvbench::state::add_element_count);
pystate_cls.def("addElementCount",
&nvbench::state::add_element_count,
py::arg("count"),
py::arg("column_name") = py::str(""));
pystate_cls.def("setElementCount", &nvbench::state::set_element_count);
pystate_cls.def("getElementCount", &nvbench::state::get_element_count);