mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-30 03:55:52 +00:00
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
297 lines
8.3 KiB
C++
297 lines
8.3 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
/// Unit tests for Dispatcher using Google Test
|
|
|
|
#include "ck_tile/dispatcher/dispatcher.hpp"
|
|
#include "test_mock_kernel.hpp"
|
|
#include <gtest/gtest.h>
|
|
|
|
using namespace ck_tile::dispatcher;
|
|
using namespace ck_tile::dispatcher::test;
|
|
|
|
class DispatcherTest : public ::testing::Test
|
|
{
|
|
protected:
|
|
void SetUp() override
|
|
{
|
|
// Clear registry before each test
|
|
Registry::instance().clear();
|
|
}
|
|
|
|
void TearDown() override
|
|
{
|
|
// Clean up after each test
|
|
Registry::instance().clear();
|
|
}
|
|
};
|
|
|
|
TEST_F(DispatcherTest, SelectKernelFirstFit)
|
|
{
|
|
Dispatcher dispatcher;
|
|
|
|
// Register kernels
|
|
auto key1 = make_test_key(256);
|
|
auto key2 = make_test_key(128);
|
|
auto kernel1 = std::make_shared<MockKernelInstance>(key1, "kernel1");
|
|
auto kernel2 = std::make_shared<MockKernelInstance>(key2, "kernel2");
|
|
|
|
Registry::instance().register_kernel(kernel1);
|
|
Registry::instance().register_kernel(kernel2);
|
|
|
|
// Select kernel for valid problem
|
|
Problem problem(1024, 1024, 1024);
|
|
auto selected = dispatcher.select_kernel(problem);
|
|
|
|
ASSERT_NE(selected, nullptr);
|
|
// Should select a kernel that supports the problem
|
|
// (order is not guaranteed, so just verify one is selected)
|
|
EXPECT_TRUE(selected->get_name() == "kernel1" || selected->get_name() == "kernel2");
|
|
EXPECT_TRUE(selected->supports(problem));
|
|
}
|
|
|
|
TEST_F(DispatcherTest, SelectKernelInvalidProblem)
|
|
{
|
|
Dispatcher dispatcher;
|
|
|
|
// Register kernel
|
|
auto key = make_test_key(256);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel1");
|
|
Registry::instance().register_kernel(kernel);
|
|
|
|
// Invalid problem
|
|
Problem invalid_problem(0, 0, 0);
|
|
auto selected = dispatcher.select_kernel(invalid_problem);
|
|
|
|
EXPECT_EQ(selected, nullptr);
|
|
}
|
|
|
|
TEST_F(DispatcherTest, SelectKernelNoMatch)
|
|
{
|
|
Dispatcher dispatcher;
|
|
|
|
// Register kernel that doesn't support the problem
|
|
auto key = make_test_key(256);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel1", false);
|
|
Registry::instance().register_kernel(kernel);
|
|
|
|
// Problem with dimensions not divisible by tile size
|
|
Problem problem(100, 100, 100); // Not divisible by 256
|
|
auto selected = dispatcher.select_kernel(problem);
|
|
|
|
EXPECT_EQ(selected, nullptr);
|
|
}
|
|
|
|
TEST_F(DispatcherTest, SelectKernelHeuristic)
|
|
{
|
|
Dispatcher dispatcher;
|
|
|
|
// Register kernels
|
|
auto key1 = make_test_key(256);
|
|
auto key2 = make_test_key(128);
|
|
auto kernel1 = std::make_shared<MockKernelInstance>(key1, "kernel1");
|
|
auto kernel2 = std::make_shared<MockKernelInstance>(key2, "kernel2");
|
|
|
|
Registry::instance().register_kernel(kernel1);
|
|
Registry::instance().register_kernel(kernel2);
|
|
|
|
// Set heuristic that prefers kernel2
|
|
dispatcher.set_heuristic([](const Problem&) {
|
|
std::vector<std::string> candidates;
|
|
auto key2 = make_test_key(128);
|
|
candidates.push_back(key2.encode_identifier());
|
|
auto key1 = make_test_key(256);
|
|
candidates.push_back(key1.encode_identifier());
|
|
return candidates;
|
|
});
|
|
|
|
Problem problem(1024, 1024, 1024);
|
|
auto selected = dispatcher.select_kernel(problem);
|
|
|
|
ASSERT_NE(selected, nullptr);
|
|
EXPECT_EQ(selected->get_name(), "kernel2");
|
|
}
|
|
|
|
TEST_F(DispatcherTest, SelectKernelHeuristicFallback)
|
|
{
|
|
Dispatcher dispatcher;
|
|
|
|
// Register kernel
|
|
auto key = make_test_key(256);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel1");
|
|
Registry::instance().register_kernel(kernel);
|
|
|
|
// Set heuristic that returns non-existent kernel
|
|
dispatcher.set_heuristic(
|
|
[](const Problem&) { return std::vector<std::string>{"nonexistent_kernel"}; });
|
|
|
|
Problem problem(1024, 1024, 1024);
|
|
auto selected = dispatcher.select_kernel(problem);
|
|
|
|
// Should fall back to first-fit
|
|
ASSERT_NE(selected, nullptr);
|
|
EXPECT_EQ(selected->get_name(), "kernel1");
|
|
}
|
|
|
|
TEST_F(DispatcherTest, RunBasic)
|
|
{
|
|
Dispatcher dispatcher;
|
|
|
|
// Register kernel
|
|
auto key = make_test_key(256);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel1");
|
|
Registry::instance().register_kernel(kernel);
|
|
|
|
Problem problem(1024, 1024, 1024);
|
|
|
|
// Mock pointers (not actually used)
|
|
float a[1], b[1], c[1];
|
|
|
|
float time_ms = dispatcher.run(a, b, c, problem);
|
|
|
|
EXPECT_GT(time_ms, 0.0f);
|
|
EXPECT_EQ(kernel->get_execution_count(), 1);
|
|
}
|
|
|
|
TEST_F(DispatcherTest, RunNoKernel)
|
|
{
|
|
Dispatcher dispatcher;
|
|
|
|
// No kernels registered
|
|
Problem problem(1024, 1024, 1024);
|
|
|
|
float a[1], b[1], c[1];
|
|
|
|
EXPECT_THROW((void)dispatcher.run(a, b, c, problem), std::runtime_error);
|
|
}
|
|
|
|
TEST_F(DispatcherTest, RunExplicit)
|
|
{
|
|
Dispatcher dispatcher;
|
|
|
|
// Register kernel
|
|
auto key = make_test_key(256);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel1");
|
|
Registry::instance().register_kernel(kernel);
|
|
|
|
Problem problem(1024, 1024, 1024);
|
|
std::string kernel_id = key.encode_identifier();
|
|
|
|
float a[1], b[1], c[1];
|
|
|
|
float time_ms = dispatcher.run_explicit(kernel_id, a, b, c, nullptr, problem);
|
|
|
|
EXPECT_GT(time_ms, 0.0f);
|
|
EXPECT_EQ(kernel->get_execution_count(), 1);
|
|
}
|
|
|
|
TEST_F(DispatcherTest, RunExplicitNotFound)
|
|
{
|
|
Dispatcher dispatcher;
|
|
|
|
Problem problem(1024, 1024, 1024);
|
|
|
|
float a[1], b[1], c[1];
|
|
|
|
EXPECT_THROW((void)dispatcher.run_explicit("nonexistent", a, b, c, nullptr, problem),
|
|
std::runtime_error);
|
|
}
|
|
|
|
TEST_F(DispatcherTest, RunExplicitNotSupported)
|
|
{
|
|
Dispatcher dispatcher;
|
|
|
|
// Register kernel that doesn't support the problem
|
|
auto key = make_test_key(256);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel1", false);
|
|
Registry::instance().register_kernel(kernel);
|
|
|
|
Problem problem(100, 100, 100); // Not divisible by 256
|
|
std::string kernel_id = key.encode_identifier();
|
|
|
|
float a[1], b[1], c[1];
|
|
|
|
EXPECT_THROW((void)dispatcher.run_explicit(kernel_id, a, b, c, nullptr, problem),
|
|
std::runtime_error);
|
|
}
|
|
|
|
TEST_F(DispatcherTest, Validate)
|
|
{
|
|
Dispatcher dispatcher;
|
|
|
|
// Register kernel
|
|
auto key = make_test_key(256);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel1");
|
|
Registry::instance().register_kernel(kernel);
|
|
|
|
Problem problem(1024, 1024, 1024);
|
|
|
|
float a[1], b[1], c[1];
|
|
|
|
bool valid = dispatcher.validate(a, b, c, nullptr, problem);
|
|
|
|
EXPECT_TRUE(valid);
|
|
}
|
|
|
|
TEST_F(DispatcherTest, ValidateNoKernel)
|
|
{
|
|
Dispatcher dispatcher;
|
|
|
|
// No kernels registered
|
|
Problem problem(1024, 1024, 1024);
|
|
|
|
float a[1], b[1], c[1];
|
|
|
|
bool valid = dispatcher.validate(a, b, c, nullptr, problem);
|
|
|
|
EXPECT_FALSE(valid);
|
|
}
|
|
|
|
TEST_F(DispatcherTest, StrategySelection)
|
|
{
|
|
Dispatcher dispatcher;
|
|
|
|
// Register kernel
|
|
auto key = make_test_key(256);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel1");
|
|
Registry::instance().register_kernel(kernel);
|
|
|
|
Problem problem(1024, 1024, 1024);
|
|
|
|
// Test FirstFit strategy
|
|
dispatcher.set_strategy(Dispatcher::SelectionStrategy::FirstFit);
|
|
auto selected1 = dispatcher.select_kernel(problem);
|
|
ASSERT_NE(selected1, nullptr);
|
|
|
|
// Test Heuristic strategy (without heuristic function - should fallback)
|
|
dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic);
|
|
auto selected2 = dispatcher.select_kernel(problem);
|
|
ASSERT_NE(selected2, nullptr);
|
|
}
|
|
|
|
TEST_F(DispatcherTest, CustomRegistry)
|
|
{
|
|
// Create custom registry instance (not singleton)
|
|
// Note: This requires Registry to allow non-singleton instances
|
|
// For now, we'll test with a separate registry instance
|
|
// In practice, custom registry would be created differently
|
|
|
|
// Since Registry is singleton-only, we'll test that dispatcher
|
|
// can work with the singleton registry
|
|
Registry& registry = Registry::instance();
|
|
registry.clear();
|
|
|
|
auto key = make_test_key(256);
|
|
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel1");
|
|
registry.register_kernel(kernel);
|
|
|
|
// Dispatcher defaults to singleton registry
|
|
Dispatcher dispatcher;
|
|
|
|
Problem problem(1024, 1024, 1024);
|
|
auto selected = dispatcher.select_kernel(problem);
|
|
|
|
ASSERT_NE(selected, nullptr);
|
|
EXPECT_EQ(selected->get_name(), "kernel1");
|
|
}
|