diff --git a/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp index 3b6dd9a2a9..c758720e10 100644 --- a/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp +++ b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp @@ -60,14 +60,13 @@ int main() int sum_of_m = 0; - Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - - int group_count = Ms.size(); + const int group_count = 16; for(int i = 0; i < group_count; ++i) { - Ns.push_back(768); - Ks.push_back(4608); + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp index 3503ae8b24..b16fe90387 100644 --- a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp @@ -57,15 +57,13 @@ int main() int sum_of_m = 0; - // Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; - - int group_count = Ms.size(); + const int group_count = 16; for(int i = 0; i < group_count; ++i) { - Ns.push_back(768); - Ks.push_back(4608); + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp index b288550b74..045fe47c4f 100644 --- a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp @@ -58,14 +58,13 @@ int main() int sum_of_m = 0; - Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - - int group_count = Ms.size(); + const int group_count = 16; for(int i = 0; i < group_count; ++i) { - Ns.push_back(768); - Ks.push_back(4608); + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp index c60daa3b36..8f82140f3f 100644 --- a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp @@ -58,14 +58,13 @@ int main() int sum_of_m = 0; - Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - - int group_count = Ms.size(); + const int group_count = 16; for(int i = 0; i < group_count; ++i) { - Ns.push_back(768); - Ks.push_back(4608); + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 89d4789c12..95b8526094 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -296,13 +296,11 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; - problem_size.Ms = { - 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ns.push_back(768); - problem_size.Ks.push_back(4608); + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(128 + 128 * i); + problem_size.Ks.push_back(128 + 64 * i); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp index 1c50dc051b..84abe1d1db 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp @@ -297,13 +297,11 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; - problem_size.Ms = { - 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ns.push_back(768); - problem_size.Ks.push_back(4608); + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(128 + 128 * i); + problem_size.Ks.push_back(128 + 64 * i); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp index 743ab96be6..9f8f6cb1e4 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp @@ -66,13 +66,11 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; - problem_size.Ms = { - 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ns.push_back(768); - problem_size.Ks.push_back(4608); + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(128 + 128 * i); + problem_size.Ks.push_back(128 + 64 * i); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);