Readd naive normalization in mhc v3

This commit is contained in:
Damien Lejeune
2026-02-06 09:44:20 +00:00
parent 053aed9402
commit e7ebd6c288
4 changed files with 704 additions and 543 deletions

View File

@@ -37,49 +37,46 @@ TYPED_TEST_SUITE(TestCkTileMHC, TestTypes);
// TYPED_TEST(TestCkTileMHC, TestArbitraryBatchSize) { this->RunArbitraryBatchSizeTest(); }
// Explicit test cases for specific batch sizes (using template syntax)
TYPED_TEST(TestCkTileMHC, TestBatchSize1) { this->template RunBatchSizeTest<1>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize1) { this->template RunGemmPipeline<1>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize16) { this->template RunBatchSizeTest<16>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize16) { this->template RunGemmPipeline<16>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize17) { this->template RunBatchSizeTest<17>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize17) { this->template RunGemmPipeline<17>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize32) { this->template RunBatchSizeTest<32>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize32) { this->template RunGemmPipeline<32>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize64) { this->template RunBatchSizeTest<64>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize64) { this->template RunGemmPipeline<64>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize100) { this->template RunBatchSizeTest<100>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize100) { this->template RunGemmPipeline<100>(); }
// Test with different expansion factors (keeping nC <= 256)
TYPED_TEST(TestCkTileMHC, TestBatchSize16N2C128) { this->template RunBatchSizeTest<16, 2, 128>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize16N2C128) { this->template RunGemmPipeline<16, 2, 128>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize32N3C85) { this->template RunBatchSizeTest<32, 3, 85>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize32N3C85) { this->template RunGemmPipeline<32, 3, 85>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize8N8C32) { this->template RunBatchSizeTest<8, 8, 32>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize8N8C32) { this->template RunGemmPipeline<8, 8, 32>(); }
// Test with large C values that require K-tiling (nC > 256)
TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C512) { this->template RunBatchSizeTest<2, 4, 512>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C512) { this->template RunGemmPipeline<2, 4, 512>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C1024) { this->template RunBatchSizeTest<2, 4, 1024>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C1024) { this->template RunGemmPipeline<2, 4, 1024>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C4096) { this->template RunBatchSizeTest<2, 4, 4096>(); }
TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C4096) { this->template RunGemmPipeline<2, 4, 4096>(); }
// Test with different activation functions
TYPED_TEST(TestCkTileMHC, TestBatchSize16WithTanh)
{
this->template RunBatchSizeTestWithActivation<16, 4, 64, ck_tile::element_wise::TanH>();
this->template RunGemmPipeline<16, 4, 64, ck_tile::element_wise::TanH>();
}
TYPED_TEST(TestCkTileMHC, TestBatchSize16WithRelu)
{
this->template RunBatchSizeTestWithActivation<16, 4, 64, ck_tile::element_wise::Relu>();
this->template RunGemmPipeline<16, 4, 64, ck_tile::element_wise::Relu>();
}
TYPED_TEST(TestCkTileMHC, TestBatchSize16WithSilu)
{
this->template RunBatchSizeTestWithActivation<16, 4, 64, ck_tile::element_wise::Silu>();
this->template RunGemmPipeline<16, 4, 64, ck_tile::element_wise::Silu>();
}
TYPED_TEST(TestCkTileMHC, TestBatchSize16N4C4096)
{
this->template RunBatchSizeTest<16, 4, 4096>();
}
TYPED_TEST(TestCkTileMHC, TestBatchSize16N4C4096) { this->template RunGemmPipeline<16, 4, 4096>(); }

File diff suppressed because it is too large Load Diff