update config

This commit is contained in:
Your Name
2026-01-06 08:10:37 +00:00
committed by kyle-256
parent 644cdbe3c9
commit 84f4255e9e
3 changed files with 81 additions and 85 deletions

View File

@@ -390,6 +390,12 @@ if(ENABLE_ASM_DUMP)
message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}")
endif()
option(ENABLE_RESOURCE_USAGE "Enable printing kernel resource usage (VGPR/SGPR/LDS)" OFF)
if(ENABLE_RESOURCE_USAGE)
add_compile_options(-Rpass-analysis=kernel-resource-usage)
message(STATUS "CK compiled with ENABLE_RESOURCE_USAGE - will print VGPR/SGPR/LDS usage")
endif()
if (ENABLE_JSON_DUMP)
add_compile_definitions(CK_ENABLE_JSON_DUMP)
message("CK compiled with ENABLE_JSON_DUMP set to ${ENABLE_JSON_DUMP}")

View File

@@ -98,9 +98,34 @@ struct GemmConfigComputeV3 : public GemmConfigBase
template <typename PrecType>
struct GemmConfigComputeV3_1 : public GemmConfigBase
{
// Optimized config: 256x128x64, LDS=48KB, 16x16 warp tile for better occupancy
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); // 64 for bf16
static constexpr ck_tile::index_t M_Warp = 2; // 4 warps in M
static constexpr ck_tile::index_t N_Warp = 2; // 2 warps in N
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr int kBlockPerCu = 1;
};
// Optimized config: 256x128x64 with COMPUTE_V3
// Smaller N tile for better L2 cache utilization
template <typename PrecType>
struct GemmConfigOptimized : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); // 64 for bf16
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
@@ -113,6 +138,31 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr int kBlockPerCu = 2; // Better occupancy with smaller tile
};
// Alternative: 128x128x128 config for higher K throughput
template <typename PrecType>
struct GemmConfigOptimized2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); // 128 for bf16
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
@@ -224,9 +274,11 @@ struct GemmConfigComputeV5 : public GemmConfigBase
template <typename PrecType>
struct GemmConfigComputeV6 : public GemmConfigBase
{
// V6 pipeline: PrefetchStages=3, GlobalBufferNum=2, HotloopUnroll=2
// Use 256x128x64 for better balance
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); // 64 for bf16
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
@@ -234,11 +286,13 @@ struct GemmConfigComputeV6 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr int kBlockPerCu = 1;
};
template <typename PrecType>

View File

@@ -21,85 +21,24 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
using Invoker = UniversalInvoker;
if(data_type == "fp16")
if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, Invoker, ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, Invoker, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
Invoker,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "bf8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
Invoker,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "int8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
Invoker,
ck_tile::int8_t,
ck_tile::int8_t,
ck_tile::int32_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "fp16i4")
{
// TODO: Add support for bhalf_t ADataType
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
// Only support RC layout (A=Row, B=Column) to reduce compile time
if(a_layout != "R" || b_layout != "C")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
Invoker,
ck_tile::half_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
else if(data_type == "fp8i4")
{
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
Invoker,
ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
else if(data_type == "bf8i4")
{
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
Invoker,
ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported pipeline for this operation !!!");
throw std::runtime_error("Only RC layout (A=Row, B=Column) is supported!");
}
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// Directly call with fixed layout to avoid std::visit instantiating all combinations
return run_gemm_example_with_layouts<GemmConfig<ck_tile::bf16_t>,
Invoker,
ck_tile::bf16_t,
ck_tile::bf16_t,
ck_tile::bf16_t,
Row,
Col,
Row>(arg_parser, Row{}, Col{}, Row{});
}
else
{
@@ -117,11 +56,8 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
#else
return !run_gemm_example<GemmConfigComputeV3_2>(arg_parser);
#endif
return !run_gemm_example<GemmConfigComputeV3_1>(arg_parser);
}
catch(const std::runtime_error& e)
{