mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 14:59:29 +00:00
Adding Support to Setting Message Size Range in Native Algorithm API (#758)
This commit is contained in:
@@ -84,6 +84,11 @@ class Algorithm {
|
||||
/// @return The Constraint struct specifying worldSize and nRanksPerNode requirements.
|
||||
virtual Constraint constraint() const = 0;
|
||||
|
||||
/// Set the valid message size range for this algorithm.
|
||||
/// @param minMessageSize Minimum supported message size in bytes.
|
||||
/// @param maxMessageSize Maximum supported message size in bytes.
|
||||
virtual void setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) = 0;
|
||||
|
||||
/// Execute the algorithm.
|
||||
/// @param comm The communicator to use.
|
||||
/// @param input Pointer to the input buffer.
|
||||
@@ -233,6 +238,7 @@ class NativeAlgorithm : public Algorithm {
|
||||
const std::string& name() const override;
|
||||
const std::string& collective() const override;
|
||||
const std::pair<size_t, size_t>& messageRange() const override;
|
||||
void setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) override;
|
||||
const std::unordered_map<std::string, uint64_t>& tags() const override;
|
||||
const CollectiveBufferMode& bufferMode() const override;
|
||||
AlgorithmType type() const override { return AlgorithmType::Native; }
|
||||
@@ -273,6 +279,7 @@ class DslAlgorithm : public Algorithm, public AlgorithmBuilder, public std::enab
|
||||
const std::string& name() const override;
|
||||
const std::string& collective() const override;
|
||||
const std::pair<size_t, size_t>& messageRange() const override;
|
||||
void setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) override;
|
||||
const std::unordered_map<std::string, uint64_t>& tags() const override;
|
||||
const CollectiveBufferMode& bufferMode() const override;
|
||||
CommResult execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
|
||||
@@ -60,6 +60,12 @@ void register_algorithm(nb::module_& m) {
|
||||
.def_prop_ro("name", &Algorithm::name)
|
||||
.def_prop_ro("collective", &Algorithm::collective)
|
||||
.def_prop_ro("message_range", &Algorithm::messageRange)
|
||||
.def(
|
||||
"set_message_size_range",
|
||||
[](Algorithm& self, size_t minMessageSize, size_t maxMessageSize) {
|
||||
self.setMessageSizeRange(minMessageSize, maxMessageSize);
|
||||
},
|
||||
nb::arg("min_message_size"), nb::arg("max_message_size"))
|
||||
.def_prop_ro("tags", &Algorithm::tags)
|
||||
.def_prop_ro("buffer_mode", &Algorithm::bufferMode)
|
||||
.def_prop_ro("constraint", &Algorithm::constraint)
|
||||
|
||||
@@ -114,11 +114,24 @@ class Algorithm:
|
||||
"""The collective operation this algorithm implements (e.g., "allreduce", "allgather")."""
|
||||
return self._algorithm.collective
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def message_size_range(self) -> Tuple[int, int]:
|
||||
"""The valid message size range (min_size, max_size) in bytes."""
|
||||
return (self._algorithm.message_range[0], self._algorithm.message_range[1])
|
||||
|
||||
def set_message_size_range(self, min_message_size: int, max_message_size: int):
|
||||
"""Set the valid message size range in bytes.
|
||||
|
||||
Args:
|
||||
min_message_size: Minimum supported message size in bytes.
|
||||
max_message_size: Maximum supported message size in bytes.
|
||||
|
||||
Only supported for native algorithms. Raises TypeError for DSL algorithms.
|
||||
"""
|
||||
if self.is_dsl_algorithm():
|
||||
raise TypeError("set_message_size_range is only supported for native algorithms")
|
||||
self._algorithm.set_message_size_range(min_message_size, max_message_size)
|
||||
|
||||
@cached_property
|
||||
def tags(self) -> Dict[str, int]:
|
||||
"""Dictionary of tag names to tag values for algorithm selection hints."""
|
||||
|
||||
@@ -66,6 +66,11 @@ const std::pair<size_t, size_t>& NativeAlgorithm::messageRange() const {
|
||||
return range;
|
||||
}
|
||||
|
||||
void NativeAlgorithm::setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) {
|
||||
minMessageSize_ = minMessageSize;
|
||||
maxMessageSize_ = maxMessageSize;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, uint64_t>& NativeAlgorithm::tags() const { return tags_; }
|
||||
|
||||
const CollectiveBufferMode& NativeAlgorithm::bufferMode() const { return bufferMode_; }
|
||||
@@ -143,6 +148,10 @@ const std::pair<size_t, size_t>& DslAlgorithm::messageRange() const {
|
||||
return range;
|
||||
}
|
||||
|
||||
void DslAlgorithm::setMessageSizeRange(size_t, size_t) {
|
||||
THROW(EXEC, Error, ErrorCode::InvalidUsage, "setMessageSizeRange is only supported for native algorithms");
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, uint64_t>& DslAlgorithm::tags() const { return tags_; }
|
||||
|
||||
const CollectiveBufferMode& DslAlgorithm::bufferMode() const {
|
||||
|
||||
Reference in New Issue
Block a user