mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[CK TILE] Use config name instead of data type in FmhaFwdTypeConfig<config> (#1731)
* Add data type config, Prepare to add mix precision in the future * Fix compile error
This commit is contained in:
@@ -142,7 +142,7 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
template <typename DataTypeConfig>
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
@@ -151,7 +151,7 @@ auto get_elimit(std::string /*init_method*/)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
|
||||
auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
@@ -159,7 +159,7 @@ auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
auto get_elimit<FmhaFwdFp8>(std::string init_method)
|
||||
{
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
@@ -261,7 +261,7 @@ int override_num_splits_if_necessary(
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
template <typename DataTypeConfig>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
@@ -305,8 +305,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
|
||||
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> ||
|
||||
std::is_same_v<DataType, ck_tile::bf16_t>))
|
||||
if constexpr(!(std::is_same_v<DataTypeConfig, FmhaFwdFp16> ||
|
||||
std::is_same_v<DataTypeConfig, FmhaFwdBf16>))
|
||||
{
|
||||
if(0 < rotary_dim)
|
||||
{
|
||||
@@ -428,25 +428,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return atoi(squant_str.c_str()) != 0 ? true : false;
|
||||
}();
|
||||
|
||||
float range_q = arg_parser.get_float("range_q");
|
||||
float range_k = arg_parser.get_float("range_k");
|
||||
float range_v = arg_parser.get_float("range_v");
|
||||
float range_p = arg_parser.get_float("range_p");
|
||||
float range_o = arg_parser.get_float("range_o");
|
||||
|
||||
float dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<DataType>::max());
|
||||
|
||||
float scale_p = 1.f;
|
||||
float scale_o = 1.f;
|
||||
|
||||
if(squant)
|
||||
{
|
||||
scale_s = scale_s * (range_q / dtype_max) * (range_k / dtype_max);
|
||||
scale_p = dtype_max / range_p;
|
||||
// scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)]
|
||||
scale_o = range_p * range_v / range_o / dtype_max;
|
||||
}
|
||||
|
||||
std::string vlayout = arg_parser.get_str("vlayout");
|
||||
bool lse = arg_parser.get_bool("lse");
|
||||
|
||||
@@ -499,7 +480,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
||||
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
|
||||
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataType>;
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
|
||||
|
||||
using QDataType = typename TypeConfig::QDataType;
|
||||
using KDataType = typename TypeConfig::KDataType;
|
||||
@@ -513,6 +494,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using OaccDataType = typename TypeConfig::OaccDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
|
||||
float range_q = arg_parser.get_float("range_q");
|
||||
float range_k = arg_parser.get_float("range_k");
|
||||
float range_v = arg_parser.get_float("range_v");
|
||||
float range_p = arg_parser.get_float("range_p");
|
||||
float range_o = arg_parser.get_float("range_o");
|
||||
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
float p_dtype_max = v_dtype_max; // assume p and v is the same type
|
||||
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
||||
|
||||
float scale_p = 1.f;
|
||||
float scale_o = 1.f;
|
||||
|
||||
if(squant)
|
||||
{
|
||||
scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max);
|
||||
scale_p = p_dtype_max / range_p;
|
||||
scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max);
|
||||
}
|
||||
|
||||
// accumulation numbers for performance evaluation
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
auto max_seqlen_q =
|
||||
@@ -709,14 +712,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else if(init_method == "ufq" || init_method == "uf:q" ||
|
||||
init_method == "3") // suitable for fp8 quantization
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistribution<QDataType>{-q_dtype_max, q_dtype_max, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(vnew_host);
|
||||
|
||||
// bias_fp8 = qscale_bias * bias_fp32
|
||||
float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k);
|
||||
float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k);
|
||||
// Assume bias is in [-1.f, 1.f] in original fp32
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
|
||||
}
|
||||
@@ -1129,14 +1132,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
auto p_compute_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
|
||||
if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
|
||||
return ck_tile::scales{scale_p};
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
|
||||
if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o});
|
||||
else
|
||||
@@ -1186,7 +1189,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
|
||||
|
||||
auto [rotary_cos_slice, rotary_sin_slice] =
|
||||
auto [rotary_cos_slice, rotary_sin_slice] =
|
||||
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q);
|
||||
|
||||
ck_tile::reference_batched_rotary_position_embedding(
|
||||
@@ -1202,13 +1205,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
k_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]);
|
||||
});
|
||||
} else {
|
||||
} else {
|
||||
k_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]);
|
||||
});
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
#endif
|
||||
{
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
@@ -1229,7 +1232,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
|
||||
|
||||
auto [rotary_cos_slice, rotary_sin_slice] =
|
||||
auto [rotary_cos_slice, rotary_sin_slice] =
|
||||
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew);
|
||||
|
||||
ck_tile::reference_batched_rotary_position_embedding(
|
||||
@@ -1251,19 +1254,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(0 < page_block_size) {
|
||||
if(is_v_rowmajor) {
|
||||
if(i_perm) {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]);
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]);
|
||||
});
|
||||
} else {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]);
|
||||
});
|
||||
}
|
||||
}
|
||||
else
|
||||
else
|
||||
{
|
||||
if(i_perm) {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
if(i_perm) {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size);
|
||||
});
|
||||
} else {
|
||||
@@ -1458,7 +1461,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>(init_method);
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool cur_pass = ck_tile::check_err(
|
||||
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
pass &= cur_pass;
|
||||
@@ -1515,15 +1518,15 @@ int main(int argc, char* argv[])
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
return run<FmhaFwdFp16>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
return run<FmhaFwdBf16>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run<ck_tile::fp8_t>(arg_parser) ? 0 : -2;
|
||||
return run<FmhaFwdFp8>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
|
||||
Reference in New Issue
Block a user