Files
composable_kernel/dispatcher/tests/test_kernel_key.cpp
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

148 lines
5.0 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/// Unit tests for KernelKey using Google Test
#include "ck_tile/dispatcher/kernel_key.hpp"
#include "test_mock_kernel.hpp"
#include <gtest/gtest.h>
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::test;
TEST(KernelKeyTest, Construction)
{
KernelKey key;
key.signature.dtype_a = DataType::FP16;
key.signature.dtype_b = DataType::FP16;
key.signature.dtype_c = DataType::FP16;
key.signature.dtype_acc = DataType::FP32;
key.signature.elementwise_op = "PassThrough";
key.signature.num_d_tensors = 0;
key.algorithm.tile_shape.m = 256;
key.algorithm.tile_shape.n = 256;
key.algorithm.tile_shape.k = 32;
key.gfx_arch = "gfx942";
EXPECT_EQ(key.signature.dtype_a, DataType::FP16);
EXPECT_EQ(key.algorithm.tile_shape.m, 256);
EXPECT_EQ(key.gfx_arch, "gfx942");
}
TEST(KernelKeyTest, Equality)
{
// Use helper function to ensure all fields are initialized
KernelKey key1 = make_test_key(256, 256, 32, "gfx942");
KernelKey key2 = make_test_key(256, 256, 32, "gfx942");
EXPECT_EQ(key1, key2);
EXPECT_FALSE(key1 != key2);
// Change one value
KernelKey key3 = make_test_key(128, 256, 32, "gfx942");
EXPECT_NE(key1, key3);
EXPECT_FALSE(key1 == key3);
}
TEST(KernelKeyTest, EncodeIdentifier)
{
KernelKey key;
key.signature.split_k = 1;
key.signature.elementwise_op = "PassThrough";
key.signature.num_d_tensors = 0;
key.algorithm.tile_shape.m = 256;
key.algorithm.tile_shape.n = 256;
key.algorithm.tile_shape.k = 32;
key.algorithm.wave_shape.m = 2;
key.algorithm.wave_shape.n = 2;
key.algorithm.wave_shape.k = 1;
key.algorithm.warp_tile_shape.m = 32;
key.algorithm.warp_tile_shape.n = 32;
key.algorithm.warp_tile_shape.k = 16;
key.algorithm.persistent = true;
key.algorithm.preshuffle = false;
key.signature.structured_sparsity = false;
std::string id = key.encode_identifier();
// Check that identifier contains expected components
EXPECT_NE(id.find("256x256x32"), std::string::npos); // tile shape
EXPECT_NE(id.find("2x2x1"), std::string::npos); // wave shape
EXPECT_NE(id.find("32x32x16"), std::string::npos); // warp tile shape
EXPECT_NE(id.find("persist"), std::string::npos); // persistent flag
}
TEST(KernelKeyTest, EncodeIdentifierWithFusion)
{
KernelKey key;
key.signature.split_k = 1;
key.signature.elementwise_op = "Relu";
key.signature.num_d_tensors = 2;
key.algorithm.tile_shape.m = 128;
key.algorithm.tile_shape.n = 128;
key.algorithm.tile_shape.k = 64;
key.algorithm.wave_shape.m = 2;
key.algorithm.wave_shape.n = 2;
key.algorithm.wave_shape.k = 1;
key.algorithm.warp_tile_shape.m = 16;
key.algorithm.warp_tile_shape.n = 16;
key.algorithm.warp_tile_shape.k = 32;
key.algorithm.persistent = false;
key.signature.structured_sparsity = false;
std::string id = key.encode_identifier();
// Check fusion-specific components
EXPECT_NE(id.find("Relu"), std::string::npos);
EXPECT_NE(id.find("_d2"), std::string::npos);
EXPECT_NE(id.find("nopers"), std::string::npos);
}
TEST(KernelKeyTest, EncodeIdentifierWithSplitK)
{
KernelKey key;
key.signature.split_k = 4;
key.signature.elementwise_op = "PassThrough";
key.signature.num_d_tensors = 0;
key.algorithm.tile_shape.m = 256;
key.algorithm.tile_shape.n = 256;
key.algorithm.tile_shape.k = 32;
key.algorithm.wave_shape.m = 2;
key.algorithm.wave_shape.n = 2;
key.algorithm.wave_shape.k = 1;
key.algorithm.warp_tile_shape.m = 32;
key.algorithm.warp_tile_shape.n = 32;
key.algorithm.warp_tile_shape.k = 16;
key.algorithm.persistent = false;
key.signature.structured_sparsity = false;
std::string id = key.encode_identifier();
EXPECT_NE(id.find("_splitk4"), std::string::npos);
}
TEST(KernelKeyTest, EncodeIdentifierWithSparsity)
{
KernelKey key;
key.signature.split_k = 1;
key.signature.elementwise_op = "PassThrough";
key.signature.num_d_tensors = 0;
key.signature.structured_sparsity = true;
key.algorithm.tile_shape.m = 256;
key.algorithm.tile_shape.n = 256;
key.algorithm.tile_shape.k = 32;
key.algorithm.wave_shape.m = 2;
key.algorithm.wave_shape.n = 2;
key.algorithm.wave_shape.k = 1;
key.algorithm.warp_tile_shape.m = 32;
key.algorithm.warp_tile_shape.n = 32;
key.algorithm.warp_tile_shape.k = 16;
key.algorithm.persistent = false;
std::string id = key.encode_identifier();
EXPECT_NE(id.find("_sparse"), std::string::npos);
}