update RCC layout

This commit is contained in:
kyle-256
2025-12-18 06:50:22 +00:00
committed by kyle-256
parent 518d02b925
commit 3053fb50ef
2 changed files with 17 additions and 15 deletions

View File

@@ -234,7 +234,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
#include "run_grouped_gemm_example.inc"
template <typename GemmConfig, typename PrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, std::string c_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
@@ -245,7 +245,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
if(a_layout == "R" && b_layout == "C")
// Row major C layout (c_layout == "R")
if(a_layout == "R" && b_layout == "C" && c_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
@@ -253,7 +254,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
CDataType,
AccDataType>(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
else if(a_layout == "R" && b_layout == "R" && c_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
@@ -261,7 +262,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
CDataType,
AccDataType>(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
else if(a_layout == "C" && b_layout == "R" && c_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
@@ -269,17 +270,19 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
CDataType,
AccDataType>(argc, argv, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
// Column major C layout (c_layout == "C")
else if(a_layout == "R" && b_layout == "C" && c_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Col{}, Row{});
AccDataType>(argc, argv, Row{}, Col{}, Col{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A and B tensors!");
throw std::runtime_error("Unsupported data layout configuration for A, B and C tensors!");
}
}
@@ -294,22 +297,23 @@ int run_grouped_gemm_example(int argc, char* argv[])
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string c_layout = arg_parser.get_str("c_layout");
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
a_layout, b_layout, c_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
a_layout, b_layout, c_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, argc, argv);
a_layout, b_layout, c_layout, argc, argv);
}
else
{
@@ -322,8 +326,6 @@ int main(int argc, char* argv[])
#if CK_TILE_USE_WMMA
return !run_grouped_gemm_example<GemmConfigComputeV4_Wmma>(argc, argv);
#else
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv) ||
!run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv) ||
!run_grouped_gemm_example<GemmConfigComputeV4_V2>(argc, argv);
return run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv);
#endif
}

View File

@@ -669,8 +669,8 @@ struct CShuffleEpilogue
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
// static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
// "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); // Enable column major C layout
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<kBlockSize,