Fix type annotations in cuda.nvbench, and in examples

This commit is contained in:
Oleksandr Pavlyk
2025-07-22 13:02:22 -05:00
parent 13ad115ca3
commit a535a1d173
6 changed files with 52 additions and 43 deletions

View File

@@ -15,7 +15,6 @@
# limitations under the License.
import sys
from collections.abc import Callable
import cuda.nvbench as nvbench
import numpy as np
@@ -26,7 +25,7 @@ def as_cuda_Stream(cs: nvbench.CudaStream) -> cuda.cudadrv.driver.Stream:
return cuda.external_stream(cs.addressof())
def make_kernel(items_per_thread: int) -> Callable:
def make_kernel(items_per_thread: int) -> cuda.compiler.AutoJitCUDAKernel:
@cuda.jit
def kernel(stride: np.uintp, elements: np.uintp, in_arr, out_arr):
tid = cuda.grid(1)