Adding Support to Setting Message Size Range in Native Algorithm API (#758)

This commit is contained in:
Caio Rocha
2026-02-27 17:50:43 -08:00
committed by GitHub
parent ab49386839
commit 4bc1999001
4 changed files with 36 additions and 1 deletions

View File

@@ -84,6 +84,11 @@ class Algorithm {
/// @return The Constraint struct specifying worldSize and nRanksPerNode requirements.
virtual Constraint constraint() const = 0;
/// Set the valid message size range for this algorithm.
/// @param minMessageSize Minimum supported message size in bytes.
/// @param maxMessageSize Maximum supported message size in bytes.
virtual void setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) = 0;
/// Execute the algorithm.
/// @param comm The communicator to use.
/// @param input Pointer to the input buffer.
@@ -233,6 +238,7 @@ class NativeAlgorithm : public Algorithm {
const std::string& name() const override;
const std::string& collective() const override;
const std::pair<size_t, size_t>& messageRange() const override;
void setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) override;
const std::unordered_map<std::string, uint64_t>& tags() const override;
const CollectiveBufferMode& bufferMode() const override;
AlgorithmType type() const override { return AlgorithmType::Native; }
@@ -273,6 +279,7 @@ class DslAlgorithm : public Algorithm, public AlgorithmBuilder, public std::enab
const std::string& name() const override;
const std::string& collective() const override;
const std::pair<size_t, size_t>& messageRange() const override;
void setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) override;
const std::unordered_map<std::string, uint64_t>& tags() const override;
const CollectiveBufferMode& bufferMode() const override;
CommResult execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,

View File

@@ -60,6 +60,12 @@ void register_algorithm(nb::module_& m) {
.def_prop_ro("name", &Algorithm::name)
.def_prop_ro("collective", &Algorithm::collective)
.def_prop_ro("message_range", &Algorithm::messageRange)
.def(
"set_message_size_range",
[](Algorithm& self, size_t minMessageSize, size_t maxMessageSize) {
self.setMessageSizeRange(minMessageSize, maxMessageSize);
},
nb::arg("min_message_size"), nb::arg("max_message_size"))
.def_prop_ro("tags", &Algorithm::tags)
.def_prop_ro("buffer_mode", &Algorithm::bufferMode)
.def_prop_ro("constraint", &Algorithm::constraint)

View File

@@ -114,11 +114,24 @@ class Algorithm:
"""The collective operation this algorithm implements (e.g., "allreduce", "allgather")."""
return self._algorithm.collective
@cached_property
@property
def message_size_range(self) -> Tuple[int, int]:
"""The valid message size range (min_size, max_size) in bytes."""
return (self._algorithm.message_range[0], self._algorithm.message_range[1])
def set_message_size_range(self, min_message_size: int, max_message_size: int):
"""Set the valid message size range in bytes.
Args:
min_message_size: Minimum supported message size in bytes.
max_message_size: Maximum supported message size in bytes.
Only supported for native algorithms. Raises TypeError for DSL algorithms.
"""
if self.is_dsl_algorithm():
raise TypeError("set_message_size_range is only supported for native algorithms")
self._algorithm.set_message_size_range(min_message_size, max_message_size)
@cached_property
def tags(self) -> Dict[str, int]:
"""Dictionary of tag names to tag values for algorithm selection hints."""

View File

@@ -66,6 +66,11 @@ const std::pair<size_t, size_t>& NativeAlgorithm::messageRange() const {
return range;
}
void NativeAlgorithm::setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) {
minMessageSize_ = minMessageSize;
maxMessageSize_ = maxMessageSize;
}
const std::unordered_map<std::string, uint64_t>& NativeAlgorithm::tags() const { return tags_; }
const CollectiveBufferMode& NativeAlgorithm::bufferMode() const { return bufferMode_; }
@@ -143,6 +148,10 @@ const std::pair<size_t, size_t>& DslAlgorithm::messageRange() const {
return range;
}
void DslAlgorithm::setMessageSizeRange(size_t, size_t) {
THROW(EXEC, Error, ErrorCode::InvalidUsage, "setMessageSizeRange is only supported for native algorithms");
}
const std::unordered_map<std::string, uint64_t>& DslAlgorithm::tags() const { return tags_; }
const CollectiveBufferMode& DslAlgorithm::bufferMode() const {