Files
composable_kernel/dispatcher/tests/test_dispatcher_extended.cpp
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

500 lines
15 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/// Extended unit tests for Dispatcher - covers selection strategies, heuristics, edge cases
#include "ck_tile/dispatcher/dispatcher.hpp"
#include "ck_tile/dispatcher/registry.hpp"
#include "test_mock_kernel.hpp"
#include <gtest/gtest.h>
#include <algorithm>
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::test;
using SelectionStrategy = Dispatcher::SelectionStrategy;
// =============================================================================
// Basic Dispatcher Tests
// =============================================================================
class DispatcherBasicTest : public ::testing::Test
{
protected:
void SetUp() override { Registry::instance().clear(); }
void TearDown() override { Registry::instance().clear(); }
};
TEST_F(DispatcherBasicTest, DefaultConstruction)
{
Dispatcher dispatcher;
// Should not crash
SUCCEED();
}
TEST_F(DispatcherBasicTest, SelectKernelEmpty)
{
Dispatcher dispatcher;
Problem problem(1024, 1024, 1024);
auto kernel = dispatcher.select_kernel(problem);
EXPECT_EQ(kernel, nullptr);
}
TEST_F(DispatcherBasicTest, SelectKernelSingle)
{
auto key = make_test_key(256);
auto kernel = std::make_shared<MockKernelInstance>(key, "test_kernel");
Registry::instance().register_kernel(kernel);
Dispatcher dispatcher;
Problem problem(1024, 1024, 1024);
auto selected = dispatcher.select_kernel(problem);
ASSERT_NE(selected, nullptr);
EXPECT_EQ(selected->get_name(), "test_kernel");
}
TEST_F(DispatcherBasicTest, SelectKernelMultiple)
{
// Register multiple kernels
for(int tile : {128, 256, 512})
{
auto key = make_test_key(tile);
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel_" + std::to_string(tile));
Registry::instance().register_kernel(kernel);
}
Dispatcher dispatcher;
Problem problem(1024, 1024, 1024);
auto selected = dispatcher.select_kernel(problem);
ASSERT_NE(selected, nullptr);
// Should select one of the registered kernels
EXPECT_TRUE(selected->get_name() == "kernel_128" || selected->get_name() == "kernel_256" ||
selected->get_name() == "kernel_512");
}
// =============================================================================
// Selection Strategy Tests
// =============================================================================
class SelectionStrategyTest : public ::testing::Test
{
protected:
void SetUp() override
{
Registry::instance().clear();
// Register kernels with different tile sizes
for(int tile : {128, 256, 512})
{
auto key = make_test_key(tile);
auto kernel =
std::make_shared<MockKernelInstance>(key, "kernel_" + std::to_string(tile));
Registry::instance().register_kernel(kernel);
}
}
void TearDown() override { Registry::instance().clear(); }
};
TEST_F(SelectionStrategyTest, FirstFitStrategy)
{
Dispatcher dispatcher;
dispatcher.set_strategy(SelectionStrategy::FirstFit);
Problem problem(1024, 1024, 1024);
auto selected = dispatcher.select_kernel(problem);
ASSERT_NE(selected, nullptr);
// FirstFit returns first matching kernel
}
TEST_F(SelectionStrategyTest, HeuristicStrategy)
{
Dispatcher dispatcher;
// Set heuristic that prefers larger tiles for large problems
dispatcher.set_heuristic([](const Problem& p) -> std::vector<std::string> {
if(p.M >= 1024 && p.N >= 1024)
{
// For large problems, prefer 512 tile
auto key = make_test_key(512);
return {key.encode_identifier()};
}
// For small problems, prefer 128 tile
auto key = make_test_key(128);
return {key.encode_identifier()};
});
dispatcher.set_strategy(SelectionStrategy::Heuristic);
// Large problem should get 512 tile
Problem large_problem(2048, 2048, 2048);
auto selected = dispatcher.select_kernel(large_problem);
ASSERT_NE(selected, nullptr);
EXPECT_EQ(selected->get_name(), "kernel_512");
// Small problem should get 128 tile
Problem small_problem(256, 256, 256);
selected = dispatcher.select_kernel(small_problem);
ASSERT_NE(selected, nullptr);
EXPECT_EQ(selected->get_name(), "kernel_128");
}
TEST_F(SelectionStrategyTest, HeuristicWithFallback)
{
Dispatcher dispatcher;
// Heuristic returns non-existent kernel first, then valid one
dispatcher.set_heuristic([](const Problem& p) -> std::vector<std::string> {
auto key = make_test_key(256);
return {"nonexistent_kernel", key.encode_identifier()};
});
dispatcher.set_strategy(SelectionStrategy::Heuristic);
Problem problem(1024, 1024, 1024);
auto selected = dispatcher.select_kernel(problem);
ASSERT_NE(selected, nullptr);
EXPECT_EQ(selected->get_name(), "kernel_256");
}
TEST_F(SelectionStrategyTest, SwitchBetweenStrategies)
{
Dispatcher dispatcher;
// Start with FirstFit
dispatcher.set_strategy(SelectionStrategy::FirstFit);
Problem problem(1024, 1024, 1024);
auto selected1 = dispatcher.select_kernel(problem);
ASSERT_NE(selected1, nullptr);
// Switch to Heuristic
dispatcher.set_heuristic([](const Problem& p) -> std::vector<std::string> {
auto key = make_test_key(256);
return {key.encode_identifier()};
});
dispatcher.set_strategy(SelectionStrategy::Heuristic);
auto selected2 = dispatcher.select_kernel(problem);
ASSERT_NE(selected2, nullptr);
}
// =============================================================================
// Heuristic Function Tests
// =============================================================================
class HeuristicTest : public ::testing::Test
{
protected:
void SetUp() override
{
Registry::instance().clear();
for(int tile : {64, 128, 256, 512})
{
auto key = make_test_key(tile);
auto kernel =
std::make_shared<MockKernelInstance>(key, "kernel_" + std::to_string(tile));
Registry::instance().register_kernel(kernel);
}
}
void TearDown() override { Registry::instance().clear(); }
};
TEST_F(HeuristicTest, SizeBasedHeuristic)
{
Dispatcher dispatcher;
dispatcher.set_heuristic([](const Problem& p) -> std::vector<std::string> {
std::vector<std::string> candidates;
// Problem-size based selection
int size = p.M * p.N * p.K;
if(size >= 1024 * 1024 * 1024)
{
candidates.push_back(make_test_key(512).encode_identifier());
candidates.push_back(make_test_key(256).encode_identifier());
}
else if(size >= 256 * 256 * 256)
{
candidates.push_back(make_test_key(256).encode_identifier());
candidates.push_back(make_test_key(128).encode_identifier());
}
else
{
candidates.push_back(make_test_key(64).encode_identifier());
candidates.push_back(make_test_key(128).encode_identifier());
}
return candidates;
});
dispatcher.set_strategy(SelectionStrategy::Heuristic);
// Large problem
auto selected = dispatcher.select_kernel(Problem(1024, 1024, 1024));
ASSERT_NE(selected, nullptr);
EXPECT_EQ(selected->get_name(), "kernel_512");
// Medium problem
selected = dispatcher.select_kernel(Problem(256, 256, 256));
ASSERT_NE(selected, nullptr);
EXPECT_EQ(selected->get_name(), "kernel_256");
// Small problem
selected = dispatcher.select_kernel(Problem(64, 64, 64));
ASSERT_NE(selected, nullptr);
EXPECT_EQ(selected->get_name(), "kernel_64");
}
TEST_F(HeuristicTest, EmptyHeuristicFallsBackToFirstFit)
{
Dispatcher dispatcher;
dispatcher.set_heuristic([](const Problem& p) -> std::vector<std::string> {
return {}; // Empty list
});
dispatcher.set_strategy(SelectionStrategy::Heuristic);
Problem problem(1024, 1024, 1024);
auto selected = dispatcher.select_kernel(problem);
// Should fall back to FirstFit
ASSERT_NE(selected, nullptr);
}
TEST_F(HeuristicTest, InvalidHeuristicFallsBackToFirstFit)
{
Dispatcher dispatcher;
dispatcher.set_heuristic([](const Problem& p) -> std::vector<std::string> {
return {"invalid_kernel_1", "invalid_kernel_2"}; // All invalid
});
dispatcher.set_strategy(SelectionStrategy::Heuristic);
Problem problem(1024, 1024, 1024);
auto selected = dispatcher.select_kernel(problem);
// Should fall back to FirstFit
ASSERT_NE(selected, nullptr);
}
// =============================================================================
// Dispatcher with Custom Registry Tests
// =============================================================================
class DispatcherCustomRegistryTest : public ::testing::Test
{
protected:
void TearDown() override { Registry::instance().clear(); }
};
TEST_F(DispatcherCustomRegistryTest, UseCustomRegistry)
{
Registry custom_registry;
custom_registry.set_name("custom");
auto key = make_test_key(256);
auto kernel = std::make_shared<MockKernelInstance>(key, "custom_kernel");
custom_registry.register_kernel(kernel);
Dispatcher dispatcher(&custom_registry);
Problem problem(1024, 1024, 1024);
auto selected = dispatcher.select_kernel(problem);
ASSERT_NE(selected, nullptr);
EXPECT_EQ(selected->get_name(), "custom_kernel");
}
TEST_F(DispatcherCustomRegistryTest, CustomRegistryIsolation)
{
Registry custom_registry;
auto key_custom = make_test_key(256);
auto key_global = make_test_key(512);
custom_registry.register_kernel(
std::make_shared<MockKernelInstance>(key_custom, "custom_kernel"));
Registry::instance().register_kernel(
std::make_shared<MockKernelInstance>(key_global, "global_kernel"));
Dispatcher custom_dispatcher(&custom_registry);
Dispatcher global_dispatcher;
Problem problem(1024, 1024, 1024);
auto custom_selected = custom_dispatcher.select_kernel(problem);
auto global_selected = global_dispatcher.select_kernel(problem);
ASSERT_NE(custom_selected, nullptr);
ASSERT_NE(global_selected, nullptr);
EXPECT_EQ(custom_selected->get_name(), "custom_kernel");
EXPECT_EQ(global_selected->get_name(), "global_kernel");
}
// =============================================================================
// Edge Cases Tests
// =============================================================================
class DispatcherEdgeCasesTest : public ::testing::Test
{
protected:
void SetUp() override { Registry::instance().clear(); }
void TearDown() override { Registry::instance().clear(); }
};
TEST_F(DispatcherEdgeCasesTest, InvalidProblem)
{
auto key = make_test_key(256);
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel");
Registry::instance().register_kernel(kernel);
Dispatcher dispatcher;
// Zero dimensions
Problem invalid(0, 1024, 1024);
EXPECT_FALSE(invalid.is_valid());
// The dispatcher should still attempt selection
// (validation is up to the kernel's supports() method)
}
TEST_F(DispatcherEdgeCasesTest, KernelDoesNotSupportProblem)
{
auto key = make_test_key(256);
auto kernel = std::make_shared<MockKernelInstance>(key, "selective_kernel", false);
Registry::instance().register_kernel(kernel);
Dispatcher dispatcher;
// Problem not divisible by tile size - kernel doesn't support it
Problem problem(1000, 1000, 1000); // Not divisible by 256
auto selected = dispatcher.select_kernel(problem);
// Should return nullptr since kernel doesn't support this problem
EXPECT_EQ(selected, nullptr);
}
TEST_F(DispatcherEdgeCasesTest, MultipleSelectionsConsistent)
{
auto key = make_test_key(256);
auto kernel = std::make_shared<MockKernelInstance>(key, "kernel");
Registry::instance().register_kernel(kernel);
Dispatcher dispatcher;
Problem problem(1024, 1024, 1024);
// Multiple selections should return the same kernel
auto selected1 = dispatcher.select_kernel(problem);
auto selected2 = dispatcher.select_kernel(problem);
auto selected3 = dispatcher.select_kernel(problem);
ASSERT_NE(selected1, nullptr);
EXPECT_EQ(selected1, selected2);
EXPECT_EQ(selected2, selected3);
}
// =============================================================================
// Validate Method Tests
// =============================================================================
class DispatcherValidateTest : public ::testing::Test
{
protected:
void SetUp() override
{
Registry::instance().clear();
auto key = make_test_key(256);
kernel_ = std::make_shared<MockKernelInstance>(key, "kernel");
Registry::instance().register_kernel(kernel_);
}
void TearDown() override { Registry::instance().clear(); }
std::shared_ptr<MockKernelInstance> kernel_;
};
TEST_F(DispatcherValidateTest, ValidateWithMockKernel)
{
Dispatcher dispatcher;
Problem problem(1024, 1024, 1024);
// MockKernelInstance always validates successfully
bool valid = dispatcher.validate(nullptr, nullptr, nullptr, nullptr, problem);
// This depends on implementation - mock returns true
// Real validation would need actual data
}
// =============================================================================
// Run Method Tests (with mock)
// =============================================================================
class DispatcherRunTest : public ::testing::Test
{
protected:
void SetUp() override
{
Registry::instance().clear();
auto key = make_test_key(256);
kernel_ = std::make_shared<MockKernelInstance>(key, "kernel");
Registry::instance().register_kernel(kernel_);
}
void TearDown() override { Registry::instance().clear(); }
std::shared_ptr<MockKernelInstance> kernel_;
};
TEST_F(DispatcherRunTest, RunWithMockKernel)
{
Dispatcher dispatcher;
Problem problem(1024, 1024, 1024);
// Mock run (with null pointers - mock doesn't use them)
float time = dispatcher.run(nullptr, nullptr, nullptr, problem);
// Mock kernel returns 1.0f
EXPECT_FLOAT_EQ(time, 1.0f);
// Verify execution count
EXPECT_EQ(kernel_->get_execution_count(), 1);
}
TEST_F(DispatcherRunTest, MultipleRuns)
{
Dispatcher dispatcher;
Problem problem(1024, 1024, 1024);
for(int i = 0; i < 10; i++)
{
(void)dispatcher.run(nullptr, nullptr, nullptr, problem);
}
EXPECT_EQ(kernel_->get_execution_count(), 10);
}
TEST_F(DispatcherRunTest, RunWithNoKernelThrows)
{
Registry::instance().clear();
Dispatcher dispatcher;
Problem problem(1024, 1024, 1024);
// Should throw when no kernel found
EXPECT_THROW((void)dispatcher.run(nullptr, nullptr, nullptr, problem), std::runtime_error);
}