mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-29 19:47:39 +00:00
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
458 lines
11 KiB
C++
458 lines
11 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||
// SPDX-License-Identifier: MIT
|
||
|
||
/// Extended unit tests for Problem - covers dimension inference, validation, edge cases
|
||
|
||
#include "ck_tile/dispatcher/problem.hpp"
|
||
#include <gtest/gtest.h>
|
||
#include <limits>
|
||
|
||
using namespace ck_tile::dispatcher;
|
||
|
||
// =============================================================================
|
||
// Dimension Inference Tests
|
||
// =============================================================================
|
||
|
||
class ProblemDimensionInferenceTest : public ::testing::Test
|
||
{
|
||
};
|
||
|
||
TEST_F(ProblemDimensionInferenceTest, FromAB_Basic)
|
||
{
|
||
// A: M×K (1024×512), B: K×N (512×2048)
|
||
auto problem = Problem::from_ab(1024, 512, 512, 2048);
|
||
|
||
EXPECT_EQ(problem.M, 1024);
|
||
EXPECT_EQ(problem.N, 2048);
|
||
EXPECT_EQ(problem.K, 512);
|
||
EXPECT_TRUE(problem.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemDimensionInferenceTest, FromDimensions_Valid)
|
||
{
|
||
// A: 1024×512, B: 512×2048, C: 1024×2048
|
||
auto problem = Problem::from_dimensions(1024, 512, 512, 2048, 1024, 2048);
|
||
|
||
EXPECT_EQ(problem.M, 1024);
|
||
EXPECT_EQ(problem.N, 2048);
|
||
EXPECT_EQ(problem.K, 512);
|
||
EXPECT_TRUE(problem.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemDimensionInferenceTest, FromShapes_WithC)
|
||
{
|
||
TensorShape A{1024, 512, false};
|
||
TensorShape B{512, 2048, false};
|
||
TensorShape C{1024, 2048, false};
|
||
|
||
auto problem = Problem::from_shapes(A, B, C);
|
||
|
||
EXPECT_EQ(problem.M, 1024);
|
||
EXPECT_EQ(problem.N, 2048);
|
||
EXPECT_EQ(problem.K, 512);
|
||
EXPECT_TRUE(problem.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedA)
|
||
{
|
||
// A stored as K×M (transposed)
|
||
TensorShape A{512, 1024, true};
|
||
TensorShape B{512, 2048, false};
|
||
TensorShape C{1024, 2048, false};
|
||
|
||
auto problem = Problem::from_shapes(A, B, C);
|
||
|
||
EXPECT_EQ(problem.M, 1024);
|
||
EXPECT_EQ(problem.N, 2048);
|
||
EXPECT_EQ(problem.K, 512);
|
||
}
|
||
|
||
TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedB)
|
||
{
|
||
TensorShape A{1024, 512, false};
|
||
// B stored as N×K (transposed)
|
||
TensorShape B{2048, 512, true};
|
||
TensorShape C{1024, 2048, false};
|
||
|
||
auto problem = Problem::from_shapes(A, B, C);
|
||
|
||
EXPECT_EQ(problem.M, 1024);
|
||
EXPECT_EQ(problem.N, 2048);
|
||
EXPECT_EQ(problem.K, 512);
|
||
}
|
||
|
||
// =============================================================================
|
||
// Validation Tests
|
||
// =============================================================================
|
||
|
||
class ProblemValidationTest : public ::testing::Test
|
||
{
|
||
};
|
||
|
||
TEST_F(ProblemValidationTest, ValidProblem)
|
||
{
|
||
Problem p(1024, 1024, 1024);
|
||
EXPECT_TRUE(p.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemValidationTest, ZeroM)
|
||
{
|
||
Problem p(0, 1024, 1024);
|
||
EXPECT_FALSE(p.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemValidationTest, ZeroN)
|
||
{
|
||
Problem p(1024, 0, 1024);
|
||
EXPECT_FALSE(p.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemValidationTest, ZeroK)
|
||
{
|
||
Problem p(1024, 1024, 0);
|
||
EXPECT_FALSE(p.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemValidationTest, NegativeM)
|
||
{
|
||
Problem p;
|
||
p.M = -1;
|
||
p.N = 1024;
|
||
p.K = 1024;
|
||
EXPECT_FALSE(p.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemValidationTest, ZeroKBatch)
|
||
{
|
||
Problem p(1024, 1024, 1024);
|
||
p.k_batch = 0;
|
||
EXPECT_FALSE(p.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemValidationTest, ValidKBatch)
|
||
{
|
||
Problem p(1024, 1024, 1024);
|
||
p.k_batch = 4;
|
||
EXPECT_TRUE(p.is_valid());
|
||
}
|
||
|
||
// =============================================================================
|
||
// num_ops Tests
|
||
// =============================================================================
|
||
|
||
class ProblemNumOpsTest : public ::testing::Test
|
||
{
|
||
};
|
||
|
||
TEST_F(ProblemNumOpsTest, SmallProblem)
|
||
{
|
||
Problem p(10, 20, 30);
|
||
// 2 * M * N * K = 2 * 10 * 20 * 30 = 12000
|
||
EXPECT_EQ(p.num_ops(), 12000);
|
||
}
|
||
|
||
TEST_F(ProblemNumOpsTest, SymmetricProblem)
|
||
{
|
||
Problem p(1024, 1024, 1024);
|
||
// 2 * 1024^3 = 2,147,483,648
|
||
EXPECT_EQ(p.num_ops(), 2LL * 1024 * 1024 * 1024);
|
||
}
|
||
|
||
TEST_F(ProblemNumOpsTest, AsymmetricProblem)
|
||
{
|
||
Problem p(512, 2048, 256);
|
||
EXPECT_EQ(p.num_ops(), 2LL * 512 * 2048 * 256);
|
||
}
|
||
|
||
TEST_F(ProblemNumOpsTest, LargeProblem)
|
||
{
|
||
Problem p(4096, 4096, 4096);
|
||
std::int64_t expected = 2LL * 4096 * 4096 * 4096;
|
||
EXPECT_EQ(p.num_ops(), expected);
|
||
EXPECT_GT(p.num_ops(), 0); // No overflow
|
||
}
|
||
|
||
// =============================================================================
|
||
// Edge Cases
|
||
// =============================================================================
|
||
|
||
class ProblemEdgeCasesTest : public ::testing::Test
|
||
{
|
||
};
|
||
|
||
TEST_F(ProblemEdgeCasesTest, MinimumValidSize)
|
||
{
|
||
Problem p(1, 1, 1);
|
||
EXPECT_TRUE(p.is_valid());
|
||
EXPECT_EQ(p.num_ops(), 2);
|
||
}
|
||
|
||
TEST_F(ProblemEdgeCasesTest, NonSquare_TallMatrix)
|
||
{
|
||
Problem p(8192, 64, 1024);
|
||
EXPECT_TRUE(p.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemEdgeCasesTest, NonSquare_WideMatrix)
|
||
{
|
||
Problem p(64, 8192, 1024);
|
||
EXPECT_TRUE(p.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemEdgeCasesTest, NonSquare_DeepK)
|
||
{
|
||
Problem p(1024, 1024, 8192);
|
||
EXPECT_TRUE(p.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemEdgeCasesTest, SmallK)
|
||
{
|
||
Problem p(1024, 1024, 16);
|
||
EXPECT_TRUE(p.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemEdgeCasesTest, NonPowerOf2Dimensions)
|
||
{
|
||
Problem p(1000, 2000, 300);
|
||
EXPECT_TRUE(p.is_valid());
|
||
EXPECT_EQ(p.num_ops(), 2LL * 1000 * 2000 * 300);
|
||
}
|
||
|
||
TEST_F(ProblemEdgeCasesTest, PrimeDimensions)
|
||
{
|
||
Problem p(997, 1009, 1013); // All prime numbers
|
||
EXPECT_TRUE(p.is_valid());
|
||
}
|
||
|
||
// =============================================================================
|
||
// Configuration Tests
|
||
// =============================================================================
|
||
|
||
class ProblemConfigurationTest : public ::testing::Test
|
||
{
|
||
};
|
||
|
||
TEST_F(ProblemConfigurationTest, DefaultConfiguration)
|
||
{
|
||
Problem p(1024, 1024, 1024);
|
||
|
||
EXPECT_FALSE(p.prefer_persistent);
|
||
EXPECT_FALSE(p.enable_validation);
|
||
EXPECT_EQ(p.smem_budget, 0);
|
||
EXPECT_EQ(p.k_batch, 1);
|
||
}
|
||
|
||
TEST_F(ProblemConfigurationTest, SetPersistentPreference)
|
||
{
|
||
Problem p(1024, 1024, 1024);
|
||
p.prefer_persistent = true;
|
||
|
||
EXPECT_TRUE(p.prefer_persistent);
|
||
EXPECT_TRUE(p.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemConfigurationTest, SetSmemBudget)
|
||
{
|
||
Problem p(1024, 1024, 1024);
|
||
p.smem_budget = 65536; // 64KB
|
||
|
||
EXPECT_EQ(p.smem_budget, 65536);
|
||
EXPECT_TRUE(p.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemConfigurationTest, SetKBatch)
|
||
{
|
||
Problem p(1024, 1024, 1024);
|
||
|
||
for(int kb : {1, 2, 4, 8, 16})
|
||
{
|
||
p.k_batch = kb;
|
||
EXPECT_EQ(p.k_batch, kb);
|
||
EXPECT_TRUE(p.is_valid());
|
||
}
|
||
}
|
||
|
||
// =============================================================================
|
||
// Copy and Assignment Tests
|
||
// =============================================================================
|
||
|
||
class ProblemCopyTest : public ::testing::Test
|
||
{
|
||
};
|
||
|
||
TEST_F(ProblemCopyTest, CopyConstruction)
|
||
{
|
||
Problem p1(1024, 2048, 512);
|
||
p1.prefer_persistent = true;
|
||
p1.k_batch = 4;
|
||
|
||
Problem p2(p1);
|
||
|
||
EXPECT_EQ(p2.M, 1024);
|
||
EXPECT_EQ(p2.N, 2048);
|
||
EXPECT_EQ(p2.K, 512);
|
||
EXPECT_TRUE(p2.prefer_persistent);
|
||
EXPECT_EQ(p2.k_batch, 4);
|
||
}
|
||
|
||
TEST_F(ProblemCopyTest, Assignment)
|
||
{
|
||
Problem p1(1024, 2048, 512);
|
||
Problem p2(256, 256, 256);
|
||
|
||
p2 = p1;
|
||
|
||
EXPECT_EQ(p2.M, 1024);
|
||
EXPECT_EQ(p2.N, 2048);
|
||
EXPECT_EQ(p2.K, 512);
|
||
}
|
||
|
||
// =============================================================================
|
||
// Builder Tests
|
||
// =============================================================================
|
||
|
||
class ProblemBuilderTest : public ::testing::Test
|
||
{
|
||
};
|
||
|
||
TEST_F(ProblemBuilderTest, BasicBuild)
|
||
{
|
||
auto problem = ProblemBuilder().dimensions(1024, 2048, 512).build();
|
||
|
||
EXPECT_EQ(problem.M, 1024);
|
||
EXPECT_EQ(problem.N, 2048);
|
||
EXPECT_EQ(problem.K, 512);
|
||
EXPECT_TRUE(problem.is_valid());
|
||
}
|
||
|
||
TEST_F(ProblemBuilderTest, WithSplitK)
|
||
{
|
||
auto problem = ProblemBuilder().dimensions(1024, 1024, 1024).split_k(4).build();
|
||
|
||
EXPECT_EQ(problem.k_batch, 4);
|
||
}
|
||
|
||
TEST_F(ProblemBuilderTest, WithPersistent)
|
||
{
|
||
auto problem = ProblemBuilder().dimensions(1024, 1024, 1024).persistent(true).build();
|
||
|
||
EXPECT_TRUE(problem.prefer_persistent);
|
||
}
|
||
|
||
TEST_F(ProblemBuilderTest, WithSmemBudget)
|
||
{
|
||
auto problem = ProblemBuilder().dimensions(1024, 1024, 1024).smem_budget(65536).build();
|
||
|
||
EXPECT_EQ(problem.smem_budget, 65536);
|
||
}
|
||
|
||
TEST_F(ProblemBuilderTest, ChainedConfiguration)
|
||
{
|
||
auto problem = ProblemBuilder()
|
||
.dimensions(2048, 2048, 1024)
|
||
.split_k(2)
|
||
.persistent(true)
|
||
.smem_budget(32768)
|
||
.validate(true)
|
||
.build();
|
||
|
||
EXPECT_EQ(problem.M, 2048);
|
||
EXPECT_EQ(problem.N, 2048);
|
||
EXPECT_EQ(problem.K, 1024);
|
||
EXPECT_EQ(problem.k_batch, 2);
|
||
EXPECT_TRUE(problem.prefer_persistent);
|
||
EXPECT_EQ(problem.smem_budget, 32768);
|
||
EXPECT_TRUE(problem.enable_validation);
|
||
}
|
||
|
||
TEST_F(ProblemBuilderTest, FromAB)
|
||
{
|
||
auto problem = ProblemBuilder().from_ab(1024, 512, 512, 2048).build();
|
||
|
||
EXPECT_EQ(problem.M, 1024);
|
||
EXPECT_EQ(problem.N, 2048);
|
||
EXPECT_EQ(problem.K, 512);
|
||
}
|
||
|
||
// =============================================================================
|
||
// Dimension Mismatch Error Tests
|
||
// =============================================================================
|
||
|
||
class ProblemDimensionErrorTest : public ::testing::Test
|
||
{
|
||
};
|
||
|
||
TEST_F(ProblemDimensionErrorTest, KMismatchThrows)
|
||
{
|
||
EXPECT_THROW((void)Problem::from_ab(1024, 512, 256, 2048), // K mismatch: 512 vs 256
|
||
std::invalid_argument);
|
||
}
|
||
|
||
TEST_F(ProblemDimensionErrorTest, MDimensionMismatchThrows)
|
||
{
|
||
TensorShape A{1024, 512, false};
|
||
TensorShape B{512, 2048, false};
|
||
TensorShape C{512, 2048, false}; // M mismatch: A says M=1024, C says M=512
|
||
|
||
EXPECT_THROW((void)Problem::from_shapes(A, B, C), std::invalid_argument);
|
||
}
|
||
|
||
TEST_F(ProblemDimensionErrorTest, NDimensionMismatchThrows)
|
||
{
|
||
TensorShape A{1024, 512, false};
|
||
TensorShape B{512, 2048, false};
|
||
TensorShape C{1024, 1024, false}; // N mismatch: B says N=2048, C says N=1024
|
||
|
||
EXPECT_THROW((void)Problem::from_shapes(A, B, C), std::invalid_argument);
|
||
}
|
||
|
||
// =============================================================================
|
||
// Validate Sizes Tests
|
||
// =============================================================================
|
||
|
||
class ProblemValidateSizesTest : public ::testing::Test
|
||
{
|
||
};
|
||
|
||
TEST_F(ProblemValidateSizesTest, CorrectSizes)
|
||
{
|
||
Problem p(1024, 2048, 512);
|
||
|
||
// This should not throw
|
||
EXPECT_NO_THROW(p.validate_sizes(1024 * 512, // A size
|
||
512 * 2048, // B size
|
||
1024 * 2048 // C size
|
||
));
|
||
}
|
||
|
||
TEST_F(ProblemValidateSizesTest, WrongASizeThrows)
|
||
{
|
||
Problem p(1024, 2048, 512);
|
||
|
||
EXPECT_THROW(p.validate_sizes(1024 * 256, // Wrong A size
|
||
512 * 2048,
|
||
1024 * 2048),
|
||
std::invalid_argument);
|
||
}
|
||
|
||
TEST_F(ProblemValidateSizesTest, WrongBSizeThrows)
|
||
{
|
||
Problem p(1024, 2048, 512);
|
||
|
||
EXPECT_THROW(p.validate_sizes(1024 * 512,
|
||
256 * 2048, // Wrong B size
|
||
1024 * 2048),
|
||
std::invalid_argument);
|
||
}
|
||
|
||
TEST_F(ProblemValidateSizesTest, WrongCSizeThrows)
|
||
{
|
||
Problem p(1024, 2048, 512);
|
||
|
||
EXPECT_THROW(p.validate_sizes(1024 * 512,
|
||
512 * 2048,
|
||
512 * 1024 // Wrong C size
|
||
),
|
||
std::invalid_argument);
|
||
}
|