mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
Adding dispatcher architecture (#3300)
* WIP POC of dispatcher
* Dispatcher python workflow setup.
* Dispatcher cleanup and updates.
Further dispatcher cleanup and updates.
Build fixes
Improvements and python to CK example
Improvements to readme
* Fixes to python paths
* Cleaning up code
* Improving dispatcher support for different arch
Fixing typos
* Fix formatting errors
* Cleaning up examples
* Improving codegeneration
* Improving and fixing C++ examples
* Adding conv functionality (fwd,bwd,bwdw) and examples.
* Fixes based on feedback.
* Further fixes based on feedback.
* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.
* Another round of improvements based on feedback.
* Trimming out unnecessary code.
* Fixing the multi-D implementation.
* Using gpu verification for gemms and fixing convolutions tflops calculation.
* Fix counter usage issue and arch filtering per ops.
* Adding changelog and other fixes.
* Improve examples and resolve critical bugs.
* Reduce build time for python examples.
* Fixing minor bug.
* Fix compilation error.
* Improve installation instructions for dispatcher.
* Add docker based installation instructions for dispatcher.
* Fixing arch-based filtering to match tile engine.
* Remove dead code and fix arch filtering.
* Minor bugfix.
* Updates after rebase.
* Trimming code.
* Fix copyright headers.
* Consolidate examples, cut down code.
* Minor fixes.
* Improving python examples.
* Update readmes.
* Remove conv functionality.
* Cleanup following conv removable.
[ROCm/composable_kernel commit: 9e049a32a1]
This commit is contained in:
committed by
GitHub
parent
6afa598838
commit
8763bbf6cf
147
dispatcher/tests/test_kernel_key.cpp
Normal file
147
dispatcher/tests/test_kernel_key.cpp
Normal file
@@ -0,0 +1,147 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/// Unit tests for KernelKey using Google Test
|
||||
|
||||
#include "ck_tile/dispatcher/kernel_key.hpp"
|
||||
#include "test_mock_kernel.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::test;
|
||||
|
||||
TEST(KernelKeyTest, Construction)
|
||||
{
|
||||
KernelKey key;
|
||||
key.signature.dtype_a = DataType::FP16;
|
||||
key.signature.dtype_b = DataType::FP16;
|
||||
key.signature.dtype_c = DataType::FP16;
|
||||
key.signature.dtype_acc = DataType::FP32;
|
||||
key.signature.elementwise_op = "PassThrough";
|
||||
key.signature.num_d_tensors = 0;
|
||||
|
||||
key.algorithm.tile_shape.m = 256;
|
||||
key.algorithm.tile_shape.n = 256;
|
||||
key.algorithm.tile_shape.k = 32;
|
||||
|
||||
key.gfx_arch = "gfx942";
|
||||
|
||||
EXPECT_EQ(key.signature.dtype_a, DataType::FP16);
|
||||
EXPECT_EQ(key.algorithm.tile_shape.m, 256);
|
||||
EXPECT_EQ(key.gfx_arch, "gfx942");
|
||||
}
|
||||
|
||||
TEST(KernelKeyTest, Equality)
|
||||
{
|
||||
// Use helper function to ensure all fields are initialized
|
||||
KernelKey key1 = make_test_key(256, 256, 32, "gfx942");
|
||||
KernelKey key2 = make_test_key(256, 256, 32, "gfx942");
|
||||
|
||||
EXPECT_EQ(key1, key2);
|
||||
EXPECT_FALSE(key1 != key2);
|
||||
|
||||
// Change one value
|
||||
KernelKey key3 = make_test_key(128, 256, 32, "gfx942");
|
||||
EXPECT_NE(key1, key3);
|
||||
EXPECT_FALSE(key1 == key3);
|
||||
}
|
||||
|
||||
TEST(KernelKeyTest, EncodeIdentifier)
|
||||
{
|
||||
KernelKey key;
|
||||
key.signature.split_k = 1;
|
||||
key.signature.elementwise_op = "PassThrough";
|
||||
key.signature.num_d_tensors = 0;
|
||||
key.algorithm.tile_shape.m = 256;
|
||||
key.algorithm.tile_shape.n = 256;
|
||||
key.algorithm.tile_shape.k = 32;
|
||||
key.algorithm.wave_shape.m = 2;
|
||||
key.algorithm.wave_shape.n = 2;
|
||||
key.algorithm.wave_shape.k = 1;
|
||||
key.algorithm.warp_tile_shape.m = 32;
|
||||
key.algorithm.warp_tile_shape.n = 32;
|
||||
key.algorithm.warp_tile_shape.k = 16;
|
||||
key.algorithm.persistent = true;
|
||||
key.algorithm.preshuffle = false;
|
||||
key.signature.structured_sparsity = false;
|
||||
|
||||
std::string id = key.encode_identifier();
|
||||
|
||||
// Check that identifier contains expected components
|
||||
EXPECT_NE(id.find("256x256x32"), std::string::npos); // tile shape
|
||||
EXPECT_NE(id.find("2x2x1"), std::string::npos); // wave shape
|
||||
EXPECT_NE(id.find("32x32x16"), std::string::npos); // warp tile shape
|
||||
EXPECT_NE(id.find("persist"), std::string::npos); // persistent flag
|
||||
}
|
||||
|
||||
TEST(KernelKeyTest, EncodeIdentifierWithFusion)
|
||||
{
|
||||
KernelKey key;
|
||||
key.signature.split_k = 1;
|
||||
key.signature.elementwise_op = "Relu";
|
||||
key.signature.num_d_tensors = 2;
|
||||
key.algorithm.tile_shape.m = 128;
|
||||
key.algorithm.tile_shape.n = 128;
|
||||
key.algorithm.tile_shape.k = 64;
|
||||
key.algorithm.wave_shape.m = 2;
|
||||
key.algorithm.wave_shape.n = 2;
|
||||
key.algorithm.wave_shape.k = 1;
|
||||
key.algorithm.warp_tile_shape.m = 16;
|
||||
key.algorithm.warp_tile_shape.n = 16;
|
||||
key.algorithm.warp_tile_shape.k = 32;
|
||||
key.algorithm.persistent = false;
|
||||
key.signature.structured_sparsity = false;
|
||||
|
||||
std::string id = key.encode_identifier();
|
||||
|
||||
// Check fusion-specific components
|
||||
EXPECT_NE(id.find("Relu"), std::string::npos);
|
||||
EXPECT_NE(id.find("_d2"), std::string::npos);
|
||||
EXPECT_NE(id.find("nopers"), std::string::npos);
|
||||
}
|
||||
|
||||
TEST(KernelKeyTest, EncodeIdentifierWithSplitK)
|
||||
{
|
||||
KernelKey key;
|
||||
key.signature.split_k = 4;
|
||||
key.signature.elementwise_op = "PassThrough";
|
||||
key.signature.num_d_tensors = 0;
|
||||
key.algorithm.tile_shape.m = 256;
|
||||
key.algorithm.tile_shape.n = 256;
|
||||
key.algorithm.tile_shape.k = 32;
|
||||
key.algorithm.wave_shape.m = 2;
|
||||
key.algorithm.wave_shape.n = 2;
|
||||
key.algorithm.wave_shape.k = 1;
|
||||
key.algorithm.warp_tile_shape.m = 32;
|
||||
key.algorithm.warp_tile_shape.n = 32;
|
||||
key.algorithm.warp_tile_shape.k = 16;
|
||||
key.algorithm.persistent = false;
|
||||
key.signature.structured_sparsity = false;
|
||||
|
||||
std::string id = key.encode_identifier();
|
||||
|
||||
EXPECT_NE(id.find("_splitk4"), std::string::npos);
|
||||
}
|
||||
|
||||
TEST(KernelKeyTest, EncodeIdentifierWithSparsity)
|
||||
{
|
||||
KernelKey key;
|
||||
key.signature.split_k = 1;
|
||||
key.signature.elementwise_op = "PassThrough";
|
||||
key.signature.num_d_tensors = 0;
|
||||
key.signature.structured_sparsity = true;
|
||||
key.algorithm.tile_shape.m = 256;
|
||||
key.algorithm.tile_shape.n = 256;
|
||||
key.algorithm.tile_shape.k = 32;
|
||||
key.algorithm.wave_shape.m = 2;
|
||||
key.algorithm.wave_shape.n = 2;
|
||||
key.algorithm.wave_shape.k = 1;
|
||||
key.algorithm.warp_tile_shape.m = 32;
|
||||
key.algorithm.warp_tile_shape.n = 32;
|
||||
key.algorithm.warp_tile_shape.k = 16;
|
||||
key.algorithm.persistent = false;
|
||||
|
||||
std::string id = key.encode_identifier();
|
||||
|
||||
EXPECT_NE(id.find("_sparse"), std::string::npos);
|
||||
}
|
||||
Reference in New Issue
Block a user