feat(grouped_gemm_multi_d): add support for bf16

This commit is contained in:
AviralGoelAMD
2025-10-09 17:10:20 +00:00
committed by Aviral Goel
parent 706c2b281c
commit 8d8b49dec2
2 changed files with 111 additions and 38 deletions

View File

@@ -15,14 +15,6 @@
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using D0DataType = ck_tile::half_t;
using D1DataType = ck_tile::half_t;
using EDataType = ck_tile::half_t;
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
using AccDataType = float;
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
@@ -173,7 +165,38 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<DsDataType::size()>;
template <typename DataType>
struct GemmMultiDTypeConfig;
template <>
struct GemmMultiDTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using D0DataType = ck_tile::half_t;
using D1DataType = ck_tile::half_t;
using EDataType = ck_tile::half_t;
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
using AccDataType = float;
};
template <>
struct GemmMultiDTypeConfig<ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using D0DataType = ck_tile::bf16_t;
using D1DataType = ck_tile::bf16_t;
using EDataType = ck_tile::bf16_t;
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
using AccDataType = float;
};
// Deduce the number of D tensors from the DsDataType tuple size
// All precision configs have the same number of D tensors, so we can use any one
constexpr std::size_t NumDTensor = GemmMultiDTypeConfig<ck_tile::bf16_t>::DsDataType::size();
using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<NumDTensor>;
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
{
@@ -190,7 +213,7 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
.insert("ds_layout", "R", "Ds tensor data layout - Row by default.")
.insert("e_layout", "R", "E tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("prec", "fp16", "data type. fp16")
.insert("prec", "bf16", "data type. fp16/bf16")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")
@@ -204,7 +227,7 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_multi_d_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<DsDataType::size()>);
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<NumDTensor>);
}
template <typename GemmConfig,

View File

@@ -19,6 +19,11 @@ static constexpr inline auto is_row_major(Layout layout_)
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType,
typename BDataType,
typename D0DataType,
typename EDataType,
typename AccDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
@@ -86,31 +91,31 @@ float invoke_gemm(int n_warmup,
}
else
{
std::vector<ck_tile::GemmTransKernelArg<DsDataType::size()>> kargs;
std::vector<ck_tile::GemmTransKernelArg<NumDTensor>> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
const bool splitk = args[0].k_batch > 1;
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, 2>{{arg.a_ptr},
{arg.b_ptr},
arg.ds_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
{arg.stride_A},
{arg.stride_B},
arg.stride_Ds,
arg.stride_E,
arg.k_batch});
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, NumDTensor>{{arg.a_ptr},
{arg.b_ptr},
arg.ds_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
{arg.stride_A},
{arg.stride_B},
arg.stride_Ds,
arg.stride_E,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(
kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<DsDataType::size()>),
hipMemcpyHostToDevice,
stream.stream_id_));
HIP_CHECK_ERROR(
hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<NumDTensor>),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time =
grouped_gemm_multi_d_tileloop<GemmConfig,
ADataType,
@@ -128,6 +133,12 @@ float invoke_gemm(int n_warmup,
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename D0DataType,
typename D1DataType,
typename AccDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename D0Layout,
@@ -145,6 +156,7 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc,
using CDElementWise = MultiplyMultiply;
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
auto valid_input_data = [&](int group_count, const auto&... args) {
return !(args.empty() || ...) && group_count == (args.size() == ...);
@@ -360,7 +372,9 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc,
const float max_accumulated_value =
*std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end());
const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value);
const auto rtol_atol =
calculate_rtol_atol<ADataType, BDataType, D0DataType, EDataType, AccDataType>(
Ks[i], 1, max_accumulated_value);
pass &=
ck_tile::check_err(e_m_n_tensors[i],
@@ -390,6 +404,38 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc,
return pass;
}
template <typename GemmConfig, typename PrecType>
int run_gemm_multi_d_example_prec_type(
std::string a_layout, std::string b_layout, std::string ds_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmMultiDTypeConfig<PrecType>;
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using D0DataType = typename Types::D0DataType;
using D1DataType = typename Types::D1DataType;
using AccDataType = typename Types::AccDataType;
using EDataType = typename Types::EDataType;
if(a_layout == "R" && b_layout == "C" && ds_layout == "R")
{
return run_grouped_gemm_multi_d_example_with_layouts<GemmConfig,
ADataType,
BDataType,
D0DataType,
D1DataType,
AccDataType,
EDataType>(
argc, argv, Row{}, Col{}, Row{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for provided tensors!");
}
}
template <typename GemmConfig>
int run_grouped_gemm_multi_d_example(int argc, char* argv[])
{
@@ -401,17 +447,21 @@ int run_grouped_gemm_multi_d_example(int argc, char* argv[])
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string ds_layout = arg_parser.get_str("ds_layout");
const std::string data_type = arg_parser.get_str("prec");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C" && ds_layout == "R")
if(data_type == "fp16")
{
return run_grouped_gemm_multi_d_example_with_layouts<GemmConfig>(
argc, argv, Row{}, Col{}, Row{}, Row{}, Row{});
return run_gemm_multi_d_example_prec_type<GemmConfig, ck_tile::half_t>(
a_layout, b_layout, ds_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_gemm_multi_d_example_prec_type<GemmConfig, ck_tile::bf16_t>(
a_layout, b_layout, ds_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data layout configuration for provided tensors!");
throw std::runtime_error(
"Unsupported data type configuration. Only fp16 and bf16 are supported.");
}
}