mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Update flatmm related kernels (#3022)
---------
Co-authored-by: Ding, Yi <yi.ding@amd.com>
Co-authored-by: felix <felix.li@amd.com>
[ROCm/composable_kernel commit: 211d64e18a]
This commit is contained in:
@@ -1,6 +1,32 @@
|
||||
add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)
|
||||
set(SUPPORTED_GPUS gfx908 gfx90a gfx942 gfx950)
|
||||
|
||||
set(has_supported_gpu FALSE)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST SUPPORTED_GPUS)
|
||||
set(has_supported_gpu TRUE)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(has_supported_gpu)
|
||||
add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)
|
||||
add_executable(tile_example_mixed_prec_flatmm EXCLUDE_FROM_ALL mixed_prec/mixed_prec_flatmm.cpp)
|
||||
add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp)
|
||||
add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp)
|
||||
add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp)
|
||||
|
||||
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
|
||||
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
|
||||
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
endif()
|
||||
|
||||
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
|
||||
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter)
|
||||
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -11,7 +11,102 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "flatmm_basic.hpp"
|
||||
#include "run_flatmm_example.inc"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename T>
|
||||
constexpr const char* DataTypeToString()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck_tile::half_t>)
|
||||
{
|
||||
return "fp16";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::fp8_t>)
|
||||
{
|
||||
return "fp8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf8_t>)
|
||||
{
|
||||
return "bf8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
return "bf16";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
constexpr int MaxVecSize = 16 / sizeof(T);
|
||||
constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile;
|
||||
constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane);
|
||||
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / ItemsPerAccess,
|
||||
ItemsPerAccess});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b_v1(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
constexpr int MaxVecSize = 16 / sizeof(T);
|
||||
constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile;
|
||||
constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane);
|
||||
constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp;
|
||||
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Tile,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / ItemsPerAccess,
|
||||
ItemsPerAccess});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5});
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
@@ -23,9 +118,12 @@ template <typename FlatmmConfig,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s)
|
||||
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
@@ -80,14 +178,14 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
@@ -110,7 +208,10 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups>>;
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
@@ -118,8 +219,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -167,40 +268,145 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
return ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool UsePersistentKernel = false,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_shuffle_dev_buf,
|
||||
ck_tile::DeviceMem& c_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
ScaleM scale_m,
|
||||
ScaleN scale_n,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C,
|
||||
scale_m,
|
||||
scale_n};
|
||||
|
||||
float ave_time = flatmm_calc<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ScaleM,
|
||||
ScaleN,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
|
||||
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
|
||||
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "256", "m dimension")
|
||||
.insert("n", "256", "n dimension")
|
||||
.insert("k", "128", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
template <template <typename PreType> typename FlatmmConfig>
|
||||
int run_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
@@ -214,20 +420,10 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
int k = arg_parser.get_int("k");
|
||||
int stride_b = arg_parser.get_int("stride_b");
|
||||
|
||||
if(b_layout == "C" && stride_b > k)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"For ColumnMajor layout, StrideB must be smaller than or equal to K (" +
|
||||
std::to_string(k) + ")");
|
||||
}
|
||||
|
||||
int scale_opt = arg_parser.get_int("scale");
|
||||
int persistent_opt = arg_parser.get_int("persistent");
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
@@ -240,13 +436,53 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
if(scale_opt == 0)
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
-1,
|
||||
-1,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
1,
|
||||
1>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
1,
|
||||
1,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
if(scale_opt == 0)
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -268,9 +504,6 @@ int main(int argc, char* argv[])
|
||||
|
||||
try
|
||||
{
|
||||
#if defined(CK_TILE_USE_WMMA)
|
||||
return !run_flatmm_example<FlatmmConfig16_Wmma>(argc, argv);
|
||||
#else
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
@@ -288,7 +521,6 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
@@ -35,12 +35,13 @@ struct FlatmmConfig32
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
@@ -72,26 +73,28 @@ struct FlatmmConfig16
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 4 == 0;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
|
||||
};
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16_Wmma : public FlatmmConfig16<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
static constexpr int N_Repeat =
|
||||
N_Tile / FlatmmConfig16<DataType>::N_Warp_Tile / FlatmmConfig16<DataType>::N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 4 == 0;
|
||||
};
|
||||
|
||||
template <typename ADataType>
|
||||
@@ -172,42 +175,19 @@ struct is_8bit_type
|
||||
{
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "256", "m dimension")
|
||||
.insert("n", "256", "n dimension")
|
||||
.insert("k", "128", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
#if !defined(CK_TILE_USE_WMMA)
|
||||
.insert(
|
||||
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
|
||||
#endif
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "flatmm_basic.json", "json file name to dump results");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename FlatmmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s);
|
||||
|
||||
364
example/ck_tile/18_flatmm/grouped_flatmm.cpp
Normal file
364
example/ck_tile/18_flatmm/grouped_flatmm.cpp
Normal file
@@ -0,0 +1,364 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "flatmm_basic.hpp"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("Ms", "1,1,1", "m dimension")
|
||||
.insert("Ns", "5120,5120,5120", "n dimension")
|
||||
.insert("Ks", "6144,6144,6144", "k dimension")
|
||||
.insert("group_count", "3", "group count")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("mode",
|
||||
"masked",
|
||||
"grouped gemm mode: [general | contiguous | masked], general by default")
|
||||
.insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool persistent,
|
||||
typename CDEElementWise,
|
||||
typename KernelArguments>
|
||||
float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel =
|
||||
ck_tile::GroupedFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.group_count * args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.group_count * args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(args.e_ptr,
|
||||
0,
|
||||
args.group_count * args.M * args.N * sizeof(CDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_grouped_flatmm_example.inc"
|
||||
|
||||
template <template <typename PreType> typename FlatmmConfig>
|
||||
int run_grouped_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string mode = arg_parser.get_str("mode");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(mode == "contiguous")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::bf16_t,
|
||||
FlatmmConfig<ck_tile::bf16_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else if(mode == "masked")
|
||||
{
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_masked_grouped_flatmm_example_with_layouts<ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
run_masked_grouped_flatmm_example_with_layouts<ck_tile::bf16_t,
|
||||
FlatmmConfig<ck_tile::bf16_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
run_masked_grouped_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
run_masked_grouped_flatmm_example_with_layouts<ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported mode!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_grouped_flatmm_example<FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
// else if(warp_tile == 1)
|
||||
// {
|
||||
// return !run_grouped_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
// }
|
||||
// else if(warp_tile == 2)
|
||||
// {
|
||||
// return !run_grouped_flatmm_example<FlatmmConfig16_950>(argc, argv);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// return !run_grouped_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
// }
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
50
example/ck_tile/18_flatmm/mixed_prec/a16w4_flatmm.hpp
Normal file
50
example/ck_tile/18_flatmm/mixed_prec/a16w4_flatmm.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
N_Tile / A16W4_FlatmmConfig16::N_Warp_Tile / A16W4_FlatmmConfig16::N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
511
example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp
Normal file
511
example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp
Normal file
@@ -0,0 +1,511 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "a16w4_moe_flatmm.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_gemm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// gemm1
|
||||
// operand-A = [num_token, d_model]
|
||||
// operand-B = [num_expert, hidden, d_model]
|
||||
// operand-C = [num_token, topk, hidden]
|
||||
|
||||
// gemm2
|
||||
// operand-A = [num_token, topk, hidden]
|
||||
// operand-B = [num_expert, d_model, hidden]
|
||||
// operand-C = [num_token, d_model]
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind moe_kind = ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename MoeFlatmmHostArgs>
|
||||
float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
false, // UsePersistentKernel_
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>; // Preshuffle_
|
||||
|
||||
constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t>;
|
||||
|
||||
if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
static_assert(
|
||||
FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0,
|
||||
"requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
|
||||
}
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
|
||||
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using CodegenPipelineProblem =
|
||||
std::conditional_t<MXFP4_Pipeline,
|
||||
ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using CodegenFlatmmPipeline = std::conditional_t<
|
||||
MXFP4_Pipeline,
|
||||
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>,
|
||||
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>>;
|
||||
using FusedAct =
|
||||
std::conditional_t<MXFP4_Pipeline, ck_tile::moe::Swiglu, ck_tile::moe::MoeSilu>;
|
||||
|
||||
using Kernel = ck_tile::MoeFlatmmKernel<TilePartitioner,
|
||||
CodegenFlatmmPipeline,
|
||||
GemmEpilogue,
|
||||
moe_kind,
|
||||
FusedAct>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
|
||||
? 2
|
||||
: 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
|
||||
? 2
|
||||
: 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2 ? args.NumTokens * args.TopK
|
||||
: args.NumTokens,
|
||||
args.K,
|
||||
args.stride_A,
|
||||
is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N * args.NumExperts, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
const int outputN =
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? args.N / 2 : args.N;
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.NumTokens * args.N * sizeof(CDataType), s.stream_id_));
|
||||
else if(args.k_batch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(args.e_ptr,
|
||||
0,
|
||||
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, ck_tile::MoeFlatmmKind moe_kind, class IterSrc, class IterDst>
|
||||
void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N, int K)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = FlatmmConfig::N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
|
||||
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
int up_stride = N / 2 / NLane;
|
||||
|
||||
for(long eid = 0; eid < experts_cnt; ++eid)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
// interleave gate and up part with granularity is 16.
|
||||
int n0_interleave = n >= N / 2 ? (n0 - up_stride) * 2 + 1 : // up part
|
||||
n0 * 2; // gate part
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
long outputIndex = eid * N * K_pk + n0_interleave * KPack * NLane * KLane * K0 +
|
||||
k0 * KPack * NLane * KLane + k1 * KPack * NLane +
|
||||
n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(long eid = 0; eid < experts_cnt; ++eid)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
long outputIndex = eid * N * K_pk + n0 * KPack * NLane * KLane * K0 +
|
||||
k0 * KPack * NLane * KLane + k1 * KPack * NLane +
|
||||
n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig, ck_tile::MoeFlatmmKind moe_kind, typename T>
|
||||
auto shuffle_mxfp4_scale(const ck_tile::HostTensor<T>& scale, int experts_cnt)
|
||||
{
|
||||
assert(scale.get_lengths().size() == 2);
|
||||
int n_ = scale.get_lengths()[1];
|
||||
int k_ = scale.get_lengths()[0];
|
||||
|
||||
int k_per_expert = k_ / experts_cnt;
|
||||
|
||||
constexpr int K_Pack = 2; // fixed for mxfp4
|
||||
constexpr int N_Pack = 2; // fixed for mxfp4
|
||||
constexpr int GranularityK = 32; // fixed for mxfp4
|
||||
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
|
||||
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
|
||||
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
|
||||
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
|
||||
|
||||
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
experts_cnt,
|
||||
k_per_expert / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
N_Pack, // N_Pack = 2 is composed of Gate + Up.
|
||||
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {0, 5, 1, 3, 6, 2, 4});
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
experts_cnt,
|
||||
k_per_expert / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
|
||||
N_Pack,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {0, 4, 1, 3, 6, 2, 5});
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_a16w4_moe_flatmm_example.inc"
|
||||
|
||||
template <typename FlatmmConfig>
|
||||
int run_a16w4_moe_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
const std::string mixed_prec = arg_parser.get_str("mixed_prec");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
const std::string gemm_kind = arg_parser.get_str("gemm_kind");
|
||||
if(gemm_kind == "gemm1_gate_up")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm2")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm2!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
|
||||
"[gemm1_gate_up | gemm2]");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
// else if(warp_tile == 1)
|
||||
// {
|
||||
// return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
|
||||
// }
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
87
example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp
Normal file
87
example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp
Normal file
@@ -0,0 +1,87 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
N_Tile / A16W4_FlatmmConfig16::N_Warp_Tile / A16W4_FlatmmConfig16::N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("experts", "8", "Num of experts - 8 by default")
|
||||
.insert("NumTokens", "128", "M dimensions - 128 by default.")
|
||||
.insert("TopK", "3", "Top K - 3 by default.")
|
||||
.insert("N", "4096", "N dimensions - 4096 by default.")
|
||||
.insert("K", "4096", "K dimensions - 4096 by default.")
|
||||
.insert("stride_A", "", "Tensor A strides - it is empty by default.")
|
||||
.insert("stride_B", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_C", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("gemm_kind",
|
||||
"gemm1_gate_up",
|
||||
"Gemm kind in FFN network [gemm1_gate_up | gemm2] - "
|
||||
"gemm1_gate_up by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("mixed_prec",
|
||||
"bf16xfp4",
|
||||
"data type for activation and weight, support: bf16xfp4, fp16xfp4")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 16x16 (950 only, may use a larger tile than warp_tile=0)")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
482
example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp
Normal file
482
example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp
Normal file
@@ -0,0 +1,482 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mixed_prec_flatmm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>;
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
|
||||
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using Kernel =
|
||||
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleN,
|
||||
bool UsePersistentKernel = false,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_mixed_prec_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_shuffle_dev_buf,
|
||||
ck_tile::DeviceMem& c_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
ScaleN dequant_scale_n,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
// Activation has no scale
|
||||
using ActScaleType = ck_tile::FlatmmScalePointer<-1>;
|
||||
|
||||
ck_tile::ScaleFlatmmHostArgs<ActScaleType, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C,
|
||||
{},
|
||||
dequant_scale_n};
|
||||
|
||||
float ave_time = mixed_prec_flatmm_calc<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ActScaleType,
|
||||
ScaleN,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K / PackedSize +
|
||||
sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run A16W4_Flatmm kernel " << " M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "256", "m dimension")
|
||||
.insert("n", "256", "n dimension")
|
||||
.insert("k", "512", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on GPU")
|
||||
.insert("mixed_prec",
|
||||
"bf16xfp4",
|
||||
"data type for activation and weight, support: bf16xfp4, fp16xfp4")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, class IterSrc, class IterDst>
|
||||
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = FlatmmConfig::N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, class T>
|
||||
auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
|
||||
{
|
||||
assert(scale.get_lengths().size() == 2);
|
||||
int n_ = scale.get_lengths()[1];
|
||||
int k_ = scale.get_lengths()[0];
|
||||
|
||||
constexpr int K_Pack = 2; // fixed for mxfp4
|
||||
constexpr int N_Pack = 2; // fixed for mxfp4
|
||||
constexpr int GranularityK = 32; // fixed for mxfp4
|
||||
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
|
||||
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
|
||||
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
|
||||
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
|
||||
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
k_ / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
|
||||
N_Pack,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {3, 0, 2, 5, 1, 4});
|
||||
}
|
||||
|
||||
#include "run_mixed_prec_flatmm.inc"
|
||||
|
||||
template <typename FlatmmConfig>
|
||||
int run_mixed_prec_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string mixed_prec = arg_parser.get_str("mixed_prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
int persistent_opt = arg_parser.get_int("persistent");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported warp_tile!");
|
||||
}
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
15
example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.hpp
Normal file
15
example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.hpp
Normal file
@@ -0,0 +1,15 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include "a16w4_flatmm.hpp"
|
||||
@@ -0,0 +1,353 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename MoeHostArgs>
|
||||
float invoke_a16w4_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
|
||||
{
|
||||
float ave_time = a16w4_moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
kind,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::string op_name{"Moe Gemm"};
|
||||
|
||||
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
|
||||
sizeof(BDataType) * args.N * args.K / PackedSize +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecActType,
|
||||
typename PrecWeightType,
|
||||
typename FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using CDataType = PrecActType;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
constexpr int ScaleGranularityN = 1;
|
||||
constexpr int ScaleGranularityK = 32;
|
||||
|
||||
const ck_tile::index_t N = arg_parser.get_int("N");
|
||||
const ck_tile::index_t K = arg_parser.get_int("K");
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_A");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_B");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_C");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
|
||||
const ck_tile::index_t topk = arg_parser.get_int("TopK");
|
||||
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
|
||||
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
|
||||
const ck_tile::index_t experts = arg_parser.get_int("experts");
|
||||
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
|
||||
|
||||
ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * MPerBlock * topk;
|
||||
ck_tile::index_t valid_tile_num = sorted_tile_num;
|
||||
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
|
||||
const ck_tile::index_t M = sorted_tile_num * MPerBlock;
|
||||
const ck_tile::index_t outputN = kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? N / 2 : N;
|
||||
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr bool IsInputGemm = kind != ck_tile::MoeFlatmmKind::kFFN_gemm2;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout)));
|
||||
auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
|
||||
is_row_major(b_layout)
|
||||
? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
|
||||
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
|
||||
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::HostTensorDescriptor(
|
||||
{K * experts / ScaleGranularityK, N / ScaleGranularityN}, {N / ScaleGranularityN, 1}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.f, 1.f}(scale_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.0f, 1.0f}(scale_b);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host(
|
||||
ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
|
||||
shuffle_mxfp4_weight<FlatmmConfig, kind>(
|
||||
b_k_n_tensor.begin(), b_shuffle_host.begin(), experts, N, K);
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b_shuffle =
|
||||
shuffle_mxfp4_scale<FlatmmConfig, kind>(scale_b, experts);
|
||||
ck_tile::DeviceMem scale_b_shuffle_dev_buf(scale_b_shuffle.get_element_space_size_in_bytes());
|
||||
|
||||
std::cout << "moe_flatmm:" << "\n num_experts: " << experts << "\n num_tokens: " << num_tokens
|
||||
<< "\n topk: " << topk << "\n sorted_tile_num: " << sorted_tile_num
|
||||
<< "\n problem_n: " << N << "\n problem_k: " << K
|
||||
<< "\n a_m_k: " << a_m_k_tensor.mDesc << "\n b_k_n: " << b_k_n_tensor.mDesc
|
||||
<< "\n b_shuffle: " << b_shuffle_host.mDesc << "\n c_m_n: " << c_m_n_tensor.mDesc
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::HostTensor<ck_tile::index_t> expert_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> expert_weight(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
|
||||
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
ck_tile::HostTensor<AccDataType> expert_bias(ck_tile::HostTensorDescriptor({experts * N}, {1}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
// for verification only, no need to satify weight normalization
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.0f, 1.0f}(expert_bias);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.0f, 1.0f}(expert_weight);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 0.0f}(expert_bias);
|
||||
}
|
||||
|
||||
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
|
||||
// int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
|
||||
}
|
||||
|
||||
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
// int token_per_tile = num_tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
// sorted_token_ids.mData[0] = 0;
|
||||
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = num_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf{a_m_k_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_origin_dev_buf{b_k_n_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf{b_shuffle_host.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem c_m_n_dev_buf{c_m_n_tensor.get_element_space_size_in_bytes()};
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_tensor.data());
|
||||
b_origin_dev_buf.ToDevice(b_k_n_tensor.data());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_tensor.SetZero();
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_dev{sorted_token_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_ids_dev{expert_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem max_token_id_dev{max_token_id.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_weight_dev{expert_weight.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_bias_dev{expert_bias.get_element_space_size_in_bytes()};
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.data());
|
||||
expert_weight_dev.ToDevice(expert_weight.data());
|
||||
expert_bias_dev.ToDevice(expert_bias.data());
|
||||
scale_b_shuffle_dev_buf.ToDevice(scale_b_shuffle.data());
|
||||
|
||||
const ck_tile::index_t* p_sorted_token_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(sorted_token_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_expert_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(expert_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_max_token_id_dev =
|
||||
static_cast<ck_tile::index_t*>(max_token_id_dev.GetDeviceBuffer());
|
||||
const AccDataType* p_sorted_expert_weight_dev =
|
||||
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
|
||||
|
||||
auto scale_b_shuffle_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
auto exp_bias_dev_ptr = ck_tile::FlatmmScalePointer<1>{
|
||||
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()), experts * N};
|
||||
|
||||
using MoeFlatmmArgs = ck_tile::MoeFlatmmHostArgs<
|
||||
ck_tile::FlatmmScalePointer<-1>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>,
|
||||
ck_tile::FlatmmScalePointer<1>>;
|
||||
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
|
||||
p_sorted_expert_weight_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
num_tokens,
|
||||
experts,
|
||||
topk,
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
nullptr,
|
||||
scale_b_shuffle_dev_ptr,
|
||||
exp_bias_dev_ptr};
|
||||
|
||||
invoke_a16w4_moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
kind>(warmup, repeat, gemm_desc);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("validate"))
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(IsInputGemm ? num_tokens * topk : num_tokens,
|
||||
outputN,
|
||||
stride_C,
|
||||
is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::HostTensor<AccDataType> scale_A(
|
||||
ck_tile::HostTensorDescriptor({1, K / ScaleGranularityK}, {1, 1}));
|
||||
|
||||
// scaleA = 1 has no effect on the result
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(scale_A);
|
||||
ck_tile::DeviceMem scale_A_dev_buf(scale_A.get_element_space_size_in_bytes());
|
||||
scale_A_dev_buf.ToDevice(scale_A.data());
|
||||
|
||||
// convert scale_b from e8m0 to float
|
||||
ck_tile::HostTensor<AccDataType> scale_b_float(ck_tile::HostTensorDescriptor(
|
||||
{K * experts / ScaleGranularityK, N / ScaleGranularityN}, {N / ScaleGranularityN, 1}));
|
||||
std::copy(scale_b.begin(), scale_b.end(), scale_b_float.begin());
|
||||
ck_tile::DeviceMem scale_b_float_dev_buf(scale_b_float.get_element_space_size_in_bytes());
|
||||
scale_b_float_dev_buf.ToDevice(scale_b_float.data());
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
|
||||
c_m_n_ref_buf->SetZero();
|
||||
|
||||
ck_tile::reference_moe_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
static_cast<int>(kind),
|
||||
ck_tile::moe::Swiglu>(
|
||||
p_sorted_token_ids_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
static_cast<const ADataType*>(a_m_k_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<const BDataType*>(b_origin_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
|
||||
p_sorted_expert_weight_dev,
|
||||
num_tokens,
|
||||
MPerBlock,
|
||||
topk,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
M,
|
||||
1,
|
||||
ScaleGranularityK,
|
||||
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()));
|
||||
|
||||
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
180
example/ck_tile/18_flatmm/mixed_prec/run_mixed_prec_flatmm.inc
Normal file
180
example/ck_tile/18_flatmm/mixed_prec/run_mixed_prec_flatmm.inc
Normal file
@@ -0,0 +1,180 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
template <typename PrecActType,
|
||||
typename PrecWeightType,
|
||||
typename FlatmmConfig,
|
||||
bool UsePersistentKernel = false,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using CDataType = PrecActType;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
constexpr int DequantGranularityN = 1;
|
||||
constexpr int DequantGranularityK = 32;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
ck_tile::index_t n_warmup = arg_parser.get_int("warmup");
|
||||
ck_tile::index_t n_repeat = arg_parser.get_int("repeat");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_origin_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_rslt_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::HostTensorDescriptor(
|
||||
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffle_host.begin(), N, K);
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b_shuffle = preShuffleScale<FlatmmConfig>(scale_b);
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffle.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
c_rslt_host.SetZero();
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffle.data());
|
||||
|
||||
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK>{
|
||||
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
|
||||
|
||||
invoke_mixed_prec_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(scale_b_dev_ptr),
|
||||
UsePersistentKernel>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
scale_b_dev_ptr,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
|
||||
bool pass = true;
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::DeviceMem b_origin_dev_buf(b_origin_host.get_element_space_size_in_bytes());
|
||||
b_origin_dev_buf.ToDevice(b_origin_host.data());
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::DeviceMem c_gpu_ref_dev_buf(c_gpu_ref_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::HostTensor<AccDataType> scale_A(
|
||||
ck_tile::HostTensorDescriptor({1, K / DequantGranularityK}, {1, 1}));
|
||||
|
||||
// scaleA = 1 has no effect on the result
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(scale_A);
|
||||
ck_tile::DeviceMem scale_A_dev_buf(scale_A.get_element_space_size_in_bytes());
|
||||
scale_A_dev_buf.ToDevice(scale_A.data());
|
||||
|
||||
// convert scale_b from e8m0 to float
|
||||
ck_tile::HostTensor<AccDataType> scale_b_float(ck_tile::HostTensorDescriptor(
|
||||
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
|
||||
std::copy(scale_b.begin(), scale_b.end(), scale_b_float.begin());
|
||||
ck_tile::DeviceMem scale_b_float_dev_buf(scale_b_float.get_element_space_size_in_bytes());
|
||||
scale_b_float_dev_buf.ToDevice(scale_b_float.data());
|
||||
|
||||
c_gpu_ref_dev_buf.SetZero();
|
||||
ck_tile::reference_blockwise_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_origin_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_gpu_ref_dev_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
M,
|
||||
DequantGranularityN,
|
||||
DequantGranularityK,
|
||||
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_rslt_host, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
470
example/ck_tile/18_flatmm/moe_flatmm.cpp
Normal file
470
example/ck_tile/18_flatmm/moe_flatmm.cpp
Normal file
@@ -0,0 +1,470 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "moe_flatmm.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_gemm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
constexpr int MaxVecSize = 16 / sizeof(T);
|
||||
constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile;
|
||||
constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane);
|
||||
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / ItemsPerAccess,
|
||||
ItemsPerAccess});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
// gemm1
|
||||
// operand-A = [num_token, d_model]
|
||||
// operand-B = [num_expert, hidden, d_model]
|
||||
// operand-C = [num_token, topk, hidden]
|
||||
|
||||
// gemm2
|
||||
// operand-A = [num_token, topk, hidden]
|
||||
// operand-B = [num_expert, d_model, hidden]
|
||||
// operand-C = [num_token, d_model]
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind moe_kind = ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename ScaleM,
|
||||
typename ScaleN>
|
||||
float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
false, // UsePersistentKernel_
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>; // Preshuffle_
|
||||
|
||||
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
static_assert(
|
||||
FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0,
|
||||
"requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
|
||||
}
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up
|
||||
? 2
|
||||
: 1; // determined by scale shuffle pattern
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using Kernel = ck_tile::
|
||||
MoeFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue, moe_kind>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2 ? args.NumTokens * args.TopK
|
||||
: args.NumTokens,
|
||||
args.K,
|
||||
args.stride_A,
|
||||
is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N * args.NumExperts, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
const int outputN =
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? args.N / 2 : args.N;
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.NumTokens * args.N * sizeof(CDataType), s.stream_id_));
|
||||
else if(args.k_batch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(args.e_ptr,
|
||||
0,
|
||||
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_moe_flatmm_example.inc"
|
||||
|
||||
template <template <typename PreType> typename FlatmmConfig>
|
||||
int run_moe_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
const std::string prec_type = arg_parser.get_str("prec");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
const std::string gemm_kind = arg_parser.get_str("gemm_kind");
|
||||
if(gemm_kind == "gemm1_gate_up")
|
||||
{
|
||||
if(prec_type == "fp8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "fp16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm1_gate_only")
|
||||
{
|
||||
if(prec_type == "fp8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "fp16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm2")
|
||||
{
|
||||
if(prec_type == "fp8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::bfloat16_t,
|
||||
FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "fp16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
|
||||
"[gemm1_gate_only | gemm1_gate_up | gemm2]");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_moe_flatmm_example<FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
return !run_moe_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 2)
|
||||
{
|
||||
return !run_moe_flatmm_example<FlatmmConfig16_950>(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return !run_moe_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
}
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
202
example/ck_tile/18_flatmm/moe_flatmm.hpp
Normal file
202
example/ck_tile/18_flatmm/moe_flatmm.hpp
Normal file
@@ -0,0 +1,202 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 32;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32_950 : public FlatmmConfig32<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 64;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 64;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
N_Tile / FlatmmConfig16<DataType>::N_Warp_Tile / FlatmmConfig16<DataType>::N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false; // N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename ADataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::bf16_t;
|
||||
};
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_8bit_type
|
||||
: std::bool_constant<std::is_same_v<T, ck_tile::fp8_t> || std::is_same_v<T, ck_tile::bf8_t>>
|
||||
{
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("experts", "8", "Num of experts - 8 by default")
|
||||
.insert("NumTokens", "128", "M dimensions - 128 by default.")
|
||||
.insert("TopK", "3", "Top K - 3 by default.")
|
||||
.insert("N", "4096", "N dimensions - 4096 by default.")
|
||||
.insert("K", "4096", "K dimensions - 4096 by default.")
|
||||
.insert("stride_A", "", "Tensor A strides - it is empty by default.")
|
||||
.insert("stride_B", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_C", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("gemm_kind",
|
||||
"gemm1_gate_only",
|
||||
"Gemm kind in FFN network [gemm1_gate_only | gemm1_gate_up | gemm2] - "
|
||||
"gemm1_gate_only by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert(
|
||||
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
@@ -1,175 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
#include <type_traits>
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
template <typename T>
|
||||
constexpr const char* DataTypeToString()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck_tile::half_t>)
|
||||
{
|
||||
return "fp16";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::fp8_t>)
|
||||
{
|
||||
return "fp8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf8_t>)
|
||||
{
|
||||
return "bf8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
return "bf16";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = FlatmmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
FlatmmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_shuffle_dev_buf,
|
||||
ck_tile::DeviceMem& c_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C};
|
||||
|
||||
float ave_time = flatmm_calc<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
int ScaleGranularityM = -1,
|
||||
int ScaleGranularityN = -1,
|
||||
bool UsePersistentKernel = false,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
@@ -213,31 +50,32 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
ck_tile::HostTensor<CDataType> c_rslt_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));
|
||||
|
||||
// TODO: add different init types
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
// ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
// ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_host);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
}
|
||||
else if(init_method == 3)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
}
|
||||
else if(init_method == 4)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -248,52 +86,69 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem per_channel_scale_dev_buf(
|
||||
per_channel_scale.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
c_rslt_host.SetZero();
|
||||
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
|
||||
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
|
||||
|
||||
// do pre-shuffle
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig>(b_origin_host);
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
|
||||
if constexpr(FlatmmConfig::TiledMMAPermuteN)
|
||||
{
|
||||
return shuffle_b_v1<FlatmmConfig>(b_origin_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
return shuffle_b<FlatmmConfig>(b_origin_host);
|
||||
}
|
||||
}();
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
|
||||
float ave_time = invoke_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
|
||||
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
|
||||
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
|
||||
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
invoke_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(per_token_scale_dev_ptr),
|
||||
decltype(per_channel_scale_dev_ptr),
|
||||
UsePersistentKernel>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
per_token_scale_dev_ptr,
|
||||
per_channel_scale_dev_ptr,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
if(ScaleGranularityM != -1 || ScaleGranularityN != -1)
|
||||
throw std::runtime_error("ScaleAB is not supported for CPU verification!\n");
|
||||
ck_tile::HostTensor<CDataType> c_ref_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_ref_host.SetZero();
|
||||
@@ -341,13 +196,41 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
if constexpr(ScaleGranularityM == -1 && ScaleGranularityN == -1)
|
||||
{
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_blockwise_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
d_A,
|
||||
d_B,
|
||||
d_C,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
ScaleGranularityM,
|
||||
ScaleGranularityN,
|
||||
K,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_dev_buf.GetDeviceBuffer(),
|
||||
d_C,
|
||||
@@ -375,22 +258,5 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_flatmm_json_results(arg_parser.get_str("jsonfile"),
|
||||
DataTypeToString<ADataType>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
605
example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc
Normal file
605
example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc
Normal file
@@ -0,0 +1,605 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
FlatmmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
const ck_tile::ContiguousGroupedFlatmmHostArgs<ScaleM, ScaleN>& args)
|
||||
{
|
||||
float ave_time = grouped_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
|
||||
sizeof(BDataType) * args.N * args.K +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int val_m,
|
||||
const ck_tile::MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN>& args)
|
||||
{
|
||||
float ave_time = grouped_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
std::size_t flop = std::size_t(2) * val_m * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * val_m * args.K +
|
||||
sizeof(BDataType) * args.N * args.K * args.group_count +
|
||||
sizeof(CDataType) * val_m * args.N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
int ScaleGranularityM = -1,
|
||||
int ScaleGranularityN = -1,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_contiguous_grouped_flatmm_example_with_layouts(
|
||||
int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
constexpr int BlockM = FlatmmConfig::M_Tile;
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
|
||||
if(!(int(Ms.size()) == group_count))
|
||||
{
|
||||
std::cout << "Please check the input data." << std::endl;
|
||||
// padding additional Ms if needed
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 64 * i);
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::index_t M =
|
||||
std::reduce(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) {
|
||||
// round up to the multiple of BlockM
|
||||
return acc + (group_m + BlockM - 1) / BlockM * BlockM;
|
||||
});
|
||||
std::cout << "Total M: " << M << std::endl;
|
||||
ck_tile::index_t N = Ns[0];
|
||||
ck_tile::index_t K = Ks[0];
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
ck_tile::index_t stride_A = 0;
|
||||
ck_tile::index_t stride_B = 0;
|
||||
ck_tile::index_t stride_C = 0;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N * group_count, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k_tensor(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n_tensor(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N * group_count, stride_B, is_row_major(b_layout))));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_tensor(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(c_layout))));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));
|
||||
|
||||
std::vector<ck_tile::index_t> m_indices(M);
|
||||
int indices_fill_start = 0;
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
int group_m = Ms[i];
|
||||
int padded_group_m = (group_m + BlockM - 1) / BlockM * BlockM;
|
||||
for(int j = 0; j < padded_group_m; j++)
|
||||
{
|
||||
m_indices[indices_fill_start + j] = j < group_m ? i : -1; // -1 for padding
|
||||
}
|
||||
indices_fill_start += padded_group_m;
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
|
||||
|
||||
assert(N % N_Warp_Tile == 0 &&
|
||||
"N must be divisible by N_Warp_Tile for contiguous grouped gemm");
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host =
|
||||
shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensor);
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> b_shfl_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(b_shuffle_host.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes()));
|
||||
|
||||
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem per_channel_scale_dev_buf(
|
||||
per_channel_scale.get_element_space_size_in_bytes());
|
||||
|
||||
c_m_n_dev_buf->SetZero();
|
||||
|
||||
ck_tile::DeviceMem m_indices_dev_buf(M * sizeof(ck_tile::index_t));
|
||||
m_indices_dev_buf.ToDevice(m_indices.data());
|
||||
|
||||
a_m_k_dev_buf->ToDevice(a_m_k_tensor.data());
|
||||
b_shfl_dev_buf->ToDevice(b_shuffle_host.data());
|
||||
|
||||
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
|
||||
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
|
||||
|
||||
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
|
||||
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
|
||||
ck_tile::ContiguousGroupedFlatmmHostArgs<decltype(per_token_scale_dev_ptr),
|
||||
decltype(per_channel_scale_dev_ptr)>
|
||||
kernal_args{static_cast<ck_tile::index_t*>(m_indices_dev_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_m_k_dev_buf->GetDeviceBuffer(),
|
||||
stride_A,
|
||||
b_shfl_dev_buf->GetDeviceBuffer(),
|
||||
stride_B,
|
||||
{},
|
||||
{},
|
||||
c_m_n_dev_buf->GetDeviceBuffer(),
|
||||
stride_C,
|
||||
kbatch,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
|
||||
invoke_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(per_token_scale_dev_ptr),
|
||||
decltype(per_channel_scale_dev_ptr)>(warmup, repeat, kernal_args);
|
||||
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Not support v=1 host verification in contiguous grouped gemm, use "
|
||||
"v=2 device verification instead");
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
|
||||
ck_tile::hip_check_error(hipMemset(d_C, 0, M * N * sizeof(CDataType)));
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::index_t acc_m = 0;
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::index_t padded_M = (Ms[i] + BlockM - 1) / BlockM * BlockM;
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_tensor.data() + i * N * K,
|
||||
N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice));
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + acc_m * K,
|
||||
d_B,
|
||||
d_C + acc_m * N,
|
||||
padded_M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C);
|
||||
acc_m += padded_M;
|
||||
}
|
||||
ck_tile::hip_check_error(hipMemcpy(
|
||||
c_gpu_ref_host.data(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost));
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
float rtol = 1e-3;
|
||||
float atol = 1e-3;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
int ScaleGranularityM = -1,
|
||||
int ScaleGranularityN = -1,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_masked_grouped_flatmm_example_with_layouts(
|
||||
int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
constexpr int BlockM = FlatmmConfig::M_Tile;
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
|
||||
if(!(int(Ms.size()) == group_count))
|
||||
{
|
||||
std::cout << "Please check the input data." << std::endl;
|
||||
// padding additional Ms if needed
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 64 * i);
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::index_t M = 4096; // Ms[0];
|
||||
ck_tile::index_t N = Ns[0];
|
||||
ck_tile::index_t K = Ks[0];
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
ck_tile::index_t stride_A = K;
|
||||
ck_tile::index_t stride_B = K;
|
||||
ck_tile::index_t stride_C = N;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(group_count * M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N * group_count, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(group_count * M, N, stride_C, is_row_major(c_layout));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k_tensor(
|
||||
ck_tile::host_tensor_descriptor(group_count * M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n_tensor(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N * group_count, stride_B, is_row_major(b_layout))));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_tensor(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(group_count * M, N, stride_C, is_row_major(c_layout))));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> per_token_scale(
|
||||
ck_tile::HostTensorDescriptor({group_count * M}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(
|
||||
ck_tile::HostTensorDescriptor({group_count * N}, {1}));
|
||||
|
||||
std::vector<ck_tile::index_t> m_indices(group_count);
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
int group_m = Ms[i];
|
||||
int padded_group_m = (group_m + BlockM - 1) / BlockM * BlockM;
|
||||
for(int j = 0; j < padded_group_m; j++)
|
||||
{
|
||||
m_indices[i] = group_m;
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
|
||||
|
||||
assert(N % N_Warp_Tile == 0 &&
|
||||
"N must be divisible by N_Warp_Tile for contiguous grouped gemm");
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host =
|
||||
shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensor);
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> b_shfl_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(b_shuffle_host.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes()));
|
||||
|
||||
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem per_channel_scale_dev_buf(
|
||||
per_channel_scale.get_element_space_size_in_bytes());
|
||||
c_m_n_dev_buf->SetZero();
|
||||
|
||||
ck_tile::DeviceMem m_indices_dev_buf(group_count * sizeof(ck_tile::index_t));
|
||||
m_indices_dev_buf.ToDevice(m_indices.data());
|
||||
|
||||
a_m_k_dev_buf->ToDevice(a_m_k_tensor.data());
|
||||
b_shfl_dev_buf->ToDevice(b_shuffle_host.data());
|
||||
|
||||
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
|
||||
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
|
||||
|
||||
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
|
||||
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
ck_tile::MaskedGroupedFlatmmHostArgs<decltype(per_token_scale_dev_ptr),
|
||||
decltype(per_channel_scale_dev_ptr)>
|
||||
kernal_args{static_cast<ck_tile::index_t*>(m_indices_dev_buf.GetDeviceBuffer()),
|
||||
group_count,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_m_k_dev_buf->GetDeviceBuffer(),
|
||||
stride_A,
|
||||
b_shfl_dev_buf->GetDeviceBuffer(),
|
||||
stride_B,
|
||||
{},
|
||||
{},
|
||||
c_m_n_dev_buf->GetDeviceBuffer(),
|
||||
stride_C,
|
||||
kbatch,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
int sum_val_m = 0;
|
||||
for(int gi = 0; gi < group_count; gi++)
|
||||
{
|
||||
sum_val_m += m_indices[gi];
|
||||
}
|
||||
|
||||
invoke_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(per_token_scale_dev_ptr),
|
||||
decltype(per_channel_scale_dev_ptr)>(warmup, repeat, sum_val_m, kernal_args);
|
||||
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Not support v=1 host verification in contiguous grouped gemm, use "
|
||||
"v=2 device verification instead");
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_C, group_count * M * N * sizeof(CDataType)));
|
||||
ck_tile::hip_check_error(hipMemset(d_C, 0, group_count * M * N * sizeof(CDataType)));
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
|
||||
ck_tile::host_tensor_descriptor(group_count * M, N, stride_C, is_row_major(CLayout{})));
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_tensor.data() + i * N * K,
|
||||
N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
if constexpr(ScaleGranularityM == -1 && ScaleGranularityN == -1)
|
||||
{
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + i * M * K,
|
||||
d_B,
|
||||
d_C + i * M * N,
|
||||
m_indices[i],
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_blockwise_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + i * M * K,
|
||||
d_B,
|
||||
d_C + i * M * N,
|
||||
m_indices[i],
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
ScaleGranularityM,
|
||||
ScaleGranularityN,
|
||||
K,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()) + i * M,
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())) +
|
||||
i* N;
|
||||
}
|
||||
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_host.data() + i * M * N,
|
||||
d_C + i * M * N,
|
||||
M * N * sizeof(CDataType),
|
||||
hipMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
float rtol = 1e-3;
|
||||
float atol = 1e-3;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
323
example/ck_tile/18_flatmm/run_moe_flatmm_example.inc
Normal file
323
example/ck_tile/18_flatmm/run_moe_flatmm_example.inc
Normal file
@@ -0,0 +1,323 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename MoeHostArgs>
|
||||
float invoke_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
|
||||
{
|
||||
float ave_time = moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
kind,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::string op_name{"Moe Gemm"};
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
|
||||
sizeof(BDataType) * args.N * args.K +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_moe_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
constexpr int ScaleGranularityM = 1;
|
||||
constexpr int ScaleGranularityN = 1;
|
||||
|
||||
const ck_tile::index_t N = arg_parser.get_int("N");
|
||||
const ck_tile::index_t K = arg_parser.get_int("K");
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_A");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_B");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_C");
|
||||
const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
|
||||
const ck_tile::index_t topk = arg_parser.get_int("TopK");
|
||||
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
|
||||
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
|
||||
const ck_tile::index_t experts = arg_parser.get_int("experts");
|
||||
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
|
||||
|
||||
ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * MPerBlock * topk;
|
||||
ck_tile::index_t valid_tile_num = sorted_tile_num;
|
||||
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
|
||||
const ck_tile::index_t M = sorted_tile_num * MPerBlock;
|
||||
const ck_tile::index_t outputN = kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? N / 2 : N;
|
||||
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr bool IsInputGemm = kind != ck_tile::MoeFlatmmKind::kFFN_gemm2;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout)));
|
||||
auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
|
||||
is_row_major(b_layout)
|
||||
? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
|
||||
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
|
||||
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
|
||||
auto b_shuffle_host = shuffle_b<FlatmmConfig>(b_k_n_tensor);
|
||||
|
||||
std::cout << "moe_flatmm:" //
|
||||
<< "\n num_experts: " << experts << "\n num_tokens: " << num_tokens
|
||||
<< "\n topk: " << topk << "\n sorted_tile_num: " << sorted_tile_num
|
||||
<< "\n a_m_k: " << a_m_k_tensor.mDesc << "\n b_k_n: " << b_k_n_tensor.mDesc
|
||||
<< "\n b_shuffle: " << b_shuffle_host.mDesc << "\n c_m_n: " << c_m_n_tensor.mDesc
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf{a_m_k_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_origin_dev_buf{b_k_n_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf{b_shuffle_host.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem c_m_n_dev_buf{c_m_n_tensor.get_element_space_size_in_bytes()};
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_tensor.data());
|
||||
b_origin_dev_buf.ToDevice(b_k_n_tensor.data());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_tensor.SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
const void* p_b_origin = b_origin_dev_buf.GetDeviceBuffer();
|
||||
const void* p_b_shuffle = b_shuffle_dev_buf.GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
|
||||
// TODO: malloc and init sorted tokens and max tokens buffer
|
||||
|
||||
ck_tile::HostTensor<ck_tile::index_t> expert_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> expert_weight(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
|
||||
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> per_token_scale(
|
||||
ck_tile::HostTensorDescriptor({IsInputGemm ? num_tokens : M}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(
|
||||
ck_tile::HostTensorDescriptor({N * experts}, {1}));
|
||||
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_channel_scale);
|
||||
|
||||
// for verification only, no need to satify weight normalization
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_dev{sorted_token_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_ids_dev{expert_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem max_token_id_dev{max_token_id.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_weight_dev{expert_weight.get_element_space_size_in_bytes()};
|
||||
|
||||
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem per_channel_scale_dev_buf(
|
||||
per_channel_scale.get_element_space_size_in_bytes());
|
||||
|
||||
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
|
||||
// int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
|
||||
}
|
||||
|
||||
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
// int token_per_tile = num_tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
// sorted_token_ids.mData[0] = 0;
|
||||
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = num_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.data());
|
||||
expert_weight_dev.ToDevice(expert_weight.data());
|
||||
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
|
||||
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
|
||||
|
||||
const ck_tile::index_t* p_sorted_token_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(sorted_token_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_expert_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(expert_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_max_token_id_dev =
|
||||
static_cast<ck_tile::index_t*>(max_token_id_dev.GetDeviceBuffer());
|
||||
const AccDataType* p_sorted_expert_weight_dev =
|
||||
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
|
||||
|
||||
using MoeFlatmmArgs =
|
||||
ck_tile::MoeFlatmmHostArgs<ck_tile::FlatmmScalePointer<ScaleGranularityM>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN>>;
|
||||
|
||||
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
|
||||
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
|
||||
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
|
||||
p_sorted_expert_weight_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
p_a,
|
||||
p_b_shuffle,
|
||||
p_c,
|
||||
num_tokens,
|
||||
experts,
|
||||
topk,
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
per_token_scale_dev_ptr,
|
||||
per_channel_scale_dev_ptr};
|
||||
|
||||
invoke_moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
kind>(warmup, repeat, gemm_desc);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("validate"))
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(IsInputGemm ? num_tokens * topk : num_tokens,
|
||||
outputN,
|
||||
stride_C,
|
||||
is_row_major(CLayout{})));
|
||||
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
|
||||
|
||||
c_m_n_ref_buf->SetZero();
|
||||
|
||||
ck_tile::reference_moe_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
static_cast<int>(kind),
|
||||
ck_tile::moe::MoeSilu>(
|
||||
p_sorted_token_ids_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b_origin),
|
||||
static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
|
||||
p_sorted_expert_weight_dev,
|
||||
num_tokens,
|
||||
MPerBlock,
|
||||
topk,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
1,
|
||||
1,
|
||||
K,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1 /*kbatch*/, max_accumulated_value);
|
||||
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -1303,6 +1303,15 @@ CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
|
||||
|
||||
// buffer atomic-add bf16
|
||||
// TODO: Replace with bf16x2_t, but llvm builins only accept cktile_bf16x2_t now.
|
||||
CK_TILE_DEVICE_EXTERN bf16x2_t llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
bf16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2bf16");
|
||||
|
||||
// buffer atomic-add i32
|
||||
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
int32_t vdata,
|
||||
@@ -1537,8 +1546,11 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, e8m0_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_int4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
|
||||
(std::is_same<T, pk_fp4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
|
||||
"wrong! not implemented");
|
||||
|
||||
using rtn_type = thread_buffer<T, N>;
|
||||
@@ -2262,6 +2274,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
{
|
||||
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
|
||||
@@ -2355,6 +2368,39 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value)
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(bit_cast<bf16x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
src_thread_data.template get_as<bf16x2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(bf16x2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
src_thread_data.template get_as<bf16x2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(bf16x2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
|
||||
@@ -1171,6 +1171,15 @@ CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
|
||||
|
||||
// buffer atomic-add bf16
|
||||
// TODO: Replace with bf16x2_t, but llvm builins only accept cktile_bf16x2_t now.
|
||||
CK_TILE_DEVICE_EXTERN bf16x2_t llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
bf16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2bf16");
|
||||
|
||||
// buffer atomic-add i32
|
||||
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
int32_t vdata,
|
||||
@@ -1405,10 +1414,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, e8m0_bexp_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_fp4_raw_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_int4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
|
||||
(std::is_same<T, pk_fp4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
|
||||
"wrong! not implemented");
|
||||
|
||||
using rtn_type = thread_buffer<T, N>;
|
||||
@@ -2047,6 +2060,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
{
|
||||
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
|
||||
@@ -2140,6 +2154,39 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value)
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(bit_cast<bf16x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
src_thread_data.template get_as<bf16x2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(bf16x2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
src_thread_data.template get_as<bf16x2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(bf16x2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/numeric/e8m0.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -88,7 +89,12 @@ template <typename T, typename = void>
|
||||
struct vector_traits
|
||||
{
|
||||
using scalar_type =
|
||||
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>, int8_t, remove_cvref_t<T>>;
|
||||
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>,
|
||||
int8_t,
|
||||
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_fp4_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, e8m0_t>,
|
||||
uint8_t,
|
||||
remove_cvref_t<T>>>;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
@@ -96,7 +102,12 @@ struct vector_traits
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N))), void>
|
||||
{
|
||||
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
|
||||
using scalar_type = std::conditional_t<
|
||||
std::is_same_v<T, pk_int4_t>,
|
||||
int8_t,
|
||||
std::conditional_t<std::is_same_v<T, pk_fp4_t> || std::is_same_v<remove_cvref_t<T>, e8m0_t>,
|
||||
uint8_t,
|
||||
T>>;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
@@ -237,4 +248,10 @@ using pk_int4x4_t = int8_t __attribute__((ext_vector_type(4)));
|
||||
using pk_int4x8_t = int8_t __attribute__((ext_vector_type(8)));
|
||||
using pk_int4x16_t = int8_t __attribute__((ext_vector_type(16)));
|
||||
using pk_int4x32_t = int8_t __attribute__((ext_vector_type(32)));
|
||||
|
||||
using pk_fp4x2_t = uint8_t __attribute((ext_vector_type(2)));
|
||||
using pk_fp4x4_t = uint8_t __attribute((ext_vector_type(4)));
|
||||
using pk_fp4x8_t = uint8_t __attribute((ext_vector_type(8)));
|
||||
using pk_fp4x16_t = uint8_t __attribute((ext_vector_type(16)));
|
||||
using pk_fp4x32_t = uint8_t __attribute((ext_vector_type(32)));
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -247,7 +247,7 @@ struct buffer_view<address_space_enum::global,
|
||||
: p_data_{p_data},
|
||||
buffer_size_{buffer_size / PackedSize},
|
||||
cached_buf_res_{0},
|
||||
invalid_element_value_{0}
|
||||
invalid_element_value_{}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -631,14 +631,24 @@ struct buffer_view<address_space_enum::global,
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
|
||||
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
|
||||
||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
|
||||
#endif
|
||||
;
|
||||
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
|
||||
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
|
||||
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
|
||||
||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
|
||||
#endif
|
||||
;
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
@@ -404,6 +405,100 @@ struct tile_scatter_gather
|
||||
});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
using Traits = load_store_traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// Precompute invariant values outside loops
|
||||
const auto window_origin = lds_tile.get_window_origin();
|
||||
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
|
||||
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
|
||||
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
auto lds_window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto lds_bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// Use precomputed window origin
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
window_origin + lds_window_adaptor_thread_coord.get_bottom_index();
|
||||
// Use precomputed tensor descriptor
|
||||
const auto lds_coord =
|
||||
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
|
||||
// Calculate SMEM address using base pointer
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// merge page_offset into bottom_coord
|
||||
auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
|
||||
mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
|
||||
|
||||
// read from bottom tensor
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
mixed_bottom_thread_coord,
|
||||
number<0>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
else
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
mixed_bottom_thread_coord,
|
||||
number<0>{},
|
||||
valids_[idx_gather],
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
forward_step_scatter);
|
||||
// lds_diff doesn't need to mask the difference of the gather-dim.
|
||||
constexpr auto lds_idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
lds_window_adaptor_thread_coord,
|
||||
lds_bottom_tensor_thread_coord,
|
||||
lds_idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: currently async load only implemented in inline asm
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
@@ -508,6 +603,88 @@ struct tile_scatter_gather
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<0>{}];
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
valids_[idx_gather],
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
forward_step_scatter);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
@@ -855,4 +1032,29 @@ CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
|
||||
number<HsGatherDim>{});
|
||||
}
|
||||
|
||||
template <typename NewTensorView_,
|
||||
typename OldTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
typename StaticValidArray_,
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1>
|
||||
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_& new_tensor_view,
|
||||
const tile_scatter_gather<OldTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
StaticPageIndexArray_,
|
||||
StaticValidArray_,
|
||||
HsGatherDim,
|
||||
NumCoord>& tile_window)
|
||||
{
|
||||
return make_tile_scatter_gather(new_tensor_view,
|
||||
tile_window.window_lengths_,
|
||||
tile_window.window_origin_,
|
||||
tile_window.tile_dstr_,
|
||||
tile_window.page_idx_,
|
||||
tile_window.valids_);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1153,6 +1153,33 @@ CK_TILE_DEVICE void move_tile_window(
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
template <typename NewTensorView_,
|
||||
typename OldTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord = 1>
|
||||
CK_TILE_DEVICE auto
|
||||
replace_bottom_tensor_view(const NewTensorView_& new_tensor_view,
|
||||
const tile_window_with_static_distribution<OldTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>& tile_window)
|
||||
{
|
||||
return make_tile_window(new_tensor_view,
|
||||
tile_window.get_window_lengths(),
|
||||
tile_window.get_window_origin(),
|
||||
tile_window.get_tile_distribution());
|
||||
}
|
||||
|
||||
template <typename NewTensorView_, typename OldTensorView_, typename WindowLengths_>
|
||||
CK_TILE_DEVICE auto replace_bottom_tensor_view(
|
||||
const NewTensorView_& new_tensor_view,
|
||||
const tile_window_with_static_lengths<OldTensorView_, WindowLengths_>& tile_window)
|
||||
{
|
||||
return make_tile_window(
|
||||
new_tensor_view, tile_window.get_window_lengths(), tile_window.get_window_origin());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Type trait to determine if a type is a tile window with static distribution.
|
||||
*
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
#include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_im2col.hpp"
|
||||
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
|
||||
#include "ck_tile/host/reference/reference_permute.hpp"
|
||||
#include "ck_tile/host/reference/reference_pool.hpp"
|
||||
|
||||
@@ -480,6 +480,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
@@ -492,6 +500,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
@@ -506,6 +522,121 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
__global__ void blockwise_gemm_kernel(ADataType* A,
|
||||
BDataType* B,
|
||||
CDataType* C,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t strideA,
|
||||
ck_tile::index_t strideB,
|
||||
ck_tile::index_t strideC,
|
||||
ck_tile::index_t scale_granularity_m,
|
||||
ck_tile::index_t scale_granularity_n,
|
||||
ck_tile::index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row = idx / N; // Compute row index
|
||||
int col = idx % N; // Compute column index
|
||||
|
||||
if(row < M && col < N)
|
||||
{
|
||||
AccDataType acc = 0.0, acc_temp = 0.0;
|
||||
|
||||
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
|
||||
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
|
||||
|
||||
float scale_A = 0;
|
||||
float scale_B = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
if(k % scale_granularity_k == 0)
|
||||
{
|
||||
// update acc
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
acc_temp = 0.0;
|
||||
// update scale factors
|
||||
scale_A = scale_A_ptr[(row / scale_granularity_m) +
|
||||
(k / scale_granularity_k) * scale_A_stride];
|
||||
scale_B = scale_B_ptr[(col / scale_granularity_n) +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
}
|
||||
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideA + k
|
||||
: k * strideA + row;
|
||||
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col;
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
}
|
||||
acc_temp += v_a * v_b;
|
||||
}
|
||||
// final accumulation
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideC + col
|
||||
: col * strideC + row;
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -534,6 +665,51 @@ void reference_gemm_gpu(ADataType* a_ptr,
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_blockwise_gemm_gpu(ADataType* a_ptr,
|
||||
BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_a,
|
||||
index_t stride_b,
|
||||
index_t stride_c,
|
||||
index_t scale_granularity_m,
|
||||
index_t scale_granularity_n,
|
||||
index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr)
|
||||
{
|
||||
int totalElements = M * N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
blockwise_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
|
||||
<<<numBlocks, numThreadsPerBlock>>>(a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
scale_granularity_m,
|
||||
scale_granularity_n,
|
||||
scale_granularity_k,
|
||||
scale_A_ptr,
|
||||
scale_B_ptr);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -571,4 +747,5 @@ void reference_batched_gemm_gpu(ADataType* a_ptr,
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
316
include/ck_tile/host/reference/reference_moe_gemm.hpp
Normal file
316
include/ck_tile/host/reference/reference_moe_gemm.hpp
Normal file
@@ -0,0 +1,316 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
|
||||
typename ActivationOp = identity>
|
||||
__global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
const ck_tile::index_t* p_sorted_expert_ids_,
|
||||
const ck_tile::index_t* p_max_token_id_,
|
||||
const ADataType* A,
|
||||
const BDataType* B,
|
||||
CDataType* C,
|
||||
const AccDataType* expert_weight_ptr,
|
||||
ck_tile::index_t Num_tokens,
|
||||
ck_tile::index_t TokensPerBlock,
|
||||
ck_tile::index_t TopK,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t strideA,
|
||||
ck_tile::index_t strideB,
|
||||
ck_tile::index_t strideC,
|
||||
index_t scale_granularity_m,
|
||||
index_t scale_granularity_n,
|
||||
index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr,
|
||||
float* expert_bias_ptr)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
int row = idx / problem_N; // Compute row index
|
||||
int col = idx % problem_N; // Compute column index
|
||||
|
||||
index_t gather_token_id = 0;
|
||||
index_t scatter_token_id = 0;
|
||||
index_t expert_id = 0;
|
||||
|
||||
if(row < p_max_token_id_[0])
|
||||
{
|
||||
expert_id = p_sorted_expert_ids_[row / TokensPerBlock];
|
||||
gather_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
|
||||
scatter_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
|
||||
if(gather_token_id >= Num_tokens)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if(MoeGemmKind == 2)
|
||||
{
|
||||
gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
|
||||
}
|
||||
else
|
||||
{
|
||||
scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if(row < M)
|
||||
{
|
||||
AccDataType acc = 0.0;
|
||||
AccDataType acc_up = 0.0;
|
||||
|
||||
AccDataType acc_temp = 0.0;
|
||||
AccDataType acc_up_temp = 0.0;
|
||||
|
||||
float scale_A = 0;
|
||||
float scale_B = 0;
|
||||
float scale_B_up = 0;
|
||||
|
||||
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
|
||||
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
|
||||
index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
if(k % scale_granularity_k == 0)
|
||||
{
|
||||
// update acc
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
acc_up += acc_up_temp * scale_A * scale_B_up;
|
||||
// reset acc temp
|
||||
acc_temp = 0.0;
|
||||
acc_up_temp = 0.0;
|
||||
// update scale factors
|
||||
scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
|
||||
(k / scale_granularity_k) * scale_A_stride];
|
||||
scale_B =
|
||||
scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
|
||||
(col + problem_N) / scale_granularity_n +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
}
|
||||
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? gather_token_id * strideA + k
|
||||
: k * strideA + gather_token_id;
|
||||
|
||||
long b_index =
|
||||
long(expert_id) * N * K +
|
||||
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
|
||||
: k * strideB + col);
|
||||
long b_index_up;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
b_index_up = long(expert_id) * N * K +
|
||||
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? (col + problem_N) * strideB + k
|
||||
: k * strideB + col + problem_N);
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
AccDataType v_b_up;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
{
|
||||
const fp32x2_t fp32_val_up =
|
||||
pk_int4_t_to_fp32x2_t(B[b_index_up / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b_up = fp32_val_up.hi;
|
||||
else
|
||||
v_b_up = fp32_val_up.lo;
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
{
|
||||
const fp32x2_t fp32_val_up =
|
||||
pk_fp4_to_fp32x2(B[b_index_up / packed_size_b], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_b_up = fp32_val_up.hi;
|
||||
else
|
||||
v_b_up = fp32_val_up.lo;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
v_b_up = ck_tile::type_convert<AccDataType>(B[b_index_up]);
|
||||
}
|
||||
acc_temp += v_a * v_b;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
acc_up_temp += v_a * v_b_up;
|
||||
}
|
||||
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
acc_up += acc_up_temp * scale_A * scale_B_up;
|
||||
|
||||
float bias = 0.f, bias_up = 0.f;
|
||||
if(expert_bias_ptr != nullptr)
|
||||
{
|
||||
bias = expert_bias_ptr[expert_id * N + col];
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
|
||||
}
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? scatter_token_id * strideC + col
|
||||
: col * strideC + scatter_token_id;
|
||||
if constexpr(MoeGemmKind < 2)
|
||||
{
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(
|
||||
ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
|
||||
}
|
||||
else
|
||||
{
|
||||
// moe gemm2 don't use activation.
|
||||
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
|
||||
using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
|
||||
ck_tile::fp16x2_t,
|
||||
ck_tile::bf16x2_t>;
|
||||
ResV2Type add_v{0, 0};
|
||||
if(c_index % 2)
|
||||
{
|
||||
// result is the second value of fp16 pair.
|
||||
add_v.y = res;
|
||||
}
|
||||
else
|
||||
{
|
||||
// result is the first value of fp16 pair.
|
||||
add_v.x = res;
|
||||
}
|
||||
// mask last bit to make sure atomicAdd pointer is aligned of DWORD.
|
||||
atomic_add<ResV2Type>(reinterpret_cast<ResV2Type*>(C + (c_index & 0xffff'fffe)), add_v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
|
||||
typename ActivationOp = identity>
|
||||
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
|
||||
const index_t* p_sorted_expert_ids_,
|
||||
const index_t* p_max_token_id_,
|
||||
const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const AccDataType* expert_weight_ptr,
|
||||
index_t Num_tokens,
|
||||
index_t TokensPerBlock,
|
||||
index_t TopK,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_a,
|
||||
index_t stride_b,
|
||||
index_t stride_c,
|
||||
index_t scale_granularity_m,
|
||||
index_t scale_granularity_n,
|
||||
index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr,
|
||||
float* exp_bias = nullptr)
|
||||
{
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
int totalElements = M * problem_N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
moe_gemm_kernel<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
LayoutA,
|
||||
LayoutB,
|
||||
LayoutC,
|
||||
MoeGemmKind,
|
||||
ActivationOp><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
|
||||
p_sorted_expert_ids_,
|
||||
p_max_token_id_,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
expert_weight_ptr,
|
||||
Num_tokens,
|
||||
TokensPerBlock,
|
||||
TopK,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
scale_granularity_m,
|
||||
scale_granularity_n,
|
||||
scale_granularity_k,
|
||||
scale_A_ptr,
|
||||
scale_B_ptr,
|
||||
exp_bias);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -9,9 +9,9 @@
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename DsDataType_,
|
||||
@@ -29,10 +29,11 @@ template <typename AsDataType_,
|
||||
index_t KPerXdl_,
|
||||
bool isCTransposed_,
|
||||
memory_operation_enum MemoryOperation_,
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false>
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false,
|
||||
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
@@ -55,6 +56,7 @@ struct CShuffleEpilogueProblem
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
@@ -107,6 +109,7 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
@@ -212,7 +215,8 @@ struct CShuffleEpilogue
|
||||
}
|
||||
}();
|
||||
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle =
|
||||
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
|
||||
|
||||
static constexpr auto MNPerIterationShuffle = [] {
|
||||
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
|
||||
@@ -265,14 +269,31 @@ struct CShuffleEpilogue
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
|
||||
{
|
||||
constexpr auto block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto block_outer_dstr_encoding = [] {
|
||||
if constexpr(BlockedXDLN_PerWarp == 1)
|
||||
{
|
||||
return tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
|
||||
// BlockedLayout
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
}
|
||||
}();
|
||||
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
|
||||
|
||||
@@ -437,7 +458,6 @@ struct CShuffleEpilogue
|
||||
|
||||
static_assert(MPerXdl % RowsPerLane == 0,
|
||||
"CShuffle (permuteN): MPerXdl must be divisible by per-lane row count.");
|
||||
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = RowsPerLane;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
@@ -527,6 +547,7 @@ struct CShuffleEpilogue
|
||||
const int src = n_idx * plane + m_lane; // source row in this N-plane
|
||||
const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
|
||||
AccDataType v = shuffle_acc.get_thread_buffer()[src];
|
||||
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
v = static_cast<AccDataType>(v * scale_m * scale_n);
|
||||
@@ -537,6 +558,7 @@ struct CShuffleEpilogue
|
||||
const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
|
||||
v = static_cast<AccDataType>(v * sm * sn);
|
||||
}
|
||||
|
||||
c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -10,8 +10,14 @@
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
|
||||
@@ -113,6 +113,7 @@ struct BlockFlatmmASmemBSmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,23 +11,138 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
struct FlatmmProblem
|
||||
{
|
||||
CK_TILE_HOST FlatmmProblem() = default;
|
||||
CK_TILE_HOST FlatmmProblem(
|
||||
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
|
||||
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN, int SharedGranularityK = 0>
|
||||
struct FlatmmScalePointer
|
||||
{
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = SharedGranularityK;
|
||||
|
||||
const float* ptr;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
|
||||
: ptr(ptr_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
|
||||
{
|
||||
FlatmmScalePointer ret;
|
||||
if constexpr(GranularityMN == 0)
|
||||
{
|
||||
ret.ptr = ptr + offset / GranularityK;
|
||||
}
|
||||
else
|
||||
{
|
||||
ret.ptr = ptr + offset / GranularityMN / GranularityK;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN>
|
||||
struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
{
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
static_assert(GranularityMN != 0);
|
||||
|
||||
const float* ptr;
|
||||
index_t length;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_)
|
||||
: ptr(ptr_), length(length_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
|
||||
{
|
||||
FlatmmScalePointer ret;
|
||||
if constexpr(GranularityMN == 1)
|
||||
{
|
||||
ret.ptr = ptr + offset;
|
||||
ret.length = length - offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
ret.ptr = ptr + offset / GranularityMN;
|
||||
ret.length = length - offset / GranularityMN;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const
|
||||
{
|
||||
// with additional oob check
|
||||
if constexpr(GranularityMN == 1)
|
||||
return i < length ? ptr[i] : 0;
|
||||
else
|
||||
return i / GranularityMN < length ? ptr[i / GranularityMN] : 0;
|
||||
}
|
||||
};
|
||||
|
||||
// shared granularityMN = -1 means no scale
|
||||
template <>
|
||||
struct FlatmmScalePointer<-1, 0>
|
||||
{
|
||||
static constexpr int GranularityMN = -1;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
const float* ptr = nullptr;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
|
||||
{
|
||||
return FlatmmScalePointer{};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
|
||||
{
|
||||
return 1; // alway return 1, it doesn't change the result
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NumDTensor = 0>
|
||||
struct FlatmmHostArgs
|
||||
struct BaseFlatmmHostArgs
|
||||
{
|
||||
CK_TILE_HOST FlatmmHostArgs() = default;
|
||||
CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
CK_TILE_HOST BaseFlatmmHostArgs() = default;
|
||||
CK_TILE_HOST BaseFlatmmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
@@ -65,8 +180,51 @@ struct FlatmmHostArgs
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<>
|
||||
{
|
||||
CK_TILE_HOST ScaleFlatmmHostArgs() = default;
|
||||
CK_TILE_HOST ScaleFlatmmHostArgs(const void* a_ptr_,
|
||||
const void* b_shuffle_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* c_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_C_,
|
||||
ScaleM scale_m_ = nullptr,
|
||||
ScaleN scale_n_ = nullptr)
|
||||
: BaseFlatmmHostArgs(a_ptr_,
|
||||
b_shuffle_ptr_,
|
||||
ds_ptr_,
|
||||
c_ptr_,
|
||||
k_batch_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
stride_A_,
|
||||
stride_B_,
|
||||
stride_Ds_,
|
||||
stride_C_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
}
|
||||
ScaleM scale_m = nullptr;
|
||||
ScaleN scale_n = nullptr;
|
||||
};
|
||||
|
||||
template <index_t NumDTensor = 0>
|
||||
template <int NumberTensor = 0>
|
||||
using FlatmmHostArgs =
|
||||
ScaleFlatmmHostArgs<FlatmmScalePointer<-1>, FlatmmScalePointer<-1>, NumberTensor>;
|
||||
|
||||
template <class ScaleM, class ScaleN, index_t NumDTensor = 0>
|
||||
struct FlatmmKernelArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
@@ -82,6 +240,8 @@ struct FlatmmKernelArgs
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
index_t stride_E;
|
||||
index_t k_batch;
|
||||
ScaleM scale_m_ptr = nullptr;
|
||||
ScaleN scale_n_ptr = nullptr;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
|
||||
@@ -98,6 +258,7 @@ struct FlatmmKernel
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
@@ -113,7 +274,7 @@ struct FlatmmKernel
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
|
||||
// using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -124,40 +285,85 @@ struct FlatmmKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
assert(!UsePersistentKernel);
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
|
||||
{
|
||||
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = FlatmmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry<1, FlatmmKernel, FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
// << ", persistent_block_size: " << persistent_block_size
|
||||
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
||||
|
||||
assert(kargs.k_batch == 1);
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr KernelArgs
|
||||
MakeKernelArgs(const FlatmmHostArgs<NumDTensor>& hostArgs)
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_HOST static constexpr FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>
|
||||
MakeKernelArgs(const ScaleFlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
|
||||
{
|
||||
return KernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E,
|
||||
hostArgs.k_batch};
|
||||
return {hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E,
|
||||
hostArgs.k_batch,
|
||||
hostArgs.scale_m,
|
||||
hostArgs.scale_n};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
|
||||
{
|
||||
return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPongSize()
|
||||
{
|
||||
return FlatmmPipeline::GetSmemSize();
|
||||
}
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
template <class KernelArgs>
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto N1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<1>{});
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
@@ -173,11 +379,11 @@ struct FlatmmKernel
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead * kargs.stride_B;
|
||||
b_k_split_offset = k_id * KRead * kargs.stride_B * N1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead;
|
||||
b_k_split_offset = k_id * KRead * N1;
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
@@ -195,6 +401,7 @@ struct FlatmmKernel
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
template <class KernelArgs>
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
|
||||
{
|
||||
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
@@ -206,6 +413,14 @@ struct FlatmmKernel
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -340,7 +555,7 @@ struct FlatmmKernel
|
||||
return DTesnorIsValid;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
@@ -370,9 +585,9 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k /
|
||||
BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
index_t kFlatK =
|
||||
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
@@ -411,7 +626,7 @@ struct FlatmmKernel
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
@@ -420,7 +635,7 @@ struct FlatmmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
@@ -429,7 +644,45 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view);
|
||||
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
|
||||
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
|
||||
|
||||
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
|
||||
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-token scale
|
||||
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-channel scale
|
||||
|
||||
static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1,
|
||||
"only support per-tensor or per-row scaling");
|
||||
static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1,
|
||||
"only support per-tensor or per-column scaling");
|
||||
|
||||
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_m_ptr.ptr,
|
||||
make_tuple(
|
||||
kargs.M / ScaleGranularityM,
|
||||
ScaleGranularityKA == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKA),
|
||||
make_tuple(scale_stride_m, 0),
|
||||
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
|
||||
number<1>{});
|
||||
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_n_ptr.ptr,
|
||||
make_tuple(
|
||||
ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
|
||||
kargs.N / ScaleGranularityN),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(a_tensor_view,
|
||||
b_flat_tensor_view,
|
||||
ds_tensor_view,
|
||||
e_tensor_view,
|
||||
scale_m_view,
|
||||
scale_n_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
@@ -495,7 +748,12 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view);
|
||||
return make_tuple(a_pad_view,
|
||||
b_flat_tensor_view,
|
||||
ds_pad_view,
|
||||
e_pad_view,
|
||||
views.at(number<4>{}),
|
||||
views.at(number<5>{}));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
@@ -555,19 +813,42 @@ struct FlatmmKernel
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
|
||||
constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
|
||||
auto scale_m_window = make_tile_window(views.at(number<4>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number < ScaleGranularityKA == 0
|
||||
? TilePartitioner::NPerBlock
|
||||
: TilePartitioner::KPerBlock > {}),
|
||||
{i_m, 0});
|
||||
auto scale_n_window = make_tile_window(views.at(number<5>{}),
|
||||
make_tuple(number < ScaleGranularityKB == 0
|
||||
? TilePartitioner::MPerBlock
|
||||
: TilePartitioner::KPerBlock > {},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
ds_block_window,
|
||||
e_block_window,
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
}
|
||||
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void
|
||||
RunFlatmm(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_ping,
|
||||
void* smem_ptr_pong,
|
||||
const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
@@ -583,50 +864,77 @@ struct FlatmmKernel
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr);
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
|
||||
|
||||
auto scale_m_window = gemm_tile_windows.at(number<4>{});
|
||||
auto scale_n_window = gemm_tile_windows.at(number<5>{});
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr);
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
|
||||
int partition_idx = blockIdx.x) const
|
||||
{
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_flat_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
do
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_flat_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_ping[GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[GetSmemPongSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
partition_idx += gridDim.x;
|
||||
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
478
include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp
Normal file
478
include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp
Normal file
@@ -0,0 +1,478 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
struct GroupedFlatmmHostArgs
|
||||
{
|
||||
CK_TILE_HOST GroupedFlatmmHostArgs() = default;
|
||||
CK_TILE_HOST GroupedFlatmmHostArgs(index_t group_count_,
|
||||
index_t* M_,
|
||||
index_t* N_,
|
||||
index_t* K_,
|
||||
const void** a_ptr_,
|
||||
index_t* stride_A_,
|
||||
const void** b_shuffle_ptr_,
|
||||
index_t* stride_B_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
void** c_ptr_,
|
||||
index_t* stride_C_,
|
||||
index_t k_batch_,
|
||||
ScaleM* scale_m_ = nullptr,
|
||||
ScaleN* scale_n_ = nullptr)
|
||||
: group_count(group_count_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
a_ptr(a_ptr_),
|
||||
stride_A(stride_A_),
|
||||
b_shuffle_ptr(b_shuffle_ptr_),
|
||||
stride_B(stride_B_),
|
||||
ds_ptr(ds_ptr_),
|
||||
stride_Ds(stride_Ds_),
|
||||
c_ptr(c_ptr_),
|
||||
stride_C(stride_C_),
|
||||
k_batch(k_batch_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t group_count;
|
||||
index_t* M;
|
||||
index_t* N;
|
||||
index_t* K;
|
||||
const void** a_ptr;
|
||||
index_t* stride_A;
|
||||
const void** b_shuffle_ptr;
|
||||
index_t* stride_B;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
union
|
||||
{
|
||||
void** e_ptr;
|
||||
void** c_ptr;
|
||||
};
|
||||
index_t* stride_C;
|
||||
index_t k_batch;
|
||||
ScaleM* scale_m = nullptr;
|
||||
ScaleN* scale_n = nullptr;
|
||||
};
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
struct ContiguousGroupedFlatmmHostArgs
|
||||
{
|
||||
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs() = default;
|
||||
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs(index_t* M_indices_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
const void* a_ptr_,
|
||||
index_t stride_A_,
|
||||
const void* b_shuffle_ptr_,
|
||||
index_t stride_B_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
void* c_ptr_,
|
||||
index_t stride_C_,
|
||||
index_t k_batch_,
|
||||
ScaleM scale_m_ = nullptr,
|
||||
ScaleN scale_n_ = nullptr)
|
||||
: group_count(1),
|
||||
M_indices(M_indices_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
a_ptr(a_ptr_),
|
||||
stride_A(stride_A_),
|
||||
b_shuffle_ptr(b_shuffle_ptr_),
|
||||
stride_B(stride_B_),
|
||||
ds_ptr(ds_ptr_),
|
||||
stride_Ds(stride_Ds_),
|
||||
c_ptr(c_ptr_),
|
||||
stride_C(stride_C_),
|
||||
k_batch(k_batch_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
}
|
||||
index_t group_count;
|
||||
index_t* M_indices;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
const void* a_ptr;
|
||||
index_t stride_A;
|
||||
const void* b_shuffle_ptr;
|
||||
index_t stride_B;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
index_t stride_C;
|
||||
index_t k_batch;
|
||||
ScaleM scale_m = nullptr;
|
||||
ScaleN scale_n = nullptr;
|
||||
};
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
struct MaskedGroupedFlatmmHostArgs
|
||||
{
|
||||
CK_TILE_HOST MaskedGroupedFlatmmHostArgs() = default;
|
||||
CK_TILE_HOST MaskedGroupedFlatmmHostArgs(index_t* M_indices_,
|
||||
index_t group_count_,
|
||||
index_t Max_M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
const void* a_ptr_,
|
||||
index_t stride_A_,
|
||||
const void* b_shuffle_ptr_,
|
||||
index_t stride_B_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
void* c_ptr_,
|
||||
index_t stride_C_,
|
||||
index_t k_batch_,
|
||||
ScaleM scale_m_ = nullptr,
|
||||
ScaleN scale_n_ = nullptr)
|
||||
: M_indices(M_indices_),
|
||||
group_count(group_count_),
|
||||
M(Max_M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
a_ptr(a_ptr_),
|
||||
stride_A(stride_A_),
|
||||
b_shuffle_ptr(b_shuffle_ptr_),
|
||||
stride_B(stride_B_),
|
||||
ds_ptr(ds_ptr_),
|
||||
stride_Ds(stride_Ds_),
|
||||
c_ptr(c_ptr_),
|
||||
stride_C(stride_C_),
|
||||
k_batch(k_batch_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t* M_indices;
|
||||
index_t group_count;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
const void* a_ptr;
|
||||
index_t stride_A;
|
||||
const void* b_shuffle_ptr;
|
||||
index_t stride_B;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
index_t stride_C;
|
||||
index_t k_batch;
|
||||
ScaleM scale_m = nullptr;
|
||||
ScaleN scale_n = nullptr;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
|
||||
struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
|
||||
{
|
||||
using UnderlyingGemmKernel = FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>;
|
||||
using BlockGemmShape = typename UnderlyingGemmKernel::BlockGemmShape;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
|
||||
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
static constexpr index_t kBlockSize = FlatmmPipeline_::BlockSize;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
|
||||
CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
return concat(
|
||||
'_', "grouped_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize([[maybe_unused]] const GroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>& kernelArgs)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry<1, GroupedFlatmmKernel, GroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
|
||||
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
// << ", persistent_block_size: " << persistent_block_size << std::endl;
|
||||
|
||||
assert(kernelArgs.k_batch == 1);
|
||||
return dim3(persistent_block_size, 1, kernelArgs.k_batch);
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize([[maybe_unused]] const ContiguousGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>&
|
||||
kernelArgs)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry<1,
|
||||
GroupedFlatmmKernel,
|
||||
ContiguousGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kernelArgs.M, kernelArgs.N);
|
||||
|
||||
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
// << ", persistent_block_size: " << persistent_block_size
|
||||
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
||||
|
||||
assert(kernelArgs.k_batch == 1);
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kernelArgs.k_batch);
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
CK_TILE_HOST_DEVICE static auto GridSize(
|
||||
[[maybe_unused]] const MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>& kernelArgs)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry<1,
|
||||
GroupedFlatmmKernel,
|
||||
MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
// const int total_work_tile_cnt = TilePartitioner::GridSize(kernelArgs.M, kernelArgs.N);
|
||||
|
||||
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
// << ", persistent_block_size: " << persistent_block_size << std::endl;
|
||||
|
||||
assert(kernelArgs.k_batch == 1);
|
||||
return dim3(persistent_block_size, 1, kernelArgs.k_batch);
|
||||
}
|
||||
|
||||
template <typename HostArgs>
|
||||
CK_TILE_HOST static constexpr auto MakeKernelArgs(const HostArgs& hostArgs)
|
||||
{
|
||||
return hostArgs;
|
||||
}
|
||||
// CK_TILE_HOST static constexpr auto
|
||||
// MakeKernelArgs(const ContiguousGroupedFlatmmHostArgs& hostArgs)
|
||||
// {
|
||||
// return hostArgs;
|
||||
// }
|
||||
// CK_TILE_HOST static constexpr auto
|
||||
// MakeKernelArgs(const MaskedGroupedFlatmmHostArgs& hostArgs)
|
||||
// {
|
||||
// return hostArgs;
|
||||
// }
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
CK_TILE_DEVICE void operator()(GroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor> kargs) const
|
||||
{
|
||||
int group_idx = 0;
|
||||
int block_linear_idx = blockIdx.x;
|
||||
int total_block_cnt = gridDim.x;
|
||||
|
||||
UnderlyingGemmKernel underlying_kernel{};
|
||||
for(; group_idx < kargs.group_count; ++group_idx)
|
||||
{
|
||||
const index_t M = kargs.M[group_idx];
|
||||
const index_t N = kargs.N[group_idx];
|
||||
const index_t group_block_cnt = TilePartitioner::GridSize(M, N);
|
||||
|
||||
while(block_linear_idx < group_block_cnt)
|
||||
{
|
||||
// Found the group this block belongs to
|
||||
// create the kernel args for the underlying flatmm kernel
|
||||
FlatmmKernelArgs<ScaleM, ScaleN, NumDTensor> impl_kargs{
|
||||
kargs.a_ptr[group_idx],
|
||||
kargs.b_shuffle_ptr[group_idx],
|
||||
kargs.ds_ptr,
|
||||
kargs.c_ptr[group_idx],
|
||||
kargs.M[group_idx],
|
||||
kargs.N[group_idx],
|
||||
kargs.K[group_idx],
|
||||
kargs.stride_A[group_idx],
|
||||
kargs.stride_B[group_idx],
|
||||
kargs.stride_Ds,
|
||||
kargs.stride_C[group_idx],
|
||||
kargs.k_batch,
|
||||
kargs.scale_m[group_idx],
|
||||
kargs.scale_n[group_idx]};
|
||||
// call the underlying flatmm kernel
|
||||
underlying_kernel(impl_kargs, block_linear_idx);
|
||||
block_linear_idx += total_block_cnt;
|
||||
}
|
||||
block_linear_idx -= group_block_cnt;
|
||||
}
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(ContiguousGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor> kargs) const
|
||||
{
|
||||
int block_linear_idx = blockIdx.x;
|
||||
int total_block_cnt = gridDim.x;
|
||||
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
UnderlyingGemmKernel underlying_kernel{};
|
||||
for(; block_linear_idx < total_work_tile_cnt; block_linear_idx += total_block_cnt)
|
||||
{
|
||||
auto [block_m_idx, block_n_idx] =
|
||||
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(block_linear_idx);
|
||||
// get the group index from the M_indices
|
||||
int group_idx = kargs.M_indices[block_m_idx * BlockGemmShape::kM];
|
||||
|
||||
FlatmmKernelArgs<ScaleM, ScaleN, NumDTensor> impl_kargs{
|
||||
kargs.a_ptr,
|
||||
static_cast<const BDataType*>(kargs.b_shuffle_ptr) + group_idx * kargs.N * kargs.K,
|
||||
kargs.ds_ptr,
|
||||
kargs.c_ptr,
|
||||
kargs.M,
|
||||
kargs.N,
|
||||
kargs.K,
|
||||
kargs.stride_A,
|
||||
kargs.stride_B,
|
||||
kargs.stride_Ds,
|
||||
kargs.stride_C,
|
||||
kargs.k_batch,
|
||||
kargs.scale_m,
|
||||
kargs.scale_n};
|
||||
// call the underlying flatmm kernel
|
||||
underlying_kernel(impl_kargs, block_linear_idx);
|
||||
}
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor> kargs) const
|
||||
{
|
||||
int group_idx = 0;
|
||||
int block_linear_idx = blockIdx.x;
|
||||
int total_block_cnt = gridDim.x;
|
||||
|
||||
UnderlyingGemmKernel underlying_kernel{};
|
||||
for(; group_idx < kargs.group_count; ++group_idx)
|
||||
{
|
||||
const index_t valid_M = kargs.M_indices[group_idx];
|
||||
const index_t N = kargs.N;
|
||||
const index_t group_block_cnt = TilePartitioner::GridSize(valid_M, N);
|
||||
|
||||
while(block_linear_idx < group_block_cnt)
|
||||
{
|
||||
// Found the group this block belongs to
|
||||
// create the kernel args for the underlying flatmm kernel
|
||||
FlatmmKernelArgs<ScaleM, ScaleN, NumDTensor> impl_kargs{
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + group_idx * kargs.M * kargs.K,
|
||||
static_cast<const BDataType*>(kargs.b_shuffle_ptr) +
|
||||
group_idx * kargs.N * kargs.K,
|
||||
kargs.ds_ptr,
|
||||
static_cast<CDataType*>(kargs.c_ptr) + group_idx * kargs.M * kargs.N,
|
||||
valid_M,
|
||||
kargs.N,
|
||||
kargs.K,
|
||||
kargs.stride_A,
|
||||
kargs.stride_B,
|
||||
kargs.stride_Ds,
|
||||
kargs.stride_C,
|
||||
kargs.k_batch,
|
||||
kargs.scale_m + group_idx * kargs.M,
|
||||
kargs.scale_n + group_idx * kargs.N};
|
||||
// call the underlying flatmm kernel
|
||||
underlying_kernel(impl_kargs, block_linear_idx);
|
||||
block_linear_idx += total_block_cnt;
|
||||
}
|
||||
block_linear_idx -= group_block_cnt;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
458
include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp
Normal file
458
include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp
Normal file
@@ -0,0 +1,458 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
|
||||
struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
|
||||
{
|
||||
using Underlying = FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
|
||||
using BlockGemmShape =
|
||||
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
static constexpr int QuantPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
static constexpr int N_Pack = 2;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
static constexpr auto I4 = number<4>();
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
// using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "mixed_prec_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
|
||||
{
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry<1,
|
||||
F16xMXF4FlatmmKernel,
|
||||
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
// << ", persistent_block_size: " << persistent_block_size
|
||||
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
||||
|
||||
assert(kargs.k_batch == 1);
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto scale_n = kargs.scale_n_ptr;
|
||||
|
||||
index_t FlatScaleK =
|
||||
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const e8m0_t*>(scale_n.ptr),
|
||||
make_tuple(FlatScaleN, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(
|
||||
a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_tensor_view = views.at(I1);
|
||||
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& d_tensor_view = views.at(I2);
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<FlatmmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_flat_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_block_window =
|
||||
make_tile_window(b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_n, i_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
auto scale_block_window =
|
||||
make_tile_window(views.at(I4),
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
|
||||
{i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
ds_block_window,
|
||||
e_block_window,
|
||||
scale_block_window);
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void
|
||||
RunFlatmm(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_ping,
|
||||
void* smem_ptr_pong,
|
||||
const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I4);
|
||||
|
||||
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|
||||
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|
||||
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
|
||||
"ScaleM and ScaleN should have the same GranularityK");
|
||||
constexpr bool DoEpiScale =
|
||||
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
|
||||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
|
||||
|
||||
auto a_block_window_with_distr =
|
||||
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
|
||||
a_block_window.get_window_lengths(),
|
||||
a_block_window.get_window_origin(),
|
||||
FlatmmPipeline::GetADramTileDistribution());
|
||||
const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
|
||||
b_flat_block_window,
|
||||
scale_block_window,
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(DoEpiScale)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
kargs.scale_m_ptr + block_idx_m,
|
||||
kargs.scale_n_ptr + block_idx_n);
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
|
||||
int partition_idx = blockIdx.x) const
|
||||
{
|
||||
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
do
|
||||
{
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
|
||||
splitk_batch_offset.b_k_split_offset / QuantPackedSize;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
partition_idx += gridDim.x;
|
||||
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
1325
include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp
Normal file
1325
include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -238,22 +239,47 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t scale = 4;
|
||||
#else
|
||||
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
|
||||
#endif
|
||||
if constexpr(TileShape::WarpTile::at(I1) == 32)
|
||||
{
|
||||
return TileShape::WarpTile::at(I2) * scale / 2;
|
||||
return TileShape::WarpTile::at(I2) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16);
|
||||
return TileShape::WarpTile::at(I2) * scale / 4;
|
||||
return TileShape::WarpTile::at(I2) / 4;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALDS_WarpTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2);
|
||||
|
||||
constexpr int Repeat = TileShape::BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr int KLane = get_warp_size() / MPerXdl;
|
||||
constexpr int KPerThread = KPerXdl / KLane;
|
||||
|
||||
constexpr int MaxVecSize = 16 / sizeof(ADataType);
|
||||
constexpr int KItemsPerLoad = min(MaxVecSize, KPerThread);
|
||||
constexpr int KFragment = KPerThread / KItemsPerLoad;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<Repeat>,
|
||||
tuple<sequence<MPerXdl>, sequence<KFragment, KLane, KItemsPerLoad>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
@@ -307,10 +333,10 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
if constexpr(get_warp_size() % (M2 * K0) == 0)
|
||||
if constexpr(get_warp_size() % K0 == 0)
|
||||
{
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
@@ -329,24 +355,54 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t M0 = BlockSize / get_warp_size();
|
||||
constexpr index_t M1 = MPerBlock / (M2 * M0);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M1, M2 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
constexpr index_t KWave = K0 / get_warp_size();
|
||||
constexpr index_t M0 = BlockSize / get_warp_size() / KWave;
|
||||
constexpr index_t M1 = MPerBlock / M0;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<KWave, get_warp_size(), K1>>,
|
||||
tuple<sequence<1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0>, sequence<1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
@@ -355,15 +411,16 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t KRepeatInWave = 2;
|
||||
#else
|
||||
constexpr index_t KRepeatInWave = 1;
|
||||
#endif
|
||||
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
|
||||
|
||||
constexpr index_t MaxVecSize = 16 / sizeof(typename Problem::BDataType);
|
||||
constexpr index_t KItemsPerLoad = min(KBPerLoad, MaxVecSize);
|
||||
constexpr index_t KFragment = KBPerLoad / KItemsPerLoad;
|
||||
static_assert(KFragment * KItemsPerLoad == KBPerLoad);
|
||||
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim./
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
static_assert(TileShape::BlockWarps::at(number<2>{}) == 1, "Requires K_Warp == 1");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
@@ -371,15 +428,17 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat, KRepeatInWave>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KFragment, KWavePerBlk, KThdPerWave, KItemsPerLoad>>, // first
|
||||
// direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
|
||||
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,239 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE 0
|
||||
|
||||
#if defined(__gfx950__)
|
||||
#define CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4 1
|
||||
#else
|
||||
#define CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4 0
|
||||
#endif
|
||||
|
||||
#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS \
|
||||
(CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && \
|
||||
CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4)
|
||||
|
||||
struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
static constexpr index_t KBPerLoad = 32;
|
||||
static constexpr index_t N_Pack = 2; // it's fixed for fp4
|
||||
static constexpr index_t K_Pack = 2; // it's fixed for fp4
|
||||
|
||||
template <typename Problem, typename NativeADramTensorView>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
TransformF16xF4_ATensorView(const NativeADramTensorView& a_dram_view)
|
||||
{
|
||||
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
|
||||
constexpr int DynamicTileOffsetFlag = 0;
|
||||
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
|
||||
|
||||
// implement swizzle pattern on global side
|
||||
// because we can't adjust the ds_write pattern of BUFFER_LOAD_LDS.
|
||||
auto swizzle_a_dram_view_1 = transform_tensor_view(
|
||||
a_dram_view,
|
||||
make_tuple(
|
||||
// M-dim is not affected by swizzle pattern
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<DynamicTileOffsetFlag>{}, number<MPerBlock>{})),
|
||||
// K-dim is the swizzle dimension
|
||||
make_unmerge_transform(make_tuple(number<DynamicTileOffsetFlag>{},
|
||||
number<KPerBlock / KPack>{},
|
||||
number<KPack>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}));
|
||||
|
||||
auto swizzle_a_dram_view_2 = transform_tensor_view(
|
||||
swizzle_a_dram_view_1,
|
||||
make_tuple(make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
|
||||
make_xor_transform(make_tuple(number<MPerBlock>{},
|
||||
number<ContiguousThreadsCntInDS_READ_16B>{})),
|
||||
make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
|
||||
|
||||
return transform_tensor_view(
|
||||
swizzle_a_dram_view_2,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<DynamicTileOffsetFlag>{}, number<MPerBlock>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(number<DynamicTileOffsetFlag>{},
|
||||
number<KPerBlock / KPack>{},
|
||||
number<KPack>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
#else
|
||||
return a_dram_view;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ReadALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
|
||||
number<ContiguousThreadsCntInDS_READ_16B>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_WriteALdsBlockDescriptor()
|
||||
{
|
||||
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
return make_naive_tensor_descriptor(make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
make_tuple(number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
#else
|
||||
return MakeF16xF4_ReadALdsBlockDescriptor<Problem>();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ALDS_TileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
|
||||
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
constexpr int Repeat = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M0 = TileShape::WarpTile::at(I0);
|
||||
|
||||
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
|
||||
|
||||
constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 8
|
||||
constexpr int XDL_PerThreadK = KBPerLoad / K2; // 4
|
||||
constexpr int K0 = K_Lane; // 4
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Repeat>,
|
||||
tuple<sequence<M0>, sequence<K0, XDL_PerThreadK, K2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NWavePerBlk, N_Pack>, // second
|
||||
// direction
|
||||
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>, // which direction
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4ScaleBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
[[maybe_unused]] constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
|
||||
|
||||
[[maybe_unused]] constexpr index_t XDLPerBlock =
|
||||
TileShape::kK / TileShape::WarpTile::at(I2);
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
|
||||
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
|
||||
|
||||
constexpr index_t NWavePerBlk = N_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>, // ?
|
||||
tuple<sequence<NWavePerBlk>, // second direction
|
||||
sequence<K_Lane, N_Lane, N_Pack * K_Pack>>, // first
|
||||
// direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<1>, sequence<2, 2>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -310,4 +310,147 @@ struct UniversalGemmPipelineProblem
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct FlatmmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Traits::AsLayout>;
|
||||
using BLayout = remove_cvref_t<typename Traits::BsLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kPadM = Traits::kPadM;
|
||||
static constexpr bool kPadN = Traits::kPadN;
|
||||
static constexpr bool kPadK = Traits::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm_problem",
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
|
||||
? pixels_per_thread
|
||||
: PackedSize * VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
|
||||
? pixels_per_thread
|
||||
: PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
|
||||
{
|
||||
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
|
||||
constexpr index_t M0 = get_warp_size() / N2;
|
||||
constexpr index_t M1 = BlockGemmShape::kM / M0;
|
||||
|
||||
return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = BlockGemmShape::kN / N0;
|
||||
|
||||
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeA = []() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentA();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadM ? 1 : GetAlignmentA();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr index_t VectorSizeB = []() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentB();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentB();
|
||||
}
|
||||
}();
|
||||
static constexpr index_t VectorSizeC = []() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentC();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadM ? 1 : GetAlignmentC();
|
||||
}
|
||||
}();
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
10
include/ck_tile/ops/moe_flatmm.hpp
Normal file
10
include/ck_tile/ops/moe_flatmm.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
Reference in New Issue
Block a user