mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Merge branch 'develop' into ckTileEnginePooling
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)")
|
||||
set(GEMM_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM (semicolon-separated)")
|
||||
set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
|
||||
@@ -11,12 +11,12 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
#include "gemm_profiler.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
|
||||
// The kernel header is included via the compile command line with -include flag
|
||||
// It defines SelectedKernel struct and KERNEL_NAME
|
||||
// DataTypeTraits are now defined in gemm_common.hpp
|
||||
|
||||
// Create argument parser
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
@@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[])
|
||||
|
||||
void benchmark_single(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Use DataTypeTraits to get the actual type names from the generated header
|
||||
// Use ck_tile::DataTypeTraits to get the actual type names from the generated header
|
||||
// The generated header defines ADataType, BDataType, AccDataType, CDataType
|
||||
std::string dtype_a = DataTypeTraits<ADataType>::name;
|
||||
std::string dtype_b = DataTypeTraits<BDataType>::name;
|
||||
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
|
||||
std::string dtype_c = DataTypeTraits<CDataType>::name;
|
||||
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
|
||||
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
|
||||
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
|
||||
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
|
||||
|
||||
// Layout names from the layout types
|
||||
std::string layout_a = ALayout::name;
|
||||
|
||||
@@ -9,65 +9,6 @@
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
|
||||
//[TODO] This can be moved to commons
|
||||
// DataTypeTraits for all supported types
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
// Helper function to determine if a layout is row-major
|
||||
template <typename Layout>
|
||||
constexpr auto is_row_major(Layout)
|
||||
|
||||
@@ -337,13 +337,6 @@ class GemmKernelBuilder:
|
||||
"compv4": "ck_tile::GemmPipelineAgBgCrCompV4",
|
||||
}
|
||||
|
||||
# Map pipeline names to base pipeline for hot loop detection
|
||||
base_pipeline_map = {
|
||||
"mem": "ck_tile::BaseGemmPipelineAgBgCrMem",
|
||||
"compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3",
|
||||
"compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4",
|
||||
}
|
||||
|
||||
# Map scheduler names to the correct enum values
|
||||
scheduler_type_map = {
|
||||
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
|
||||
@@ -423,33 +416,10 @@ struct SelectedKernel {{
|
||||
|
||||
// Tile partitioner
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<TileShape, 8, 4>;
|
||||
|
||||
// Traits
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, NumWaveGroups>;
|
||||
|
||||
// Pipeline problem
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
TileShape,
|
||||
Traits>;
|
||||
|
||||
// Base pipeline for hot loop detection
|
||||
using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}<GemmPipelineProblem>;
|
||||
|
||||
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
|
||||
const ck_tile::index_t k_grain = args.k_batch * TileK;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{{0}};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
const auto Run = [&](const auto memory_operation_) {{
|
||||
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
|
||||
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
@@ -462,9 +432,7 @@ struct SelectedKernel {{
|
||||
ALayout, BLayout, CLayout, TransposeC,
|
||||
UseStructuredSparsity, UsePersistentKernel,
|
||||
NumWaveGroups, Preshuffle>,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
|
||||
|
||||
@@ -542,28 +510,23 @@ struct SelectedKernel {{
|
||||
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
|
||||
if(args.k_batch == 1) {{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{{}});
|
||||
}} else {{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{{}});
|
||||
}}
|
||||
}};
|
||||
float ave_time = 0.f;
|
||||
|
||||
if(args.k_batch == 1) {{
|
||||
ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{{}});
|
||||
}} else {{
|
||||
ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{{}});
|
||||
}}
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}}
|
||||
}};
|
||||
|
||||
Reference in New Issue
Block a user