mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
WIP
This commit is contained in:
@@ -288,6 +288,7 @@ class NativeAlgorithm : public Algorithm {
|
||||
std::unordered_map<std::string, uint64_t> tags_;
|
||||
Constraint constraint_;
|
||||
std::unordered_map<AlgorithmCtxKey, std::shared_ptr<void>> contexts_;
|
||||
std::unordered_map<int32_t, AlgorithmCtxKey> customKeyMap_;
|
||||
|
||||
bool initialized_ = false;
|
||||
};
|
||||
|
||||
@@ -47,15 +47,26 @@ CommResult NativeAlgorithm::execute(std::shared_ptr<Communicator> comm, const vo
|
||||
initFunc_(comm);
|
||||
initialized_ = true;
|
||||
}
|
||||
AlgorithmCtxKey ctxKey = (contextKey >= 0)
|
||||
? AlgorithmCtxKey(contextKey)
|
||||
: contextKeyGenFunc_(input, output, inputSize, outputSize, dtype, symmetricMemory);
|
||||
auto it = contexts_.find(ctxKey);
|
||||
AlgorithmCtxKey bufferKey;
|
||||
if (contextKey >= 0) {
|
||||
auto mapIt = customKeyMap_.find(contextKey);
|
||||
if (mapIt != customKeyMap_.end()) {
|
||||
// Fast path: reuse the previously cached buffer key for this custom key.
|
||||
bufferKey = mapIt->second;
|
||||
} else {
|
||||
// First time seeing this custom key — generate and cache the buffer key.
|
||||
bufferKey = contextKeyGenFunc_(input, output, inputSize, outputSize, dtype, symmetricMemory);
|
||||
customKeyMap_[contextKey] = bufferKey;
|
||||
}
|
||||
} else {
|
||||
bufferKey = contextKeyGenFunc_(input, output, inputSize, outputSize, dtype, symmetricMemory);
|
||||
}
|
||||
auto it = contexts_.find(bufferKey);
|
||||
if (it == contexts_.end()) {
|
||||
auto ctx = contextInitFunc_(comm, input, output, inputSize, outputSize, dtype);
|
||||
contexts_[ctxKey] = ctx;
|
||||
contexts_[bufferKey] = ctx;
|
||||
}
|
||||
return kernelLaunchFunc_(contexts_[ctxKey], input, output, inputSize, outputSize, dtype, op, stream, nBlocks,
|
||||
return kernelLaunchFunc_(contexts_[bufferKey], input, output, inputSize, outputSize, dtype, op, stream, nBlocks,
|
||||
nThreadsPerBlock, extras);
|
||||
}
|
||||
|
||||
@@ -82,11 +93,15 @@ Algorithm::Constraint NativeAlgorithm::constraint() const { return constraint_;
|
||||
|
||||
void NativeAlgorithm::reset(int32_t contextKey) {
|
||||
if (contextKey >= 0) {
|
||||
// Remove only the context associated with the given custom key
|
||||
AlgorithmCtxKey key(contextKey);
|
||||
contexts_.erase(key);
|
||||
// Remove the context associated with the given custom key
|
||||
auto it = customKeyMap_.find(contextKey);
|
||||
if (it != customKeyMap_.end()) {
|
||||
contexts_.erase(it->second);
|
||||
customKeyMap_.erase(it);
|
||||
}
|
||||
} else {
|
||||
contexts_.clear();
|
||||
customKeyMap_.clear();
|
||||
initialized_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user