diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 3ff3f2f10e..5d21519cad 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -234,7 +234,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, #include "run_grouped_gemm_example.inc" template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, std::string c_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -245,7 +245,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a using AccDataType = typename Types::AccDataType; using CDataType = typename Types::CDataType; - if(a_layout == "R" && b_layout == "C") + // Row major C layout (c_layout == "R") + if(a_layout == "R" && b_layout == "C" && c_layout == "R") { return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } - else if(a_layout == "R" && b_layout == "R") + else if(a_layout == "R" && b_layout == "R" && c_layout == "R") { return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); } - else if(a_layout == "C" && b_layout == "R") + else if(a_layout == "C" && b_layout == "R" && c_layout == "R") { return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); } - else if(a_layout == "C" && b_layout == "C") + // Column major C layout (c_layout == "C") + else if(a_layout == "R" && b_layout == "C" && c_layout == "C") { return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + AccDataType>(argc, argv, Row{}, Col{}, Col{}); } + else { - throw std::runtime_error("Unsupported data layout configuration for A and B tensors!"); + throw std::runtime_error("Unsupported data layout configuration for A, B and C tensors!"); } } @@ -294,22 +297,23 @@ int run_grouped_gemm_example(int argc, char* argv[]) const std::string a_layout = arg_parser.get_str("a_layout"); const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string c_layout = arg_parser.get_str("c_layout"); const std::string data_type = arg_parser.get_str("prec"); if(data_type == "fp16") { return run_gemm_example_prec_type, ck_tile::half_t>( - a_layout, b_layout, argc, argv); + a_layout, b_layout, c_layout, argc, argv); } else if(data_type == "bf16") { return run_gemm_example_prec_type, ck_tile::bf16_t>( - a_layout, b_layout, argc, argv); + a_layout, b_layout, c_layout, argc, argv); } else if(data_type == "fp8") { return run_gemm_example_prec_type, ck_tile::fp8_t>( - a_layout, b_layout, argc, argv); + a_layout, b_layout, c_layout, argc, argv); } else { @@ -322,8 +326,6 @@ int main(int argc, char* argv[]) #if CK_TILE_USE_WMMA return !run_grouped_gemm_example(argc, argv); #else - return !run_grouped_gemm_example(argc, argv) || - !run_grouped_gemm_example(argc, argv) || - !run_grouped_gemm_example(argc, argv); + return run_grouped_gemm_example(argc, argv); #endif } diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 53bfa6041d..b02aa231a3 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -669,8 +669,8 @@ struct CShuffleEpilogue constexpr index_t num_access = SFC::get_num_of_access(); - static_assert(std::is_same_v, - "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); + // static_assert(std::is_same_v, + // "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); // Enable column major C layout using TileEncodingPattern = tile_distribution_encoding_pattern_2d