mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
* WIP POC of dispatcher
* Dispatcher python workflow setup.
* Dispatcher cleanup and updates.
Further dispatcher cleanup and updates.
Build fixes
Improvements and python to CK example
Improvements to readme
* Fixes to python paths
* Cleaning up code
* Improving dispatcher support for different arch
Fixing typos
* Fix formatting errors
* Cleaning up examples
* Improving codegeneration
* Improving and fixing C++ examples
* Adding conv functionality (fwd,bwd,bwdw) and examples.
* Fixes based on feedback.
* Further fixes based on feedback.
* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.
* Another round of improvements based on feedback.
* Trimming out unnecessary code.
* Fixing the multi-D implementation.
* Using gpu verification for gemms and fixing convolutions tflops calculation.
* Fix counter usage issue and arch filtering per ops.
* Adding changelog and other fixes.
* Improve examples and resolve critical bugs.
* Reduce build time for python examples.
* Fixing minor bug.
* Fix compilation error.
* Improve installation instructions for dispatcher.
* Add docker based installation instructions for dispatcher.
* Fixing arch-based filtering to match tile engine.
* Remove dead code and fix arch filtering.
* Minor bugfix.
* Updates after rebase.
* Trimming code.
* Fix copyright headers.
* Consolidate examples, cut down code.
* Minor fixes.
* Improving python examples.
* Update readmes.
* Remove conv functionality.
* Cleanup following conv removable.
[ROCm/composable_kernel commit: 9e049a32a1]
167 lines
4.6 KiB
C++
167 lines
4.6 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
/// Unit tests for Registry using Google Test
|
|
|
|
#include "ck_tile/dispatcher/registry.hpp"
|
|
#include "test_mock_kernel.hpp"
|
|
#include <gtest/gtest.h>
|
|
|
|
using namespace ck_tile::dispatcher;
|
|
using namespace ck_tile::dispatcher::test;
|
|
|
|
TEST(RegistryTest, Registration)
|
|
{
|
|
Registry& registry = Registry::instance();
|
|
registry.clear();
|
|
|
|
auto key = make_test_key(256);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "test_kernel");
|
|
|
|
bool registered = registry.register_kernel(kernel);
|
|
EXPECT_TRUE(registered);
|
|
EXPECT_EQ(registry.size(), 1);
|
|
}
|
|
|
|
TEST(RegistryTest, Lookup)
|
|
{
|
|
Registry& registry = Registry::instance();
|
|
registry.clear();
|
|
|
|
auto key = make_test_key(256);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "test_kernel");
|
|
registry.register_kernel(kernel);
|
|
|
|
// Lookup by key
|
|
auto found = registry.lookup(key);
|
|
ASSERT_NE(found, nullptr);
|
|
EXPECT_EQ(found->get_name(), "test_kernel");
|
|
|
|
// Lookup by identifier
|
|
std::string id = key.encode_identifier();
|
|
auto found2 = registry.lookup(id);
|
|
ASSERT_NE(found2, nullptr);
|
|
EXPECT_EQ(found2->get_name(), "test_kernel");
|
|
|
|
// Lookup non-existent
|
|
auto key2 = make_test_key(128);
|
|
auto not_found = registry.lookup(key2);
|
|
EXPECT_EQ(not_found, nullptr);
|
|
}
|
|
|
|
TEST(RegistryTest, Priority)
|
|
{
|
|
Registry& registry = Registry::instance();
|
|
registry.clear();
|
|
|
|
auto key = make_test_key(256);
|
|
auto kernel1 = std::make_shared<MockKernelInstance>(key, "kernel_low");
|
|
auto kernel2 = std::make_shared<MockKernelInstance>(key, "kernel_high");
|
|
|
|
// Register with low priority
|
|
registry.register_kernel(kernel1, Registry::Priority::Low);
|
|
|
|
// Try to register with normal priority (should replace)
|
|
bool replaced = registry.register_kernel(kernel2, Registry::Priority::Normal);
|
|
EXPECT_TRUE(replaced);
|
|
|
|
auto found = registry.lookup(key);
|
|
ASSERT_NE(found, nullptr);
|
|
EXPECT_EQ(found->get_name(), "kernel_high");
|
|
|
|
// Try to register with low priority again (should fail)
|
|
auto kernel3 = std::make_shared<MockKernelInstance>(key, "kernel_low2");
|
|
bool not_replaced = registry.register_kernel(kernel3, Registry::Priority::Low);
|
|
EXPECT_FALSE(not_replaced);
|
|
|
|
found = registry.lookup(key);
|
|
ASSERT_NE(found, nullptr);
|
|
EXPECT_EQ(found->get_name(), "kernel_high");
|
|
}
|
|
|
|
TEST(RegistryTest, GetAll)
|
|
{
|
|
Registry& registry = Registry::instance();
|
|
registry.clear();
|
|
|
|
auto key1 = make_test_key(256);
|
|
auto key2 = make_test_key(128);
|
|
auto kernel1 = std::make_shared<MockKernelInstance>(key1, "kernel1");
|
|
auto kernel2 = std::make_shared<MockKernelInstance>(key2, "kernel2");
|
|
|
|
registry.register_kernel(kernel1);
|
|
registry.register_kernel(kernel2);
|
|
|
|
auto all = registry.get_all();
|
|
EXPECT_EQ(all.size(), 2);
|
|
}
|
|
|
|
TEST(RegistryTest, Filter)
|
|
{
|
|
Registry& registry = Registry::instance();
|
|
registry.clear();
|
|
|
|
// Create kernels with different tile sizes
|
|
for(int tile_m : {128, 256, 512})
|
|
{
|
|
auto key = make_test_key(tile_m);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel_" + std::to_string(tile_m));
|
|
registry.register_kernel(kernel);
|
|
}
|
|
|
|
// Filter for large tiles (>= 256)
|
|
auto large_tiles = registry.filter(
|
|
[](const KernelInstance& k) { return k.get_key().algorithm.tile_shape.m >= 256; });
|
|
|
|
EXPECT_EQ(large_tiles.size(), 2);
|
|
}
|
|
|
|
TEST(RegistryTest, Clear)
|
|
{
|
|
Registry& registry = Registry::instance();
|
|
registry.clear();
|
|
|
|
auto key = make_test_key(256);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "test_kernel");
|
|
registry.register_kernel(kernel);
|
|
|
|
EXPECT_EQ(registry.size(), 1);
|
|
|
|
registry.clear();
|
|
EXPECT_EQ(registry.size(), 0);
|
|
}
|
|
|
|
TEST(RegistryTest, MultipleKernels)
|
|
{
|
|
Registry& registry = Registry::instance();
|
|
registry.clear();
|
|
|
|
// Register multiple kernels
|
|
for(int i = 0; i < 10; ++i)
|
|
{
|
|
auto key = make_test_key(256 + i);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel_" + std::to_string(i));
|
|
registry.register_kernel(kernel);
|
|
}
|
|
|
|
EXPECT_EQ(registry.size(), 10);
|
|
|
|
// Verify all can be looked up
|
|
for(int i = 0; i < 10; ++i)
|
|
{
|
|
auto key = make_test_key(256 + i);
|
|
auto found = registry.lookup(key);
|
|
ASSERT_NE(found, nullptr);
|
|
EXPECT_EQ(found->get_name(), "kernel_" + std::to_string(i));
|
|
}
|
|
}
|
|
|
|
TEST(RegistryTest, Singleton)
|
|
{
|
|
Registry& reg1 = Registry::instance();
|
|
Registry& reg2 = Registry::instance();
|
|
|
|
// Should be the same instance
|
|
EXPECT_EQ(®1, ®2);
|
|
}
|