mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
feat(grouped_gemm_multi_d): add support for bf16
This commit is contained in:
committed by
Aviral Goel
parent
706c2b281c
commit
8d8b49dec2
@@ -15,14 +15,6 @@
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using D0DataType = ck_tile::half_t;
|
||||
using D1DataType = ck_tile::half_t;
|
||||
using EDataType = ck_tile::half_t;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
using AccDataType = float;
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
@@ -173,7 +165,38 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<DsDataType::size()>;
|
||||
template <typename DataType>
|
||||
struct GemmMultiDTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmMultiDTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using D0DataType = ck_tile::half_t;
|
||||
using D1DataType = ck_tile::half_t;
|
||||
using EDataType = ck_tile::half_t;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
using AccDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmMultiDTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using D0DataType = ck_tile::bf16_t;
|
||||
using D1DataType = ck_tile::bf16_t;
|
||||
using EDataType = ck_tile::bf16_t;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
using AccDataType = float;
|
||||
};
|
||||
|
||||
// Deduce the number of D tensors from the DsDataType tuple size
|
||||
// All precision configs have the same number of D tensors, so we can use any one
|
||||
constexpr std::size_t NumDTensor = GemmMultiDTypeConfig<ck_tile::bf16_t>::DsDataType::size();
|
||||
|
||||
using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<NumDTensor>;
|
||||
|
||||
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
|
||||
{
|
||||
@@ -190,7 +213,7 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
|
||||
.insert("ds_layout", "R", "Ds tensor data layout - Row by default.")
|
||||
.insert("e_layout", "R", "E tensor data layout - Row by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("prec", "fp16", "data type. fp16")
|
||||
.insert("prec", "bf16", "data type. fp16/bf16")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
@@ -204,7 +227,7 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
|
||||
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_multi_d_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<DsDataType::size()>);
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<NumDTensor>);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
|
||||
@@ -19,6 +19,11 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename D0DataType,
|
||||
typename EDataType,
|
||||
typename AccDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
@@ -86,31 +91,31 @@ float invoke_gemm(int n_warmup,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<ck_tile::GemmTransKernelArg<DsDataType::size()>> kargs;
|
||||
std::vector<ck_tile::GemmTransKernelArg<NumDTensor>> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
const bool splitk = args[0].k_batch > 1;
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, 2>{{arg.a_ptr},
|
||||
{arg.b_ptr},
|
||||
arg.ds_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
{arg.stride_A},
|
||||
{arg.stride_B},
|
||||
arg.stride_Ds,
|
||||
arg.stride_E,
|
||||
arg.k_batch});
|
||||
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, NumDTensor>{{arg.a_ptr},
|
||||
{arg.b_ptr},
|
||||
arg.ds_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
{arg.stride_A},
|
||||
{arg.stride_B},
|
||||
arg.stride_Ds,
|
||||
arg.stride_E,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(
|
||||
kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<DsDataType::size()>),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<NumDTensor>),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time =
|
||||
grouped_gemm_multi_d_tileloop<GemmConfig,
|
||||
ADataType,
|
||||
@@ -128,6 +133,12 @@ float invoke_gemm(int n_warmup,
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename D0DataType,
|
||||
typename D1DataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
@@ -145,6 +156,7 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc,
|
||||
|
||||
using CDElementWise = MultiplyMultiply;
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
|
||||
auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
return !(args.empty() || ...) && group_count == (args.size() == ...);
|
||||
@@ -360,7 +372,9 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc,
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end());
|
||||
|
||||
const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value);
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<ADataType, BDataType, D0DataType, EDataType, AccDataType>(
|
||||
Ks[i], 1, max_accumulated_value);
|
||||
|
||||
pass &=
|
||||
ck_tile::check_err(e_m_n_tensors[i],
|
||||
@@ -390,6 +404,38 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc,
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename PrecType>
|
||||
int run_gemm_multi_d_example_prec_type(
|
||||
std::string a_layout, std::string b_layout, std::string ds_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Types = GemmMultiDTypeConfig<PrecType>;
|
||||
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using D0DataType = typename Types::D0DataType;
|
||||
using D1DataType = typename Types::D1DataType;
|
||||
using AccDataType = typename Types::AccDataType;
|
||||
using EDataType = typename Types::EDataType;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C" && ds_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_multi_d_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
AccDataType,
|
||||
EDataType>(
|
||||
argc, argv, Row{}, Col{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for provided tensors!");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig>
|
||||
int run_grouped_gemm_multi_d_example(int argc, char* argv[])
|
||||
{
|
||||
@@ -401,17 +447,21 @@ int run_grouped_gemm_multi_d_example(int argc, char* argv[])
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string ds_layout = arg_parser.get_str("ds_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C" && ds_layout == "R")
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_gemm_multi_d_example_with_layouts<GemmConfig>(
|
||||
argc, argv, Row{}, Col{}, Row{}, Row{}, Row{});
|
||||
return run_gemm_multi_d_example_prec_type<GemmConfig, ck_tile::half_t>(
|
||||
a_layout, b_layout, ds_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_multi_d_example_prec_type<GemmConfig, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, ds_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for provided tensors!");
|
||||
throw std::runtime_error(
|
||||
"Unsupported data type configuration. Only fp16 and bf16 are supported.");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user