mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-20 06:48:53 +00:00
Add implementation of and signature for State.getDevice
make batch/sync arguments of State.exec keyword-only Provide default column_name value for State.addElementCount method, so that it can be called state.addElementCount(count), or as state.addElementCount(count, column_name="Descriptive Name")
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Optional, Self
|
||||
from typing import Optional, Self, Union
|
||||
|
||||
class CudaStream:
|
||||
"""Represents CUDA stream
|
||||
@@ -15,6 +15,18 @@ class CudaStream:
|
||||
Special method implement CUDA stream protocol
|
||||
from `cuda.core`. Returns a pair of integers:
|
||||
(protocol_version, integral_value_of_cudaStream_t pointer)
|
||||
|
||||
Example
|
||||
-------
|
||||
import cuda.core.experimental as core
|
||||
import cuda.nvbench as nvbench
|
||||
|
||||
def bench(state: nvbench.State):
|
||||
dev = core.Device(state.getDevice())
|
||||
dev.set_current()
|
||||
# converts CudaString to core.Stream
|
||||
# using __cuda_stream__ protocol
|
||||
dev.create_stream(state.getStream())
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -68,6 +80,9 @@ class State:
|
||||
def hasPrinters(self) -> bool:
|
||||
"True if configuration has a printer"
|
||||
...
|
||||
def getDevice(self) -> Union[int, None]:
|
||||
"Get device_id of the device from this configuration"
|
||||
...
|
||||
def getStream(self) -> CudaStream:
|
||||
"CudaStream object from this configuration"
|
||||
...
|
||||
@@ -150,6 +165,8 @@ class State:
|
||||
def exec(
|
||||
self,
|
||||
fn: Callable[[Launch], None],
|
||||
/,
|
||||
*,
|
||||
batched: Optional[bool] = True,
|
||||
sync: Optional[bool] = False,
|
||||
):
|
||||
|
||||
@@ -344,6 +344,14 @@ PYBIND11_MODULE(_nvbench, m)
|
||||
pystate_cls.def("hasPrinters", [](nvbench::state &state) -> bool {
|
||||
return state.get_benchmark().get_printer().has_value();
|
||||
});
|
||||
pystate_cls.def("getDevice", [](nvbench::state &state) {
|
||||
auto dev = state.get_device();
|
||||
if (dev.has_value())
|
||||
{
|
||||
return py::cast(dev.value().get_id());
|
||||
}
|
||||
return py::object(py::none());
|
||||
});
|
||||
|
||||
pystate_cls.def(
|
||||
"getStream",
|
||||
@@ -359,7 +367,10 @@ PYBIND11_MODULE(_nvbench, m)
|
||||
pystate_cls.def("getString", &nvbench::state::get_string);
|
||||
pystate_cls.def("getString", &nvbench::state::get_string_or_default);
|
||||
|
||||
pystate_cls.def("addElementCount", &nvbench::state::add_element_count);
|
||||
pystate_cls.def("addElementCount",
|
||||
&nvbench::state::add_element_count,
|
||||
py::arg("count"),
|
||||
py::arg("column_name") = py::str(""));
|
||||
pystate_cls.def("setElementCount", &nvbench::state::set_element_count);
|
||||
pystate_cls.def("getElementCount", &nvbench::state::get_element_count);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user