mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-03-14 20:27:24 +00:00
Fix mypy error in cutlass_gemm example
This commit is contained in:
@@ -34,9 +34,11 @@ def as_core_Stream(cs: nvbench.CudaStream) -> core.Stream:
|
||||
return core.Stream.from_handle(cs.addressof())
|
||||
|
||||
|
||||
def make_cp_array(arr_h: np.ndarray, dev_buf: core.Buffer, dev_id: int) -> cp.ndarray:
|
||||
def make_cp_array(
|
||||
arr_h: np.ndarray, dev_buf: core.Buffer, dev_id: int | None
|
||||
) -> cp.ndarray:
|
||||
cp_memview = cp.cuda.UnownedMemory(
|
||||
int(dev_buf.handle), dev_buf.size, dev_buf, dev_id
|
||||
int(dev_buf.handle), dev_buf.size, dev_buf, -1 if dev_id is None else dev_id
|
||||
)
|
||||
zero_offset = 0
|
||||
return cp.ndarray(
|
||||
|
||||
Reference in New Issue
Block a user