mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Allow obtaining cuda stream handle from PyTorch stream when launching kernel (#297)
Use `cuda_stream` attribute of a torch stream if the stream is not an instance of the cupy stream.
This commit is contained in:
@@ -50,7 +50,9 @@ class Kernel:
|
||||
],
|
||||
dtype=np.uint64,
|
||||
)
|
||||
cuda_stream = stream.ptr if stream else 0
|
||||
cuda_stream = 0
|
||||
if stream:
|
||||
cuda_stream = stream.ptr if isinstance(stream, cp.cuda.Stream) else stream.cuda_stream
|
||||
cp.cuda.driver.launchKernel(
|
||||
self._kernel, nblocks, 1, 1, nthreads, 1, 1, shared, cuda_stream, 0, config.ctypes.data
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user