mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add KBatch support for gemm_ab_scale (#2740)
* Add KBatch support for gemm_ab_scale
* Revert kernel parameters change
* Remove printing
* fix formatting
* fix check
* Use {} in if
---------
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: 9d4bfe3932]
This commit is contained in:
@@ -91,6 +91,8 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
@@ -101,7 +103,7 @@ int main(int argc, char* argv[])
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 8)
|
||||
else if(argc == 8 || argc == 9)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
@@ -113,6 +115,11 @@ int main(int argc, char* argv[])
|
||||
|
||||
flush_cache = std::stoi(argv[7]);
|
||||
|
||||
if(argc == 9)
|
||||
{
|
||||
KBatch = std::stoi(argv[8]);
|
||||
}
|
||||
|
||||
StrideA = K;
|
||||
StrideB = K;
|
||||
StrideE = N;
|
||||
@@ -124,6 +131,7 @@ int main(int argc, char* argv[])
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: M, N, K\n");
|
||||
printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n");
|
||||
printf("arg8: KBatch (default: 1)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
@@ -233,9 +241,9 @@ int main(int argc, char* argv[])
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
@@ -251,6 +259,7 @@ int main(int argc, char* argv[])
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
argument.KBatch = KBatch;
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user