mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK][Examples] Fix for example_grouped_gemm_multiple_d_dl_fp16 - corrected stride for B matrix. (#3104)
Fix for example_elementwise_layernorm_blockwise - corrected cmdline.
Signed-off-by: Michal Kulikowski <Michal.Kulikowski@amd.com>
[ROCm/composable_kernel commit: b0aab85baa]
This commit is contained in:
committed by
GitHub
parent
97c2fb582a
commit
40c4bad35b
@@ -59,4 +59,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
return !run_grouped_gemm_example(argc, argv);
|
||||
}
|
||||
|
||||
@@ -278,19 +278,20 @@ bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default cases
|
||||
}
|
||||
else if(argc == 4 || argc == 6)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 6)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.async_hargs = std::stoi(argv[4]);
|
||||
problem_size.group_count = std::stoi(argv[5]);
|
||||
if(argc == 6)
|
||||
{
|
||||
config.async_hargs = std::stoi(argv[4]);
|
||||
problem_size.group_count = std::stoi(argv[5]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -299,18 +300,33 @@ bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: async hargs (0=n0, 1=yes)\n");
|
||||
printf("arg5: group count (default=16)");
|
||||
exit(0);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// Lambda to get stride based on layout
|
||||
auto get_stride = [](auto layout, auto row_dim, auto col_dim) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col_dim;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row_dim;
|
||||
}
|
||||
};
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(256 + 256 * i);
|
||||
problem_size.Ns.push_back(128 + 128 * i);
|
||||
problem_size.Ks.push_back(128 + 64 * i);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
problem_size.stride_As.push_back(
|
||||
get_stride(ALayout{}, problem_size.Ms[i], problem_size.Ks[i]));
|
||||
problem_size.stride_Bs.push_back(
|
||||
get_stride(BLayout{}, problem_size.Ks[i], problem_size.Ns[i]));
|
||||
problem_size.stride_Cs.push_back(
|
||||
get_stride(ELayout{}, problem_size.Ms[i], problem_size.Ns[i]));
|
||||
}
|
||||
|
||||
return run_grouped_gemm(problem_size, config);
|
||||
|
||||
@@ -82,37 +82,29 @@ int main(int argc, char* argv[])
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
|
||||
ck::index_t M = 48 * 256;
|
||||
ck::index_t N = 1024;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
}
|
||||
else if(argc == 3)
|
||||
else if(argc == 3 || argc == 5)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
if(argc == 5)
|
||||
{
|
||||
M = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: time kernel (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::index_t M = 48 * 256;
|
||||
ck::index_t N = 1024;
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 3)
|
||||
{
|
||||
M = std::stoi(argv[1]);
|
||||
N = std::stoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1 to 2: M, N" << std::endl;
|
||||
return 1;
|
||||
printf("arg3-4: M, N\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
ck::index_t Stride = N;
|
||||
|
||||
Reference in New Issue
Block a user