Add KBatch support for gemm_ab_scale (#2740)

* Add KBatch support for gemm_ab_scale

* Revert kernel parameters change

* Remove printing

* fix formatting

* fix check

* Use {} in if

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

[ROCm/composable_kernel commit: 9d4bfe3932]
This commit is contained in:
Sami Remes
2025-10-09 07:33:16 +01:00
committed by GitHub
parent e9ade69185
commit e7ef841a68
5 changed files with 34 additions and 12 deletions

View File

@@ -91,6 +91,8 @@ int main(int argc, char* argv[])
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t KBatch = 1;
if(argc == 1)
{
// use default case
@@ -101,7 +103,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 8)
else if(argc == 8 || argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
@@ -113,6 +115,11 @@ int main(int argc, char* argv[])
flush_cache = std::stoi(argv[7]);
if(argc == 9)
{
KBatch = std::stoi(argv[8]);
}
StrideA = K;
StrideB = K;
StrideE = N;
@@ -124,6 +131,7 @@ int main(int argc, char* argv[])
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: M, N, K\n");
printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n");
printf("arg8: KBatch (default: 1)\n");
exit(0);
}
@@ -233,9 +241,9 @@ int main(int argc, char* argv[])
constexpr ck::index_t NumDTensor = DsDataType::Size();
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{},
e_device_buf.GetDeviceBuffer(),
@@ -251,6 +259,7 @@ int main(int argc, char* argv[])
a_element_op,
b_element_op,
cde_element_op);
argument.KBatch = KBatch;
if(!device_op.IsSupportedArgument(argument))
{