make changes so that ck_tile gemm example runs on ctr-navi4x-aj50-ws02 / Navi4x

This commit is contained in:
pnikolic-amd
2025-06-12 11:43:15 -04:00
parent 7ea1508b59
commit aea72fabf7
2 changed files with 13 additions and 8 deletions

View File

@@ -186,7 +186,8 @@ int run_gemm_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
// return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
throw std::runtime_error("Unsupported data type for this operation !!!");
}
else if(data_type == "bf16")
{
@@ -194,21 +195,24 @@ int run_gemm_example(int argc, char* argv[])
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
// return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
// a_layout, b_layout, argc, argv);
throw std::runtime_error("Unsupported data type for this operation !!!");
}
else if(data_type == "bf8")
{
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
// return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
// a_layout, b_layout, argc, argv);
throw std::runtime_error("Unsupported data type for this operation !!!");
}
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
else if(data_type == "pk_int4_t")
{
// TODO: Add support for bhalf_t ADataType
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
// return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
// a_layout, b_layout, argc, argv);
throw std::runtime_error("Unsupported data type for this operation !!!");
}
#endif
else

View File

@@ -248,7 +248,8 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
// ck_tile::index_t kbatch = arg_parser.get_int("split_k");
ck_tile::index_t kbatch = 1;
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");