diff --git a/dispatcher/codegen/unified_gemm_codegen.py b/dispatcher/codegen/unified_gemm_codegen.py index a818cec83e..c0fb08aa44 100755 --- a/dispatcher/codegen/unified_gemm_codegen.py +++ b/dispatcher/codegen/unified_gemm_codegen.py @@ -734,7 +734,7 @@ using AccDataType = float; DsLayout, CLayout, ElementWiseFn, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK, - TransposeC, NumWaveGroups, false, 1, false, 1, DoubleSmemBuffer>; + TransposeC, NumWaveGroups, false, 1, 1, DoubleSmemBuffer>; using GemmEpilogue = CShuffleEpilogue;""" elif config.trait.epilogue == "cshuffle": return """ @@ -743,7 +743,7 @@ using AccDataType = float; tuple<>, CLayout, element_wise::PassThrough, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK, - TransposeC, NumWaveGroups, false, 1, false, 1, DoubleSmemBuffer>; + TransposeC, NumWaveGroups, false, 1, 1, DoubleSmemBuffer>; using GemmEpilogue = CShuffleEpilogue;""" else: return """ diff --git a/dispatcher/codegen/unified_grouped_conv_codegen.py b/dispatcher/codegen/unified_grouped_conv_codegen.py index ff40cb4ed4..db0ef79bd3 100644 --- a/dispatcher/codegen/unified_grouped_conv_codegen.py +++ b/dispatcher/codegen/unified_grouped_conv_codegen.py @@ -600,7 +600,7 @@ struct {kernel_name}_Launcher {{ GroupedConvTraitsType::FixedGemmParams::TransposeC, Config::NumWaveGroups, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - Config::VectorSizeC, false, 1, Config::DoubleSmemBuffer>>; + Config::VectorSizeC, 1, Config::DoubleSmemBuffer>>; using Kernel = {kernel_type}< GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>; diff --git a/dispatcher/include/ck_tile/dispatcher/base_registry.hpp b/dispatcher/include/ck_tile/dispatcher/base_registry.hpp index 2bb940c320..f4e7151d24 100644 --- a/dispatcher/include/ck_tile/dispatcher/base_registry.hpp +++ b/dispatcher/include/ck_tile/dispatcher/base_registry.hpp @@ -69,14 +69,14 @@ class BaseRegistry BaseRegistry& operator=(const BaseRegistry&) = delete; /// Register a kernel. If the key already exists, the new entry replaces it - /// unless the existing entry has strictly higher priority. - /// Same-priority registration overwrites (last-writer-wins at equal priority). + /// only when its priority is strictly higher than the existing entry's + /// priority. Same-priority registration is rejected (first-writer-wins). bool register_kernel(const KeyType& key, InstancePtr instance, Priority priority = Priority::Normal) { std::lock_guard lock(mutex_); auto it = entries_.find(key); - if(it != entries_.end() && it->second.priority > priority) + if(it != entries_.end() && it->second.priority >= priority) { return false; } diff --git a/dispatcher/src/registry.cpp b/dispatcher/src/registry.cpp index f565885181..cd17fcbd53 100644 --- a/dispatcher/src/registry.cpp +++ b/dispatcher/src/registry.cpp @@ -55,7 +55,10 @@ bool Registry::register_kernel(KernelInstancePtr instance, Priority priority) if(!instance) return false; - if(Base::register_kernel(instance->get_name(), instance, priority)) + // Store under the encoded identifier so Registry::lookup(KernelKey) finds it. + // Previously stored under instance->get_name(), but lookup(KernelKey) queries by + // key.encode_identifier() — those keys never matched, breaking key-based lookup. + if(Base::register_kernel(instance->get_key().encode_identifier(), instance, priority)) { if(auto_export_enabled_ && auto_export_on_every_registration_) { diff --git a/dispatcher/tests/test_grouped_conv_registry.cpp b/dispatcher/tests/test_grouped_conv_registry.cpp index 47d13a9997..f05f2d0476 100644 --- a/dispatcher/tests/test_grouped_conv_registry.cpp +++ b/dispatcher/tests/test_grouped_conv_registry.cpp @@ -19,7 +19,7 @@ void test_grouped_conv_registry_basic() reg.clear(); reg.set_name("test_registry"); - assert(reg.name() == "test_registry"); + assert(reg.get_name() == "test_registry"); assert(reg.size() == 0); assert(reg.empty()); diff --git a/dispatcher/tests/test_kernel_key.cpp b/dispatcher/tests/test_kernel_key.cpp index b35641952a..b44b140db5 100644 --- a/dispatcher/tests/test_kernel_key.cpp +++ b/dispatcher/tests/test_kernel_key.cpp @@ -71,7 +71,12 @@ TEST(KernelKeyTest, EncodeIdentifier) EXPECT_NE(id.find("256x256x32"), std::string::npos); // tile shape EXPECT_NE(id.find("2x2x1"), std::string::npos); // wave shape EXPECT_NE(id.find("32x32x16"), std::string::npos); // warp tile shape - EXPECT_NE(id.find("persist"), std::string::npos); // persistent flag + + // Verify persistent flag is encoded by toggling it and asserting the + // identifier changes. Robust to encoding spelling changes. + KernelKey non_persistent_key = key; + non_persistent_key.algorithm.persistent = false; + EXPECT_NE(id, non_persistent_key.encode_identifier()); } TEST(KernelKeyTest, EncodeIdentifierWithFusion) @@ -97,7 +102,12 @@ TEST(KernelKeyTest, EncodeIdentifierWithFusion) // Check fusion-specific components EXPECT_NE(id.find("Relu"), std::string::npos); EXPECT_NE(id.find("_d2"), std::string::npos); - EXPECT_NE(id.find("nopers"), std::string::npos); + + // Verify persistent flag is encoded by toggling it and asserting the + // identifier changes. Robust to encoding spelling changes. + KernelKey persistent_key = key; + persistent_key.algorithm.persistent = true; + EXPECT_NE(id, persistent_key.encode_identifier()); } TEST(KernelKeyTest, EncodeIdentifierWithSplitK) diff --git a/dispatcher/tests/test_kernel_key_extended.cpp b/dispatcher/tests/test_kernel_key_extended.cpp index 1c6b5bcba0..01b082fa63 100644 --- a/dispatcher/tests/test_kernel_key_extended.cpp +++ b/dispatcher/tests/test_kernel_key_extended.cpp @@ -374,9 +374,9 @@ TEST_F(IdentifierEncodingTest, IdentifierReflectsPersistence) std::string persistent_id = persistent_key.encode_identifier(); std::string non_persistent_id = non_persistent_key.encode_identifier(); + // EXPECT_NE above already verifies persistence affects encoding; + // substring checks for specific spelling were brittle and have been removed. EXPECT_NE(persistent_id, non_persistent_id); - EXPECT_NE(persistent_id.find("persist"), std::string::npos); - EXPECT_NE(non_persistent_id.find("nopers"), std::string::npos); } // ============================================================================= diff --git a/dispatcher/tests/test_tile_backend.cpp b/dispatcher/tests/test_tile_backend.cpp index 4e7c693071..dd17c05520 100644 --- a/dispatcher/tests/test_tile_backend.cpp +++ b/dispatcher/tests/test_tile_backend.cpp @@ -97,8 +97,10 @@ TEST(TileBackendTest, TileKernelIdentifierEncoding) EXPECT_NE(id.find("2x2x1"), std::string::npos); EXPECT_NE(id.find("32x32x16"), std::string::npos); - // Should contain persistent flag - EXPECT_NE(id.find("nopers"), std::string::npos); // persistent = false + // Verify persistent flag affects identifier + KernelKey persistent_key = key; + persistent_key.algorithm.persistent = true; + EXPECT_NE(id, persistent_key.encode_identifier()); } TEST(TileBackendTest, MultipleKernelRegistration)