mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
## Motivation CPP tests in dispatcher were failing due to a mismatch in registry key and string representation. ## Technical Details Bug 1 - Registry key mismatch: The registry stored kernels using get_name() but lookups used encode_identifier(), causing all registry lookups to fail. Fixed by changing registry.cpp:58 to use encode_identifier() for storage. Bug 2 - String representation changes: Tests checked for "persist"/"nopers" substrings, but the code emits "True"/"False". Fixed by replacing brittle substring checks with comparison-based assertions in test_kernel_key.cpp and test_kernel_key_extended.cpp. ## Test Plan Tested with CPP tests in dispatcher ## Test Result Validation: All three core cpp tests now pass: - test_kernel_key - 6/6 tests passing - test_kernel_key_extended - 25/25 tests passing - test_registry - 8/8 tests passing ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
158 lines
5.4 KiB
C++
158 lines
5.4 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
/// Unit tests for KernelKey using Google Test
|
|
|
|
#include "ck_tile/dispatcher/kernel_key.hpp"
|
|
#include "test_mock_kernel.hpp"
|
|
#include <gtest/gtest.h>
|
|
|
|
using namespace ck_tile::dispatcher;
|
|
using namespace ck_tile::dispatcher::test;
|
|
|
|
TEST(KernelKeyTest, Construction)
|
|
{
|
|
KernelKey key;
|
|
key.signature.dtype_a = DataType::FP16;
|
|
key.signature.dtype_b = DataType::FP16;
|
|
key.signature.dtype_c = DataType::FP16;
|
|
key.signature.dtype_acc = DataType::FP32;
|
|
key.signature.elementwise_op = "PassThrough";
|
|
key.signature.num_d_tensors = 0;
|
|
|
|
key.algorithm.tile_shape.m = 256;
|
|
key.algorithm.tile_shape.n = 256;
|
|
key.algorithm.tile_shape.k = 32;
|
|
|
|
key.gfx_arch = "gfx942";
|
|
|
|
EXPECT_EQ(key.signature.dtype_a, DataType::FP16);
|
|
EXPECT_EQ(key.algorithm.tile_shape.m, 256);
|
|
EXPECT_EQ(key.gfx_arch, "gfx942");
|
|
}
|
|
|
|
TEST(KernelKeyTest, Equality)
|
|
{
|
|
// Use helper function to ensure all fields are initialized
|
|
KernelKey key1 = make_test_key(256, 256, 32, "gfx942");
|
|
KernelKey key2 = make_test_key(256, 256, 32, "gfx942");
|
|
|
|
EXPECT_EQ(key1, key2);
|
|
EXPECT_FALSE(key1 != key2);
|
|
|
|
// Change one value
|
|
KernelKey key3 = make_test_key(128, 256, 32, "gfx942");
|
|
EXPECT_NE(key1, key3);
|
|
EXPECT_FALSE(key1 == key3);
|
|
}
|
|
|
|
TEST(KernelKeyTest, EncodeIdentifier)
|
|
{
|
|
KernelKey key;
|
|
key.signature.split_k = 1;
|
|
key.signature.elementwise_op = "PassThrough";
|
|
key.signature.num_d_tensors = 0;
|
|
key.algorithm.tile_shape.m = 256;
|
|
key.algorithm.tile_shape.n = 256;
|
|
key.algorithm.tile_shape.k = 32;
|
|
key.algorithm.wave_shape.m = 2;
|
|
key.algorithm.wave_shape.n = 2;
|
|
key.algorithm.wave_shape.k = 1;
|
|
key.algorithm.warp_tile_shape.m = 32;
|
|
key.algorithm.warp_tile_shape.n = 32;
|
|
key.algorithm.warp_tile_shape.k = 16;
|
|
key.algorithm.persistent = true;
|
|
key.algorithm.preshuffle = false;
|
|
key.signature.structured_sparsity = false;
|
|
|
|
std::string id = key.encode_identifier();
|
|
|
|
// Check that identifier contains expected components
|
|
EXPECT_NE(id.find("256x256x32"), std::string::npos); // tile shape
|
|
EXPECT_NE(id.find("2x2x1"), std::string::npos); // wave shape
|
|
EXPECT_NE(id.find("32x32x16"), std::string::npos); // warp tile shape
|
|
|
|
// Verify persistent flag is encoded by toggling it and asserting the
|
|
// identifier changes. Robust to encoding spelling changes.
|
|
KernelKey non_persistent_key = key;
|
|
non_persistent_key.algorithm.persistent = false;
|
|
EXPECT_NE(id, non_persistent_key.encode_identifier());
|
|
}
|
|
|
|
TEST(KernelKeyTest, EncodeIdentifierWithFusion)
|
|
{
|
|
KernelKey key;
|
|
key.signature.split_k = 1;
|
|
key.signature.elementwise_op = "Relu";
|
|
key.signature.num_d_tensors = 2;
|
|
key.algorithm.tile_shape.m = 128;
|
|
key.algorithm.tile_shape.n = 128;
|
|
key.algorithm.tile_shape.k = 64;
|
|
key.algorithm.wave_shape.m = 2;
|
|
key.algorithm.wave_shape.n = 2;
|
|
key.algorithm.wave_shape.k = 1;
|
|
key.algorithm.warp_tile_shape.m = 16;
|
|
key.algorithm.warp_tile_shape.n = 16;
|
|
key.algorithm.warp_tile_shape.k = 32;
|
|
key.algorithm.persistent = false;
|
|
key.signature.structured_sparsity = false;
|
|
|
|
std::string id = key.encode_identifier();
|
|
|
|
// Check fusion-specific components
|
|
EXPECT_NE(id.find("Relu"), std::string::npos);
|
|
EXPECT_NE(id.find("_d2"), std::string::npos);
|
|
|
|
// Verify persistent flag is encoded by toggling it and asserting the
|
|
// identifier changes. Robust to encoding spelling changes.
|
|
KernelKey persistent_key = key;
|
|
persistent_key.algorithm.persistent = true;
|
|
EXPECT_NE(id, persistent_key.encode_identifier());
|
|
}
|
|
|
|
TEST(KernelKeyTest, EncodeIdentifierWithSplitK)
|
|
{
|
|
KernelKey key;
|
|
key.signature.split_k = 4;
|
|
key.signature.elementwise_op = "PassThrough";
|
|
key.signature.num_d_tensors = 0;
|
|
key.algorithm.tile_shape.m = 256;
|
|
key.algorithm.tile_shape.n = 256;
|
|
key.algorithm.tile_shape.k = 32;
|
|
key.algorithm.wave_shape.m = 2;
|
|
key.algorithm.wave_shape.n = 2;
|
|
key.algorithm.wave_shape.k = 1;
|
|
key.algorithm.warp_tile_shape.m = 32;
|
|
key.algorithm.warp_tile_shape.n = 32;
|
|
key.algorithm.warp_tile_shape.k = 16;
|
|
key.algorithm.persistent = false;
|
|
key.signature.structured_sparsity = false;
|
|
|
|
std::string id = key.encode_identifier();
|
|
|
|
EXPECT_NE(id.find("_splitk4"), std::string::npos);
|
|
}
|
|
|
|
TEST(KernelKeyTest, EncodeIdentifierWithSparsity)
|
|
{
|
|
KernelKey key;
|
|
key.signature.split_k = 1;
|
|
key.signature.elementwise_op = "PassThrough";
|
|
key.signature.num_d_tensors = 0;
|
|
key.signature.structured_sparsity = true;
|
|
key.algorithm.tile_shape.m = 256;
|
|
key.algorithm.tile_shape.n = 256;
|
|
key.algorithm.tile_shape.k = 32;
|
|
key.algorithm.wave_shape.m = 2;
|
|
key.algorithm.wave_shape.n = 2;
|
|
key.algorithm.wave_shape.k = 1;
|
|
key.algorithm.warp_tile_shape.m = 32;
|
|
key.algorithm.warp_tile_shape.n = 32;
|
|
key.algorithm.warp_tile_shape.k = 16;
|
|
key.algorithm.persistent = false;
|
|
|
|
std::string id = key.encode_identifier();
|
|
|
|
EXPECT_NE(id.find("_sparse"), std::string::npos);
|
|
}
|