mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
454 lines
14 KiB
C++
454 lines
14 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
/// Extended unit tests for KernelKey - covers all data types, layouts, pipelines
|
|
|
|
#include "ck_tile/dispatcher/kernel_key.hpp"
|
|
#include "test_mock_kernel.hpp"
|
|
#include <gtest/gtest.h>
|
|
#include <set>
|
|
#include <sstream>
|
|
|
|
using namespace ck_tile::dispatcher;
|
|
using namespace ck_tile::dispatcher::test;
|
|
|
|
// =============================================================================
|
|
// DataType Tests
|
|
// =============================================================================
|
|
|
|
class DataTypeTest : public ::testing::Test
|
|
{
|
|
protected:
|
|
void SetUp() override {}
|
|
};
|
|
|
|
TEST_F(DataTypeTest, AllDataTypesExist)
|
|
{
|
|
// Every DataType should be accessible
|
|
std::vector<DataType> all_types = {DataType::FP16,
|
|
DataType::BF16,
|
|
DataType::FP32,
|
|
DataType::FP64,
|
|
DataType::INT8,
|
|
DataType::INT4,
|
|
DataType::INT32,
|
|
DataType::FP8,
|
|
DataType::BF8,
|
|
DataType::UNKNOWN};
|
|
|
|
EXPECT_EQ(all_types.size(), 10);
|
|
}
|
|
|
|
TEST_F(DataTypeTest, DataTypesAreDifferent)
|
|
{
|
|
EXPECT_NE(DataType::FP16, DataType::BF16);
|
|
EXPECT_NE(DataType::FP16, DataType::FP32);
|
|
EXPECT_NE(DataType::INT8, DataType::INT4);
|
|
}
|
|
|
|
// =============================================================================
|
|
// LayoutTag Tests
|
|
// =============================================================================
|
|
|
|
class LayoutTagTest : public ::testing::Test
|
|
{
|
|
};
|
|
|
|
TEST_F(LayoutTagTest, AllLayoutsExist)
|
|
{
|
|
std::vector<LayoutTag> all_layouts = {
|
|
LayoutTag::RowMajor, LayoutTag::ColMajor, LayoutTag::PackedExternal};
|
|
|
|
EXPECT_EQ(all_layouts.size(), 3);
|
|
}
|
|
|
|
TEST_F(LayoutTagTest, LayoutsAreDifferent) { EXPECT_NE(LayoutTag::RowMajor, LayoutTag::ColMajor); }
|
|
|
|
// =============================================================================
|
|
// Pipeline Tests
|
|
// =============================================================================
|
|
|
|
class PipelineTest : public ::testing::Test
|
|
{
|
|
};
|
|
|
|
TEST_F(PipelineTest, AllPipelinesExist)
|
|
{
|
|
std::vector<Pipeline> all_pipelines = {Pipeline::Mem,
|
|
Pipeline::CompV1,
|
|
Pipeline::CompV2,
|
|
Pipeline::CompV3,
|
|
Pipeline::CompV4,
|
|
Pipeline::CompV5,
|
|
Pipeline::PreShuffleV1,
|
|
Pipeline::PreShuffleV2};
|
|
|
|
EXPECT_EQ(all_pipelines.size(), 8);
|
|
}
|
|
|
|
TEST_F(PipelineTest, PipelinesAreDifferent)
|
|
{
|
|
EXPECT_NE(Pipeline::Mem, Pipeline::CompV4);
|
|
EXPECT_NE(Pipeline::CompV3, Pipeline::CompV4);
|
|
}
|
|
|
|
// =============================================================================
|
|
// Scheduler Tests
|
|
// =============================================================================
|
|
|
|
class SchedulerTest : public ::testing::Test
|
|
{
|
|
};
|
|
|
|
TEST_F(SchedulerTest, AllSchedulersExist)
|
|
{
|
|
std::vector<Scheduler> all_schedulers = {
|
|
Scheduler::Auto, Scheduler::Intrawave, Scheduler::Interwave};
|
|
|
|
EXPECT_EQ(all_schedulers.size(), 3);
|
|
}
|
|
|
|
// =============================================================================
|
|
// Epilogue Tests
|
|
// =============================================================================
|
|
|
|
class EpilogueTest : public ::testing::Test
|
|
{
|
|
};
|
|
|
|
TEST_F(EpilogueTest, AllEpiloguesExist)
|
|
{
|
|
std::vector<Epilogue> all_epilogues = {Epilogue::None,
|
|
Epilogue::Default,
|
|
Epilogue::CShuffle,
|
|
Epilogue::Bias,
|
|
Epilogue::Activation,
|
|
Epilogue::BiasActivation};
|
|
|
|
EXPECT_EQ(all_epilogues.size(), 6);
|
|
}
|
|
|
|
// =============================================================================
|
|
// KernelKey::Signature Tests
|
|
// =============================================================================
|
|
|
|
class SignatureTest : public ::testing::Test
|
|
{
|
|
protected:
|
|
KernelKey::Signature CreateDefaultSignature()
|
|
{
|
|
KernelKey::Signature sig;
|
|
sig.dtype_a = DataType::FP16;
|
|
sig.dtype_b = DataType::FP16;
|
|
sig.dtype_c = DataType::FP16;
|
|
sig.dtype_acc = DataType::FP32;
|
|
sig.layout_a = LayoutTag::RowMajor;
|
|
sig.layout_b = LayoutTag::ColMajor;
|
|
sig.layout_c = LayoutTag::RowMajor;
|
|
sig.transpose_a = false;
|
|
sig.transpose_b = false;
|
|
sig.grouped = false;
|
|
sig.split_k = 1;
|
|
sig.elementwise_op = "PassThrough";
|
|
sig.num_d_tensors = 0;
|
|
sig.structured_sparsity = false;
|
|
return sig;
|
|
}
|
|
};
|
|
|
|
TEST_F(SignatureTest, DefaultValuesAreReasonable)
|
|
{
|
|
KernelKey::Signature sig = CreateDefaultSignature();
|
|
EXPECT_EQ(sig.split_k, 1);
|
|
EXPECT_FALSE(sig.grouped);
|
|
EXPECT_FALSE(sig.structured_sparsity);
|
|
}
|
|
|
|
TEST_F(SignatureTest, AllDataTypeCombinations)
|
|
{
|
|
// Test various data type combinations that should be valid
|
|
std::vector<std::tuple<DataType, DataType, DataType, DataType>> valid_combos = {
|
|
{DataType::FP16, DataType::FP16, DataType::FP16, DataType::FP32},
|
|
{DataType::BF16, DataType::BF16, DataType::BF16, DataType::FP32},
|
|
{DataType::FP32, DataType::FP32, DataType::FP32, DataType::FP32},
|
|
{DataType::INT8, DataType::INT8, DataType::INT8, DataType::INT32},
|
|
};
|
|
|
|
for(const auto& [a, b, c, acc] : valid_combos)
|
|
{
|
|
KernelKey::Signature sig;
|
|
sig.dtype_a = a;
|
|
sig.dtype_b = b;
|
|
sig.dtype_c = c;
|
|
sig.dtype_acc = acc;
|
|
|
|
EXPECT_EQ(sig.dtype_a, a);
|
|
EXPECT_EQ(sig.dtype_b, b);
|
|
EXPECT_EQ(sig.dtype_c, c);
|
|
EXPECT_EQ(sig.dtype_acc, acc);
|
|
}
|
|
}
|
|
|
|
TEST_F(SignatureTest, AllLayoutCombinations)
|
|
{
|
|
std::vector<std::string> layout_codes = {
|
|
"rrr", "rcr", "crr", "ccr", "rrc", "rcc", "crc", "ccc"};
|
|
|
|
for(const std::string& code : layout_codes)
|
|
{
|
|
KernelKey::Signature sig = CreateDefaultSignature();
|
|
sig.layout_a = (code[0] == 'r') ? LayoutTag::RowMajor : LayoutTag::ColMajor;
|
|
sig.layout_b = (code[1] == 'r') ? LayoutTag::RowMajor : LayoutTag::ColMajor;
|
|
sig.layout_c = (code[2] == 'r') ? LayoutTag::RowMajor : LayoutTag::ColMajor;
|
|
|
|
// Just verify assignment works
|
|
EXPECT_TRUE(sig.layout_a == LayoutTag::RowMajor || sig.layout_a == LayoutTag::ColMajor);
|
|
}
|
|
}
|
|
|
|
TEST_F(SignatureTest, SplitKValues)
|
|
{
|
|
KernelKey::Signature sig = CreateDefaultSignature();
|
|
|
|
std::vector<std::uint8_t> valid_split_k = {1, 2, 4, 8, 16};
|
|
for(auto sk : valid_split_k)
|
|
{
|
|
sig.split_k = sk;
|
|
EXPECT_EQ(sig.split_k, sk);
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// KernelKey::Algorithm Tests
|
|
// =============================================================================
|
|
|
|
class AlgorithmTest : public ::testing::Test
|
|
{
|
|
protected:
|
|
KernelKey::Algorithm CreateDefaultAlgorithm()
|
|
{
|
|
KernelKey::Algorithm algo;
|
|
algo.tile_shape = {256, 256, 32};
|
|
algo.wave_shape = {2, 2, 1};
|
|
algo.warp_tile_shape = {32, 32, 16};
|
|
algo.pipeline = Pipeline::CompV4;
|
|
algo.scheduler = Scheduler::Intrawave;
|
|
algo.epilogue = Epilogue::CShuffle;
|
|
algo.block_size = 256;
|
|
algo.double_buffer = true;
|
|
algo.persistent = false;
|
|
algo.preshuffle = false;
|
|
algo.transpose_c = false;
|
|
algo.num_wave_groups = 1;
|
|
return algo;
|
|
}
|
|
};
|
|
|
|
TEST_F(AlgorithmTest, CommonTileShapes)
|
|
{
|
|
std::vector<std::tuple<int, int, int>> valid_tiles = {
|
|
{64, 64, 32},
|
|
{128, 128, 32},
|
|
{128, 128, 64},
|
|
{256, 256, 32},
|
|
{256, 256, 64},
|
|
{256, 128, 32},
|
|
{128, 256, 32},
|
|
};
|
|
|
|
for(const auto& [m, n, k] : valid_tiles)
|
|
{
|
|
KernelKey::Algorithm algo = CreateDefaultAlgorithm();
|
|
algo.tile_shape = {static_cast<std::uint16_t>(m),
|
|
static_cast<std::uint16_t>(n),
|
|
static_cast<std::uint16_t>(k)};
|
|
|
|
EXPECT_EQ(algo.tile_shape.m, m);
|
|
EXPECT_EQ(algo.tile_shape.n, n);
|
|
EXPECT_EQ(algo.tile_shape.k, k);
|
|
}
|
|
}
|
|
|
|
TEST_F(AlgorithmTest, CommonWarpConfigs)
|
|
{
|
|
std::vector<std::tuple<int, int, int>> valid_warps = {
|
|
{1, 4, 1},
|
|
{2, 2, 1},
|
|
{4, 1, 1},
|
|
{1, 2, 1},
|
|
{2, 1, 1},
|
|
};
|
|
|
|
for(const auto& [m, n, k] : valid_warps)
|
|
{
|
|
KernelKey::Algorithm algo = CreateDefaultAlgorithm();
|
|
algo.wave_shape = {static_cast<std::uint8_t>(m),
|
|
static_cast<std::uint8_t>(n),
|
|
static_cast<std::uint8_t>(k)};
|
|
|
|
EXPECT_EQ(algo.wave_shape.m, m);
|
|
EXPECT_EQ(algo.wave_shape.n, n);
|
|
EXPECT_EQ(algo.wave_shape.k, k);
|
|
}
|
|
}
|
|
|
|
TEST_F(AlgorithmTest, AllPipelines)
|
|
{
|
|
KernelKey::Algorithm algo = CreateDefaultAlgorithm();
|
|
|
|
std::vector<Pipeline> pipelines = {Pipeline::Mem,
|
|
Pipeline::CompV3,
|
|
Pipeline::CompV4,
|
|
Pipeline::PreShuffleV1,
|
|
Pipeline::PreShuffleV2};
|
|
|
|
for(Pipeline p : pipelines)
|
|
{
|
|
algo.pipeline = p;
|
|
EXPECT_EQ(algo.pipeline, p);
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// KernelKey Identifier Encoding Tests
|
|
// =============================================================================
|
|
|
|
class IdentifierEncodingTest : public ::testing::Test
|
|
{
|
|
};
|
|
|
|
TEST_F(IdentifierEncodingTest, UniqueIdentifiersForDifferentConfigs)
|
|
{
|
|
std::set<std::string> identifiers;
|
|
|
|
// Generate multiple configurations
|
|
for(int tile_m : {128, 256})
|
|
{
|
|
for(int wave_m : {1, 2, 4})
|
|
{
|
|
for(bool persistent : {true, false})
|
|
{
|
|
KernelKey key = make_test_key(tile_m);
|
|
key.algorithm.wave_shape.m = wave_m;
|
|
key.algorithm.persistent = persistent;
|
|
|
|
std::string id = key.encode_identifier();
|
|
EXPECT_TRUE(identifiers.find(id) == identifiers.end())
|
|
<< "Duplicate identifier: " << id;
|
|
identifiers.insert(id);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Should have generated 2 * 3 * 2 = 12 unique identifiers
|
|
EXPECT_EQ(identifiers.size(), 12);
|
|
}
|
|
|
|
TEST_F(IdentifierEncodingTest, IdentifierContainsTileShape)
|
|
{
|
|
KernelKey key = make_test_key(256, 128, 64);
|
|
std::string id = key.encode_identifier();
|
|
|
|
EXPECT_NE(id.find("256x128x64"), std::string::npos)
|
|
<< "Identifier should contain tile shape: " << id;
|
|
}
|
|
|
|
TEST_F(IdentifierEncodingTest, IdentifierContainsWarpConfig)
|
|
{
|
|
KernelKey key = make_test_key(256);
|
|
key.algorithm.wave_shape = {4, 2, 1};
|
|
std::string id = key.encode_identifier();
|
|
|
|
EXPECT_NE(id.find("4x2x1"), std::string::npos)
|
|
<< "Identifier should contain warp config: " << id;
|
|
}
|
|
|
|
TEST_F(IdentifierEncodingTest, IdentifierReflectsPersistence)
|
|
{
|
|
KernelKey persistent_key = make_test_key(256);
|
|
persistent_key.algorithm.persistent = true;
|
|
|
|
KernelKey non_persistent_key = make_test_key(256);
|
|
non_persistent_key.algorithm.persistent = false;
|
|
|
|
std::string persistent_id = persistent_key.encode_identifier();
|
|
std::string non_persistent_id = non_persistent_key.encode_identifier();
|
|
|
|
EXPECT_NE(persistent_id, non_persistent_id);
|
|
EXPECT_NE(persistent_id.find("persist"), std::string::npos);
|
|
EXPECT_NE(non_persistent_id.find("nopers"), std::string::npos);
|
|
}
|
|
|
|
// =============================================================================
|
|
// KernelKey Equality Tests
|
|
// =============================================================================
|
|
|
|
class KeyEqualityTest : public ::testing::Test
|
|
{
|
|
};
|
|
|
|
TEST_F(KeyEqualityTest, IdenticalKeysAreEqual)
|
|
{
|
|
KernelKey key1 = make_test_key(256, 256, 32, "gfx942");
|
|
KernelKey key2 = make_test_key(256, 256, 32, "gfx942");
|
|
|
|
EXPECT_EQ(key1, key2);
|
|
EXPECT_FALSE(key1 != key2);
|
|
}
|
|
|
|
TEST_F(KeyEqualityTest, DifferentTileShapesNotEqual)
|
|
{
|
|
KernelKey key1 = make_test_key(256, 256, 32);
|
|
KernelKey key2 = make_test_key(128, 128, 32);
|
|
|
|
EXPECT_NE(key1, key2);
|
|
}
|
|
|
|
TEST_F(KeyEqualityTest, DifferentDataTypesNotEqual)
|
|
{
|
|
KernelKey key1 = make_test_key(256);
|
|
KernelKey key2 = make_test_key(256);
|
|
key2.signature.dtype_a = DataType::BF16;
|
|
|
|
EXPECT_NE(key1, key2);
|
|
}
|
|
|
|
TEST_F(KeyEqualityTest, DifferentLayoutsNotEqual)
|
|
{
|
|
KernelKey key1 = make_test_key(256);
|
|
KernelKey key2 = make_test_key(256);
|
|
key2.signature.layout_a = LayoutTag::ColMajor;
|
|
|
|
EXPECT_NE(key1, key2);
|
|
}
|
|
|
|
TEST_F(KeyEqualityTest, DifferentGfxArchNotEqual)
|
|
{
|
|
KernelKey key1 = make_test_key(256, 256, 32, "gfx942");
|
|
KernelKey key2 = make_test_key(256, 256, 32, "gfx90a");
|
|
|
|
EXPECT_NE(key1, key2);
|
|
}
|
|
|
|
// =============================================================================
|
|
// ElementwiseOps Tests
|
|
// =============================================================================
|
|
|
|
class ElementwiseOpsTest : public ::testing::Test
|
|
{
|
|
};
|
|
|
|
TEST_F(ElementwiseOpsTest, CanUseInKernelKey)
|
|
{
|
|
KernelKey key = make_test_key(256);
|
|
|
|
key.signature.elementwise_op = "Relu";
|
|
EXPECT_EQ(key.signature.elementwise_op, "Relu");
|
|
|
|
key.signature.elementwise_op = "Gelu";
|
|
EXPECT_EQ(key.signature.elementwise_op, "Gelu");
|
|
|
|
key.signature.elementwise_op = "PassThrough";
|
|
EXPECT_EQ(key.signature.elementwise_op, "PassThrough");
|
|
}
|