mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 20:27:42 +00:00
update RCC layout
This commit is contained in:
@@ -234,7 +234,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
template <typename GemmConfig, typename PrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, std::string c_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -245,7 +245,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
using AccDataType = typename Types::AccDataType;
|
||||
using CDataType = typename Types::CDataType;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
// Row major C layout (c_layout == "R")
|
||||
if(a_layout == "R" && b_layout == "C" && c_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
@@ -253,7 +254,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
else if(a_layout == "R" && b_layout == "R" && c_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
@@ -261,7 +262,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
else if(a_layout == "C" && b_layout == "R" && c_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
@@ -269,17 +270,19 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
// Column major C layout (c_layout == "C")
|
||||
else if(a_layout == "R" && b_layout == "C" && c_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Col{}, Col{}, Row{});
|
||||
AccDataType>(argc, argv, Row{}, Col{}, Col{});
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A and B tensors!");
|
||||
throw std::runtime_error("Unsupported data layout configuration for A, B and C tensors!");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -294,22 +297,23 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string c_layout = arg_parser.get_str("c_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
a_layout, b_layout, c_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
a_layout, b_layout, c_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
a_layout, b_layout, c_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -322,8 +326,6 @@ int main(int argc, char* argv[])
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_gemm_example<GemmConfigComputeV4_Wmma>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv) ||
|
||||
!run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv) ||
|
||||
!run_grouped_gemm_example<GemmConfigComputeV4_V2>(argc, argv);
|
||||
return run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -669,8 +669,8 @@ struct CShuffleEpilogue
|
||||
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
|
||||
// static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
// "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); // Enable column major C layout
|
||||
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<kBlockSize,
|
||||
|
||||
Reference in New Issue
Block a user