mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
* Few small fixes. * New GroupedGemm instances (BF16) * Unify and refactor GroupedGEMM device API. * Adapt changes to new API. * Adapt grouped gemm profiler. * Accept multiple kbatches for grouped gemm profiler. - delete obsolete two stage as it is now covered by grouped gemm * Update unit test for grouped gemm. * Fix thresholds for BF16 and F8. Unblock tests. * Fix few instances. * Multiple small fixes. * Adapt to new API, check dynamic casting. * Uncomment few data types in grouped gemm profiler. * Fix call to SetDeviceArgs. * Fix profile grouped gemm multiply tile loop. * Fix grouped gemm tile loop kernel args in client examples. * Review comments.
76 lines
1.6 KiB
C++
76 lines
1.6 KiB
C++
#pragma once
|
|
|
|
TYPED_TEST(TestGroupedGemm, TinyCases)
|
|
{
|
|
const std::vector<int> Ms{0, 1};
|
|
constexpr int N = 768;
|
|
constexpr int K = 544;
|
|
|
|
const std::vector<int> Ns(Ms.size(), N);
|
|
const std::vector<int> Ks(Ms.size(), K);
|
|
|
|
this->Run(Ms, Ns, Ks);
|
|
}
|
|
|
|
TYPED_TEST(TestGroupedGemm, SmallCases)
|
|
{
|
|
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
|
|
constexpr int N = 768;
|
|
constexpr int K = 544;
|
|
|
|
const std::vector<int> Ns(Ms.size(), N);
|
|
const std::vector<int> Ks(Ms.size(), K);
|
|
|
|
this->Run(Ms, Ns, Ks);
|
|
}
|
|
|
|
TYPED_TEST(TestGroupedGemm, MidCases)
|
|
{
|
|
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
|
|
constexpr int N = 768;
|
|
constexpr int K = 544;
|
|
|
|
const std::vector<int> Ns(Ms.size(), N);
|
|
const std::vector<int> Ks(Ms.size(), K);
|
|
|
|
this->Run(Ms, Ns, Ks);
|
|
}
|
|
|
|
TYPED_TEST(TestGroupedGemm, Regular)
|
|
{
|
|
const std::vector<int> Ms{64, 128, 256};
|
|
constexpr int N = 768;
|
|
constexpr int K = 320;
|
|
|
|
const std::vector<int> Ns(Ms.size(), N);
|
|
const std::vector<int> Ks(Ms.size(), K);
|
|
|
|
this->Run(Ms, Ns, Ks);
|
|
}
|
|
|
|
TYPED_TEST(TestGroupedGemm, MNKPadded)
|
|
{
|
|
const std::vector<int> Ms{127, 150, 188, 210};
|
|
constexpr int N = 136;
|
|
constexpr int K = 280;
|
|
|
|
const std::vector<int> Ns(Ms.size(), N);
|
|
const std::vector<int> Ks(Ms.size(), K);
|
|
|
|
this->Run(Ms, Ns, Ks);
|
|
}
|
|
|
|
TYPED_TEST(TestGroupedGemm, TestLargeKBatch)
|
|
{
|
|
const std::vector<int> Ms{188, 210};
|
|
constexpr int N = 768;
|
|
constexpr int K = 4096;
|
|
|
|
const std::vector<int> Ns(Ms.size(), N);
|
|
const std::vector<int> Ks(Ms.size(), K);
|
|
|
|
this->k_batches_ = {32, 64};
|
|
|
|
this->Run(Ms, Ns, Ks);
|
|
}
|