diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp index fb047ae364..236e5e4fa2 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp @@ -59,4 +59,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl #include "run_grouped_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 0; + } + return !run_grouped_gemm_example(argc, argv); +} diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 4ef6074f4a..87ccebc3c4 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -278,19 +278,20 @@ bool run_grouped_gemm_example(int argc, char* argv[]) problem_size.group_count = 16; - if(argc == 4) + if(argc == 1) + { + // use default cases + } + else if(argc == 4 || argc == 6) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); - } - else if(argc == 6) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.async_hargs = std::stoi(argv[4]); - problem_size.group_count = std::stoi(argv[5]); + if(argc == 6) + { + config.async_hargs = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } } else { @@ -299,18 +300,33 @@ bool run_grouped_gemm_example(int argc, char* argv[]) printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4: async hargs (0=n0, 1=yes)\n"); printf("arg5: group count (default=16)"); - exit(0); + exit(1); } + // Lambda to get stride based on layout + auto get_stride = [](auto layout, auto row_dim, auto col_dim) { + if constexpr(std::is_same_v) + { + return col_dim; + } + else + { + return row_dim; + } + }; + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); problem_size.Ns.push_back(128 + 128 * i); problem_size.Ks.push_back(128 + 64 * i); - problem_size.stride_As.push_back(problem_size.Ks[i]); - problem_size.stride_Bs.push_back(problem_size.Ks[i]); - problem_size.stride_Cs.push_back(problem_size.Ns[i]); + problem_size.stride_As.push_back( + get_stride(ALayout{}, problem_size.Ms[i], problem_size.Ks[i])); + problem_size.stride_Bs.push_back( + get_stride(BLayout{}, problem_size.Ks[i], problem_size.Ns[i])); + problem_size.stride_Cs.push_back( + get_stride(ELayout{}, problem_size.Ms[i], problem_size.Ns[i])); } return run_grouped_gemm(problem_size, config); diff --git a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp index 8064809123..21c5ff8d5a 100644 --- a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp +++ b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -82,37 +82,29 @@ int main(int argc, char* argv[]) bool do_verification = true; bool time_kernel = true; + ck::index_t M = 48 * 256; + ck::index_t N = 1024; + if(argc == 1) { // use default } - else if(argc == 3) + else if(argc == 3 || argc == 5) { do_verification = std::stoi(argv[1]); time_kernel = std::stoi(argv[2]); + if(argc == 5) + { + M = std::stoi(argv[3]); + N = std::stoi(argv[4]); + } } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: time kernel (0=no, 1=yes)\n"); - exit(0); - } - - ck::index_t M = 48 * 256; - ck::index_t N = 1024; - if(argc == 1) - { - // use default case - } - else if(argc == 3) - { - M = std::stoi(argv[1]); - N = std::stoi(argv[2]); - } - else - { - std::cerr << "arg1 to 2: M, N" << std::endl; - return 1; + printf("arg3-4: M, N\n"); + exit(1); } ck::index_t Stride = N;