Merge branch 'main' into copilot/remove-gtest-use-custom-framework

This commit is contained in:
Changho Hwang
2026-03-03 15:49:01 -08:00
committed by GitHub
7 changed files with 173 additions and 3 deletions

View File

@@ -468,3 +468,132 @@ stream_handle = torch.cuda.current_stream().cuda_stream
All examples are in [`examples/torch-integration/`](../../examples/torch-integration/).
---
## Performance Tuning
The default algorithms use a fixed heuristic to select algorithms based on message size. For production workloads, you can achieve significantly better performance by **auto-tuning** — benchmarking every candidate algorithm, block count, and thread count for each message size at startup, then using the fastest configuration at runtime.
**Full example:** [customized_comm_with_tuning.py](../../examples/torch-integration/customized_comm_with_tuning.py)
### How It Works
1. **Candidate selection** — For each power-of-two message size from 1 KB to 128 MB, the tuner picks the applicable algorithms:
- Small messages (≤ 4 MB): `default_allreduce_nvls_packet`, `default_allreduce_packet`
- Large messages (≥ 512 KB): `default_allreduce_rsag_zero_copy`
- Overlapping sizes get all three candidates.
2. **Grid search** — Each candidate is run with every combination of block counts (`4, 8, 16, … 128`) and thread counts (`512, 768, 1024`). Results are captured in a CUDA graph and timed.
3. **Cross-rank consensus** — Elapsed times are averaged across all ranks with an allreduce so every GPU selects the same configuration.
4. **Runtime dispatch** — `get_tuned_config()` rounds the actual message size up to the next power of two and returns the winning `(algorithm, nblocks, nthreads)` triple.
### Loading Candidate Algorithms
The same `load_algorithms` helper from Approach 1 is reused. The tuner extracts multiple algorithm objects:
```python
algorithms = load_algorithms(scratch_buffer=self.scratch_buffer, rank=self.rank)
self._algorithm_nvls_packet = [
algo for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_packet"
][0]
self._algorithm_rsag_zero_copy = [
algo for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_rsag_zero_copy"
][0]
self._algorithm_packet = [
algo for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_packet"
][0]
```
### The Tuning Loop
The tuning loop iterates over message sizes, candidate algorithms, and kernel launch parameters. CUDA graphs are used for accurate timing:
```python
def _tune(self, n_warmup, n_graph_launches, n_ops_per_graph):
sizes = [1 << i for i in range(10, 28)]
self.best_configs = {1024: (self._algorithm_nvls_packet, 0, 0)}
tune_tensor = torch.rand(1 << 27, dtype=torch.float16, device="cuda")
candidates_nblocks = [4, 8, 16, 24, 32, 48, 64, 128]
candidates_nthreads = [512, 768, 1024]
for size in sizes:
algos = []
if size <= 4 * 1024 * 1024:
algos.append(self._algorithm_nvls_packet)
algos.append(self._algorithm_packet)
if size >= 512 * 1024:
algos.append(self._algorithm_rsag_zero_copy)
best_time = float("inf")
best_config = None
for algo in algos:
for nb in candidates_nblocks:
for nt in candidates_nthreads:
if self._run_algo(algo, tune_tensor, size, nb, nt) != 0:
continue # skip unsupported configs
# Warmup, then time with CUDA graphs
# ... (see full example for graph capture logic)
# Average timing across ranks
time_tensor = torch.full(
(self.world_size,), elapsed, dtype=torch.float64, device="cuda"
).to(dtype=torch.float32)
self.all_reduce(time_tensor, op=torch.distributed.ReduceOp.SUM)
avg_time = time_tensor[self.rank].item() / self.world_size
if avg_time < best_time:
best_time = avg_time
best_config = (algo, nb, nt)
if best_config:
self.best_configs[size] = best_config
```
### Dispatching with Tuned Configuration
At runtime, round the message size to the next power of two and look up the best configuration:
```python
def get_tuned_config(self, size):
if size < 1024:
target_size = 1024
elif size > 256 * 1024 * 1024:
target_size = 256 * 1024 * 1024
else:
target_size = 1 << (size - 1).bit_length()
return self.best_configs.get(target_size)
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, stream=None):
config = self.get_tuned_config(tensor.nbytes)
algo, nblocks, nthreads = config if config else (self._algorithm_nvls_packet, 0, 0)
algo.execute(
comm=self.comm.communicator,
input_buffer=tensor.data_ptr(),
output_buffer=tensor.data_ptr(),
input_size=tensor.nbytes,
output_size=tensor.nbytes,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
op=mscclpp.ReduceOp.SUM,
stream=stream.cuda_stream if stream else torch.cuda.current_stream().cuda_stream,
nblocks=nblocks,
nthreads_per_block=nthreads,
)
```
### Running the Tuning Example
```bash
MSCCLPP_MASTER_ADDR=<ip> MSCCLPP_MASTER_PORT=<port> \
torchrun --nnodes=1 --nproc_per_node=8 customized_comm_with_tuning.py
```

View File

