This commit is contained in:
Binyang Li
2026-03-17 19:16:55 +00:00
parent c777290271
commit a8edfb7cf9
2 changed files with 25 additions and 9 deletions

View File

@@ -288,6 +288,7 @@ class NativeAlgorithm : public Algorithm {
std::unordered_map<std::string, uint64_t> tags_;
Constraint constraint_;
std::unordered_map<AlgorithmCtxKey, std::shared_ptr<void>> contexts_;
std::unordered_map<int32_t, AlgorithmCtxKey> customKeyMap_;
bool initialized_ = false;
};

View File

@@ -47,15 +47,26 @@ CommResult NativeAlgorithm::execute(std::shared_ptr<Communicator> comm, const vo
initFunc_(comm);
initialized_ = true;
}
AlgorithmCtxKey ctxKey = (contextKey >= 0)
? AlgorithmCtxKey(contextKey)
: contextKeyGenFunc_(input, output, inputSize, outputSize, dtype, symmetricMemory);
auto it = contexts_.find(ctxKey);
AlgorithmCtxKey bufferKey;
if (contextKey >= 0) {
auto mapIt = customKeyMap_.find(contextKey);
if (mapIt != customKeyMap_.end()) {
// Fast path: reuse the previously cached buffer key for this custom key.
bufferKey = mapIt->second;
} else {
// First time seeing this custom key — generate and cache the buffer key.
bufferKey = contextKeyGenFunc_(input, output, inputSize, outputSize, dtype, symmetricMemory);
customKeyMap_[contextKey] = bufferKey;
}
} else {
bufferKey = contextKeyGenFunc_(input, output, inputSize, outputSize, dtype, symmetricMemory);
}
auto it = contexts_.find(bufferKey);
if (it == contexts_.end()) {
auto ctx = contextInitFunc_(comm, input, output, inputSize, outputSize, dtype);
contexts_[ctxKey] = ctx;
contexts_[bufferKey] = ctx;
}
return kernelLaunchFunc_(contexts_[ctxKey], input, output, inputSize, outputSize, dtype, op, stream, nBlocks,
return kernelLaunchFunc_(contexts_[bufferKey], input, output, inputSize, outputSize, dtype, op, stream, nBlocks,
nThreadsPerBlock, extras);
}
@@ -82,11 +93,15 @@ Algorithm::Constraint NativeAlgorithm::constraint() const { return constraint_;
void NativeAlgorithm::reset(int32_t contextKey) {
if (contextKey >= 0) {
// Remove only the context associated with the given custom key
AlgorithmCtxKey key(contextKey);
contexts_.erase(key);
// Remove the context associated with the given custom key
auto it = customKeyMap_.find(contextKey);
if (it != customKeyMap_.end()) {
contexts_.erase(it->second);
customKeyMap_.erase(it);
}
} else {
contexts_.clear();
customKeyMap_.clear();
initialized_ = false;
}
}