mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Addressing (Post Merge) code review comments for PR 1845 (#1883)
* Addressing code review comments. * Addressing code review comments. * Reorganized code for better readability. * add ck_tile gemms for new types in CI * fix jenkins syntax * fix script syntax * Add the test cases back * Address the review comments * Address review comments * clang format * Solve the merging issues * Addressed the comments * clang format --------- Co-authored-by: illsilin <Illia.Silin@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com> Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -29,8 +29,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// This part comes from the Codegen
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 128;
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
@@ -54,7 +54,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
@@ -99,45 +101,99 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices when "
|
||||
"BPrecType is ck_tile::pk_int4_t!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user