@@ -84,6 +84,11 @@ class Algorithm {
/// @return The Constraint struct specifying worldSize and nRanksPerNode requirements.
virtual Constraint constraint() const = 0;
/// Set the valid message size range for this algorithm.
/// @param minMessageSize Minimum supported message size in bytes.
/// @param maxMessageSize Maximum supported message size in bytes.
virtual void setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) = 0;
/// Execute the algorithm.
/// @param comm The communicator to use.
/// @param input Pointer to the input buffer.
@@ -233,6 +238,7 @@ class NativeAlgorithm : public Algorithm {
const std::string& name() const override;
const std::string& collective() const override;
const std::pair<size_t, size_t>& messageRange() const override;
void setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) override;
const std::unordered_map<std::string, uint64_t>& tags() const override;
const CollectiveBufferMode& bufferMode() const override;
AlgorithmType type() const override { return AlgorithmType::Native; }
@@ -273,6 +279,7 @@ class DslAlgorithm : public Algorithm, public AlgorithmBuilder, public std::enab
const std::string& name() const override;
const std::string& collective() const override;
const std::pair<size_t, size_t>& messageRange() const override;
void setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) override;
const std::unordered_map<std::string, uint64_t>& tags() const override;
const CollectiveBufferMode& bufferMode() const override;
CommResult execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,

View File

@@ -29,7 +29,9 @@ class Proxy {
public:
/// Constructor.
/// @param handler Handler for each FIFO trigger.
/// @param threadInit Optional function run in proxy thread before FIFO consumption.
/// @param threadInit Optional function run once in the proxy thread before FIFO consumption.
/// The function should initialize thread runtime context before any CUDA API call in that thread
/// (for example, set CUDA device and optionally bind NUMA affinity).
/// @param fifoSize FIFO size (default: DEFAULT_FIFO_SIZE).
Proxy(ProxyHandler handler, std::function<void()> threadInit, int fifoSize = DEFAULT_FIFO_SIZE);

View File

@@ -60,6 +60,12 @@ void register_algorithm(nb::module_& m) {
.def_prop_ro("name", &Algorithm::name)
.def_prop_ro("collective", &Algorithm::collective)
.def_prop_ro("message_range", &Algorithm::messageRange)
.def(
"set_message_size_range",
[](Algorithm& self, size_t minMessageSize, size_t maxMessageSize) {
self.setMessageSizeRange(minMessageSize, maxMessageSize);
},
nb::arg("min_message_size"), nb::arg("max_message_size"))
.def_prop_ro("tags", &Algorithm::tags)
.def_prop_ro("buffer_mode", &Algorithm::bufferMode)
.def_prop_ro("constraint", &Algorithm::constraint)

View File

@@ -114,11 +114,24 @@ class Algorithm:
"""The collective operation this algorithm implements (e.g., "allreduce", "allgather")."""
return self._algorithm.collective
@cached_property
@property
def message_size_range(self) -> Tuple[int, int]:
"""The valid message size range (min_size, max_size) in bytes."""
return (self._algorithm.message_range[0], self._algorithm.message_range[1])
def set_message_size_range(self, min_message_size: int, max_message_size: int):
"""Set the valid message size range in bytes.
Args:
min_message_size: Minimum supported message size in bytes.
max_message_size: Maximum supported message size in bytes.
Only supported for native algorithms. Raises TypeError for DSL algorithms.
"""
if self.is_dsl_algorithm():
raise TypeError("set_message_size_range is only supported for native algorithms")
self._algorithm.set_message_size_range(min_message_size, max_message_size)
@cached_property
def tags(self) -> Dict[str, int]:
"""Dictionary of tag names to tag values for algorithm selection hints."""

View File

@@ -66,6 +66,11 @@ const std::pair<size_t, size_t>& NativeAlgorithm::messageRange() const {
return range;
}
void NativeAlgorithm::setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) {
minMessageSize_ = minMessageSize;
maxMessageSize_ = maxMessageSize;
}
const std::unordered_map<std::string, uint64_t>& NativeAlgorithm::tags() const { return tags_; }
const CollectiveBufferMode& NativeAlgorithm::bufferMode() const { return bufferMode_; }
@@ -143,6 +148,10 @@ const std::pair<size_t, size_t>& DslAlgorithm::messageRange() const {
return range;
}
void DslAlgorithm::setMessageSizeRange(size_t, size_t) {
THROW(EXEC, Error, ErrorCode::InvalidUsage, "setMessageSizeRange is only supported for native algorithms");
}
const std::unordered_map<std::string, uint64_t>& DslAlgorithm::tags() const { return tags_; }
const CollectiveBufferMode& DslAlgorithm::bufferMode() const {

View File

@@ -59,11 +59,15 @@ MSCCLPP_API_CPP Proxy::~Proxy() {
MSCCLPP_API_CPP void Proxy::start(bool blocking) {
pimpl_->running.store(true, std::memory_order_release);
pimpl_->service = std::thread([this] {
// threadInit() is responsible for setting up the runtime context for the thread.
// The default implementation sets the CUDA device and NUMA affinity to match the main thread (see Proxy ctor).
// It should be called before any CUDA API calls to avoid resource allocation on unwanted GPUs.
pimpl_->threadInit();
// never capture in a proxy thread
auto mode = cudaStreamCaptureModeRelaxed;
MSCCLPP_CUDATHROW(cudaThreadExchangeStreamCaptureMode(&mode));
pimpl_->threadInit();
pimpl_->threadStarted.store(true, std::memory_order_release);
ProxyHandler handler = this->pimpl_->handler;