Merge commit '9d4bfe393276317b5c1f9dda990eb0bd6c1ec3e7' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-09 07:12:56 +00:00
parent 8620b69b7a
commit 64603db299
5 changed files with 34 additions and 12 deletions

View File

@@ -91,6 +91,8 @@ int main(int argc, char* argv[])
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t KBatch = 1;
if(argc == 1)
{
// use default case
@@ -101,7 +103,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 8)
else if(argc == 8 || argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
@@ -113,6 +115,11 @@ int main(int argc, char* argv[])
flush_cache = std::stoi(argv[7]);
if(argc == 9)
{
KBatch = std::stoi(argv[8]);
}
StrideA = K;
StrideB = K;
StrideE = N;
@@ -124,6 +131,7 @@ int main(int argc, char* argv[])
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: M, N, K\n");
printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n");
printf("arg8: KBatch (default: 1)\n");
exit(0);
}
@@ -233,9 +241,9 @@ int main(int argc, char* argv[])
constexpr ck::index_t NumDTensor = DsDataType::Size();
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{},
e_device_buf.GetDeviceBuffer(),
@@ -251,6 +259,7 @@ int main(int argc, char* argv[])
a_element_op,
b_element_op,
cde_element_op);
argument.KBatch = KBatch;
if(!device_op.IsSupportedArgument(argument))
{