Fix mypy error in cutlass_gemm example

This commit is contained in:
Oleksandr Pavlyk
2025-07-24 10:30:31 -05:00
parent 5428534124
commit 5c01c34793

View File

@@ -34,9 +34,11 @@ def as_core_Stream(cs: nvbench.CudaStream) -> core.Stream:
return core.Stream.from_handle(cs.addressof())
def make_cp_array(arr_h: np.ndarray, dev_buf: core.Buffer, dev_id: int) -> cp.ndarray:
def make_cp_array(
arr_h: np.ndarray, dev_buf: core.Buffer, dev_id: int | None
) -> cp.ndarray:
cp_memview = cp.cuda.UnownedMemory(
int(dev_buf.handle), dev_buf.size, dev_buf, dev_id
int(dev_buf.handle), dev_buf.size, dev_buf, -1 if dev_id is None else dev_id
)
zero_offset = 0
return cp.ndarray(