Fixing supported data type instances issue

This commit is contained in:
Aleksander Dudek
2025-07-02 09:20:20 -05:00
parent 52388ae514
commit 85d3fd8d27

View File

@@ -275,52 +275,53 @@ int run_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf16")
// if(data_type == "fp16")
//{
// return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
// a_layout, b_layout, argc, argv);
//}
// else
if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "int8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
ck_tile::int8_t,
ck_tile::int8_t,
ck_tile::int32_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "pk_int4_t")
{
// TODO: Add support for bhalf_t ADataType
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
ck_tile::half_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
// else if(data_type == "fp8")
//{
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
// ck_tile::fp8_t,
// ck_tile::fp8_t,
// ck_tile::half_t>(a_layout, b_layout, argc, argv);
//}
// else if(data_type == "bf8")
//{
// return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
// ck_tile::bf8_t,
// ck_tile::bf8_t,
// ck_tile::half_t>(a_layout, b_layout, argc, argv);
//}
// else if(data_type == "int8")
//{
// return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
// ck_tile::int8_t,
// ck_tile::int8_t,
// ck_tile::int32_t>(a_layout, b_layout, argc, argv);
//}
// else if(data_type == "pk_int4_t")
//{
// // TODO: Add support for bhalf_t ADataType
// if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
// {
// return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
// ck_tile::half_t,
// ck_tile::pk_int4_t,
// ck_tile::half_t>(a_layout, b_layout, argc, argv);
// }
// else
// {
// throw std::runtime_error("Unsupported pipeline for this operation !!!");
// }
//}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");