diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index ef9bbc3ab2..435ab17687 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -275,52 +275,53 @@ int run_gemm_example(int argc, char* argv[]) std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); - if(data_type == "fp16") - { - return run_gemm_example_prec_type, ck_tile::half_t>( - a_layout, b_layout, argc, argv); - } - else if(data_type == "bf16") + // if(data_type == "fp16") + //{ + // return run_gemm_example_prec_type, ck_tile::half_t>( + // a_layout, b_layout, argc, argv); + //} + // else + if(data_type == "bf16") { return run_gemm_example_prec_type, ck_tile::bf16_t>( a_layout, b_layout, argc, argv); } - else if(data_type == "fp8") - { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::fp8_t, - ck_tile::half_t>(a_layout, b_layout, argc, argv); - } - else if(data_type == "bf8") - { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::bf8_t, - ck_tile::half_t>(a_layout, b_layout, argc, argv); - } - else if(data_type == "int8") - { - return run_gemm_example_prec_type, - ck_tile::int8_t, - ck_tile::int8_t, - ck_tile::int32_t>(a_layout, b_layout, argc, argv); - } - else if(data_type == "pk_int4_t") - { - // TODO: Add support for bhalf_t ADataType - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) - { - return run_gemm_example_prec_type, - ck_tile::half_t, - ck_tile::pk_int4_t, - ck_tile::half_t>(a_layout, b_layout, argc, argv); - } - else - { - throw std::runtime_error("Unsupported pipeline for this operation !!!"); - } - } + // else if(data_type == "fp8") + //{ + // return run_gemm_example_prec_type, + // ck_tile::fp8_t, + // ck_tile::fp8_t, + // ck_tile::half_t>(a_layout, b_layout, argc, argv); + //} + // else if(data_type == "bf8") + //{ + // return run_gemm_example_prec_type, + // ck_tile::bf8_t, + // ck_tile::bf8_t, + // ck_tile::half_t>(a_layout, b_layout, argc, argv); + //} + // else if(data_type == "int8") + //{ + // return run_gemm_example_prec_type, + // ck_tile::int8_t, + // ck_tile::int8_t, + // ck_tile::int32_t>(a_layout, b_layout, argc, argv); + //} + // else if(data_type == "pk_int4_t") + //{ + // // TODO: Add support for bhalf_t ADataType + // if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + // { + // return run_gemm_example_prec_type, + // ck_tile::half_t, + // ck_tile::pk_int4_t, + // ck_tile::half_t>(a_layout, b_layout, argc, argv); + // } + // else + // { + // throw std::runtime_error("Unsupported pipeline for this operation !!!"); + // } + //} else { throw std::runtime_error("Unsupported data type for this operation !!!");