mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-19 22:38:52 +00:00
Replace uses of deprecated typing.Tuple, typing.Callable, etc.
Also use typing.Self to encode that `Benchmark.addInt64Axis` returns self.
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
from typing import Callable, Sequence, Tuple
|
||||
# from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Optional, Self
|
||||
|
||||
class CudaStream:
|
||||
"""Represents CUDA stream
|
||||
@@ -7,7 +10,7 @@ class CudaStream:
|
||||
----
|
||||
The class is not directly constructible.
|
||||
"""
|
||||
def __cuda_stream__(self) -> Tuple[int]:
|
||||
def __cuda_stream__(self) -> tuple[int]:
|
||||
"""
|
||||
Special method implement CUDA stream protocol
|
||||
from `cuda.core`. Returns a pair of integers:
|
||||
@@ -31,13 +34,13 @@ class Benchmark:
|
||||
def getName(self) -> str:
|
||||
"Get benchmark name"
|
||||
...
|
||||
def addInt64Axis(self, name: str, values: Sequence[int]) -> Benchmark:
|
||||
def addInt64Axis(self, name: str, values: Sequence[int]) -> Self:
|
||||
"Add integral type parameter axis with given name and values to sweep over"
|
||||
...
|
||||
def addFloat64Axis(self, name: str, values: Sequence[float]) -> Benchmark:
|
||||
def addFloat64Axis(self, name: str, values: Sequence[float]) -> Self:
|
||||
"Add floating-point type parameter axis with given name and values to sweep over"
|
||||
...
|
||||
def addStringAxis(sef, name: str, values: Sequence[str]) -> Benchmark:
|
||||
def addStringAxis(sef, name: str, values: Sequence[str]) -> Self:
|
||||
"Add string type parameter axis with given name and values to sweep over"
|
||||
...
|
||||
|
||||
@@ -68,16 +71,16 @@ class State:
|
||||
def getStream(self) -> CudaStream:
|
||||
"CudaStream object from this configuration"
|
||||
...
|
||||
def getInt64(self, name: str, default_value: int = None) -> int:
|
||||
def getInt64(self, name: str, default_value: Optional[int] = None) -> int:
|
||||
"Get value for given Int64 axis from this configuration"
|
||||
...
|
||||
def getFloat64(self, name: str, default_value: float = None) -> float:
|
||||
def getFloat64(self, name: str, default_value: Optional[float] = None) -> float:
|
||||
"Get value for given Float64 axis from this configuration"
|
||||
...
|
||||
def getString(self, name: str, default_value: str = None) -> str:
|
||||
def getString(self, name: str, default_value: Optional[str] = None) -> str:
|
||||
"Get value for given String axis from this configuration"
|
||||
...
|
||||
def addElementCount(self, count: int, column_name: str = None) -> None:
|
||||
def addElementCount(self, count: int, column_name: Optional[str] = None) -> None:
|
||||
"Add element count"
|
||||
...
|
||||
def setElementCount(self, count: int) -> None:
|
||||
@@ -145,7 +148,10 @@ class State:
|
||||
"True if (some) CUPTI metrics are being collected"
|
||||
...
|
||||
def exec(
|
||||
self, fn: Callable[[Launch], None], batched: bool = True, sync: bool = False
|
||||
self,
|
||||
fn: Callable[[Launch], None],
|
||||
batched: Optional[bool] = True,
|
||||
sync: Optional[bool] = False,
|
||||
):
|
||||
"""Execute callable running the benchmark.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user