mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK TILE] Use config name instead of data type in FmhaFwdTypeConfig<config> (#1731)
* Add data type config, Prepare to add mix precision in the future * Fix compile error
This commit is contained in:
@@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
template <typename DataTypeConfig>
|
||||
auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
@@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
|
||||
auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
@@ -122,7 +122,7 @@ auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
template <typename DataTypeConfig>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
@@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
|
||||
const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k);
|
||||
|
||||
using TypeConfig = FmhaBwdTypeConfig<DataType>;
|
||||
using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
|
||||
|
||||
using QDataType = typename TypeConfig::QDataType;
|
||||
using KDataType = typename TypeConfig::KDataType;
|
||||
@@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
|
||||
bool dq_cur_pass = ck_tile::check_err(dq_host_result,
|
||||
dq_host_ref,
|
||||
std::string("Error: QGrad Incorrect results!"),
|
||||
@@ -986,11 +986,11 @@ int main(int argc, char* argv[])
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
return run<FmhaBwdFp16>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
return run<FmhaBwdBf16>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
|
||||
Reference in New Issue
Block a user