mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
update config
This commit is contained in:
@@ -390,6 +390,12 @@ if(ENABLE_ASM_DUMP)
|
||||
message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}")
|
||||
endif()
|
||||
|
||||
option(ENABLE_RESOURCE_USAGE "Enable printing kernel resource usage (VGPR/SGPR/LDS)" OFF)
|
||||
if(ENABLE_RESOURCE_USAGE)
|
||||
add_compile_options(-Rpass-analysis=kernel-resource-usage)
|
||||
message(STATUS "CK compiled with ENABLE_RESOURCE_USAGE - will print VGPR/SGPR/LDS usage")
|
||||
endif()
|
||||
|
||||
if (ENABLE_JSON_DUMP)
|
||||
add_compile_definitions(CK_ENABLE_JSON_DUMP)
|
||||
message("CK compiled with ENABLE_JSON_DUMP set to ${ENABLE_JSON_DUMP}")
|
||||
|
||||
@@ -98,9 +98,34 @@ struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
{
|
||||
// Optimized config: 256x128x64, LDS=48KB, 16x16 warp tile for better occupancy
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); // 64 for bf16
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2; // 4 warps in M
|
||||
static constexpr ck_tile::index_t N_Warp = 2; // 2 warps in N
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
};
|
||||
|
||||
// Optimized config: 256x128x64 with COMPUTE_V3
|
||||
// Smaller N tile for better L2 cache utilization
|
||||
template <typename PrecType>
|
||||
struct GemmConfigOptimized : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); // 64 for bf16
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
@@ -113,6 +138,31 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2; // Better occupancy with smaller tile
|
||||
};
|
||||
|
||||
// Alternative: 128x128x128 config for higher K throughput
|
||||
template <typename PrecType>
|
||||
struct GemmConfigOptimized2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); // 128 for bf16
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -224,9 +274,11 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV6 : public GemmConfigBase
|
||||
{
|
||||
// V6 pipeline: PrefetchStages=3, GlobalBufferNum=2, HotloopUnroll=2
|
||||
// Use 256x128x64 for better balance
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); // 64 for bf16
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
@@ -234,11 +286,13 @@ struct GemmConfigComputeV6 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
|
||||
@@ -21,85 +21,24 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using Invoker = UniversalInvoker;
|
||||
|
||||
if(data_type == "fp16")
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, Invoker, ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, Invoker, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
Invoker,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
Invoker,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "int8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
|
||||
Invoker,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "fp16i4")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
// Only support RC layout (A=Row, B=Column) to reduce compile time
|
||||
if(a_layout != "R" || b_layout != "C")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
|
||||
Invoker,
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else if(data_type == "fp8i4")
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
Invoker,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8i4")
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
Invoker,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
throw std::runtime_error("Only RC layout (A=Row, B=Column) is supported!");
|
||||
}
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
// Directly call with fixed layout to avoid std::visit instantiating all combinations
|
||||
return run_gemm_example_with_layouts<GemmConfig<ck_tile::bf16_t>,
|
||||
Invoker,
|
||||
ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
Row,
|
||||
Col,
|
||||
Row>(arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -117,11 +56,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
|
||||
#else
|
||||
return !run_gemm_example<GemmConfigComputeV3_2>(arg_parser);
|
||||
#endif
|
||||
return !run_gemm_example<GemmConfigComputeV3_1>(arg_parser);
|
||||
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user