mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-19 22:38:52 +00:00
Fix type annotations in cuda.nvbench, and in examples
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
|
||||
import cuda.nvbench as nvbench
|
||||
import numpy as np
|
||||
@@ -26,7 +25,7 @@ def as_cuda_Stream(cs: nvbench.CudaStream) -> cuda.cudadrv.driver.Stream:
|
||||
return cuda.external_stream(cs.addressof())
|
||||
|
||||
|
||||
def make_kernel(items_per_thread: int) -> Callable:
|
||||
def make_kernel(items_per_thread: int) -> cuda.compiler.AutoJitCUDAKernel:
|
||||
@cuda.jit
|
||||
def kernel(stride: np.uintp, elements: np.uintp, in_arr, out_arr):
|
||||
tid = cuda.grid(1)
|
||||
|
||||
Reference in New Issue
Block a user