diff --git a/python/examples/cutlass_gemm.py b/python/examples/cutlass_gemm.py index bba8633..1675d0c 100644 --- a/python/examples/cutlass_gemm.py +++ b/python/examples/cutlass_gemm.py @@ -34,9 +34,11 @@ def as_core_Stream(cs: nvbench.CudaStream) -> core.Stream: return core.Stream.from_handle(cs.addressof()) -def make_cp_array(arr_h: np.ndarray, dev_buf: core.Buffer, dev_id: int) -> cp.ndarray: +def make_cp_array( + arr_h: np.ndarray, dev_buf: core.Buffer, dev_id: int | None +) -> cp.ndarray: cp_memview = cp.cuda.UnownedMemory( - int(dev_buf.handle), dev_buf.size, dev_buf, dev_id + int(dev_buf.handle), dev_buf.size, dev_buf, -1 if dev_id is None else dev_id ) zero_offset = 0 return cp.ndarray(