Allow obtaining cuda stream handle from PyTorch stream when launching kernel (#297)

Use `cuda_stream` attribute of a torch stream if the stream is not an
instance of the cupy stream.
This commit is contained in:
aashaka
2024-05-03 21:57:07 -07:00
committed by GitHub
parent 6226556ce2
commit 0650371b54

View File

@@ -50,7 +50,9 @@ class Kernel:
],
dtype=np.uint64,
)
cuda_stream = stream.ptr if stream else 0
cuda_stream = 0
if stream:
cuda_stream = stream.ptr if isinstance(stream, cp.cuda.Stream) else stream.cuda_stream
cp.cuda.driver.launchKernel(
self._kernel, nblocks, 1, 1, nthreads, 1, 1, shared, cuda_stream, 0, config.ctypes.data
)