Support for dtypes (fp8, bf8, bf16 and fp16) for the ck_tile/03_gemm example. (#1845)

* Support bf16/fb8/bf8 datatypes for ck_tile/gemm

* remove commented out code.

* Addressing code review comments and enabling universal_gemm for all the supported data types.

* Merge conflict resolution.

* Solve the memory pipeline compilation error. Merge with the new change of CShuffle

* finish the feature, pass the tests

* Fix the pipeline and add the benchmark script for other data types

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
kylasa
2025-02-06 14:07:38 -08:00
committed by Sam Wu
parent 9b5dfba242
commit ab5d027866
21 changed files with 598 additions and 88 deletions

View File

@@ -12,7 +12,13 @@
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
@@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
@@ -243,24 +249,101 @@ int run_gemm_example(int argc, char* argv[])
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else
{