mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Add batched attention special kernel instances (#424)
* sanity check
* add attribution
* add irrgular k tile size for batched attention
* format
[ROCm/composable_kernel commit: 7c788e10ce]
This commit is contained in:
@@ -105,6 +105,19 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16_IrregularK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{{256, 256, 160, 160, 16},
|
||||
{256, 64, 160, 64, 16},
|
||||
{1024, 1024, 80, 80, 16},
|
||||
{1024, 64, 80, 64, 16},
|
||||
{4096, 4096, 40, 40, 16},
|
||||
{4096, 64, 40, 64, 16}};
|
||||
this->bench_ = true;
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
// TODO: enable KPadding tests when it is implemented
|
||||
|
||||
@@ -29,14 +29,19 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
|
||||
using B1Layout = std::tuple_element_t<6, Tuple>;
|
||||
using CLayout = std::tuple_element_t<7, Tuple>;
|
||||
|
||||
std::vector<std::vector<int>> lengths_ = {
|
||||
{256, 256, 64, 64, 4},
|
||||
{256, 256, 128, 128, 4},
|
||||
{512, 512, 64, 64, 2},
|
||||
{512, 512, 128, 128, 2},
|
||||
{1024, 1024, 64, 64, 1},
|
||||
{1024, 1024, 128, 128, 1},
|
||||
};
|
||||
std::vector<std::vector<int>> lengths_ = {{256, 256, 64, 64, 4},
|
||||
{256, 256, 128, 128, 4},
|
||||
{512, 512, 64, 64, 2},
|
||||
{512, 512, 128, 128, 2},
|
||||
{1024, 1024, 64, 64, 1},
|
||||
{1024, 1024, 128, 128, 1},
|
||||
{256, 256, 160, 160, 4},
|
||||
{256, 64, 160, 64, 4},
|
||||
{1024, 1024, 80, 80, 2},
|
||||
{1024, 64, 80, 64, 2},
|
||||
{4096, 4096, 40, 40, 1},
|
||||
{4096, 64, 40, 64, 1}};
|
||||
|
||||
bool bench_ = false;
|
||||
bool verify_ = true;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user