Addressing (Post Merge) code review comments for PR 1845 (#1883)

* Addressing code review comments.

* Addressing code review comments.

* Reorganized code for better readability.

* add ck_tile gemms for new types in CI

* fix jenkins syntax

* fix script syntax

* Add the test cases back

* Address the review comments

* Address review comments

* clang format

* Solve the merging issues

* Addressed the comments

* clang format

---------

Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
kylasa
2025-03-06 11:40:30 -08:00
committed by GitHub
parent c12fb0a624
commit 66c5f5b0b6
32 changed files with 511 additions and 245 deletions

View File

@@ -29,8 +29,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr int kBlockPerCu = 1;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
@@ -54,7 +54,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem::kBlockSize,
@@ -99,45 +101,99 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#include "run_gemm_example.inc"
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices when "
"BPrecType is ck_tile::pk_int4_t!");
}
}
else
{
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
}
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
if(data_type == "fp16")
{
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
else if(data_type == "pk_int4_t")
{
// TODO: Add support for bhalf_t ADataType
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
#endif
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}