mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
Add safe fallback for untyped_storage() on older PyTorch versions
This commit is contained in:
@@ -166,6 +166,8 @@ class MscclppAlltoAllV:
|
||||
self._cached_output_size = 0
|
||||
self._cached_total_output_elems = 0
|
||||
self._cached_dtype = None
|
||||
# One-time check for untyped_storage (available since PyTorch 1.13)
|
||||
self._has_untyped_storage = hasattr(torch.Tensor, 'untyped_storage')
|
||||
# Pre-built extras dict (GPU pointers don't change)
|
||||
self._extras = {
|
||||
"sendCounts": self._d_send_counts.data_ptr(),
|
||||
@@ -293,8 +295,12 @@ class MscclppAlltoAllV:
|
||||
# Use the full underlying storage size for context key stability.
|
||||
# When the test reuses the same large tensor with different split sizes,
|
||||
# storage size stays constant → same context key → reuses channels.
|
||||
input_alloc_size = input.untyped_storage().size()
|
||||
output_alloc_size = output.untyped_storage().size()
|
||||
if self._has_untyped_storage:
|
||||
input_alloc_size = input.untyped_storage().size()
|
||||
output_alloc_size = output.untyped_storage().size()
|
||||
else:
|
||||
input_alloc_size = input.nelement() * input.element_size()
|
||||
output_alloc_size = output.nelement() * output.element_size()
|
||||
|
||||
if _DEBUG:
|
||||
# Clear stale CUDA errors (the C++ code checks cudaGetLastError
|
||||
@@ -304,6 +310,7 @@ class MscclppAlltoAllV:
|
||||
if _last_err != 0:
|
||||
print(f" [rank {self._rank}] WARNING: cleared stale CUDA error code {_last_err} before execute", flush=True)
|
||||
print(f" [rank {self._rank}] alltoallv: calling algo.execute(input_alloc={input_alloc_size}, output_alloc={output_alloc_size})", flush=True)
|
||||
|
||||
result = self._algo.execute(
|
||||
self._comm,
|
||||
input.data_ptr(),
|
||||
@@ -318,6 +325,7 @@ class MscclppAlltoAllV:
|
||||
0, # nthreads_per_block (auto)
|
||||
self._extras,
|
||||
)
|
||||
|
||||
if _DEBUG:
|
||||
print(f" [rank {self._rank}] alltoallv: algo.execute returned {result}", flush=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user