mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
merge nkpad
This commit is contained in:
@@ -1,6 +1,15 @@
|
||||
add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)
|
||||
add_executable(tile_example_mixed_prec_flatmm EXCLUDE_FROM_ALL mixed_prec/mixed_prec_flatmm.cpp)
|
||||
|
||||
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
|
||||
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter)
|
||||
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef -Wno-unused-variable -Wno-unused-parameter)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps -Wno-nrvo)
|
||||
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0")
|
||||
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --disable-schedmodel-in-sched-mi=1 -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental -mllvm --misched-bottomup=1")
|
||||
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -11,7 +11,96 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "flatmm_basic.hpp"
|
||||
#include "run_flatmm_example.inc"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename T>
|
||||
constexpr const char* DataTypeToString()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck_tile::half_t>)
|
||||
{
|
||||
return "fp16";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::fp8_t>)
|
||||
{
|
||||
return "fp8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf8_t>)
|
||||
{
|
||||
return "bf8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
return "bf16";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
FlatmmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b_v1(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp;
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Tile,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
FlatmmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
@@ -23,9 +112,12 @@ template <typename FlatmmConfig,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s)
|
||||
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
@@ -80,14 +172,14 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
@@ -111,7 +203,10 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups>>;
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
@@ -119,7 +214,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
@@ -204,6 +299,111 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool UsePersistentKernel = false,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_shuffle_dev_buf,
|
||||
ck_tile::DeviceMem& c_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
ScaleM scale_m,
|
||||
ScaleN scale_n,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C,
|
||||
scale_m,
|
||||
scale_n};
|
||||
|
||||
float ave_time = flatmm_calc<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ScaleM,
|
||||
ScaleN,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
|
||||
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
|
||||
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "256", "m dimension")
|
||||
.insert("n", "256", "n dimension")
|
||||
.insert("k", "128", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
template <template <typename PreType> typename FlatmmConfig>
|
||||
int run_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
@@ -217,6 +417,8 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
int scale_opt = arg_parser.get_int("scale");
|
||||
int persistent_opt = arg_parser.get_int("persistent");
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
@@ -231,13 +433,53 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
if(scale_opt == 0)
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
-1,
|
||||
-1,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
1,
|
||||
1>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
1,
|
||||
1,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
if(scale_opt == 0)
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -35,12 +35,13 @@ struct FlatmmConfig32
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
@@ -72,18 +73,28 @@ struct FlatmmConfig16
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
N_Tile / FlatmmConfig16<DataType>::N_Warp_Tile / FlatmmConfig16<DataType>::N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename ADataType>
|
||||
@@ -164,39 +175,19 @@ struct is_8bit_type
|
||||
{
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "256", "m dimension")
|
||||
.insert("n", "256", "n dimension")
|
||||
.insert("k", "128", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename FlatmmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s);
|
||||
|
||||
50
example/ck_tile/18_flatmm/mixed_prec/a16w4_flatmm.hpp
Normal file
50
example/ck_tile/18_flatmm/mixed_prec/a16w4_flatmm.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
N_Tile / A16W4_FlatmmConfig16::N_Warp_Tile / A16W4_FlatmmConfig16::N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
485
example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp
Normal file
485
example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp
Normal file
@@ -0,0 +1,485 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mixed_prec_flatmm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>;
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
|
||||
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using Kernel =
|
||||
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleN,
|
||||
bool UsePersistentKernel = false,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_mixed_prec_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_shuffle_dev_buf,
|
||||
ck_tile::DeviceMem& c_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
ScaleN dequant_scale_n,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
// Activation has no scale
|
||||
using ActScaleType = ck_tile::FlatmmScalePointer<-1>;
|
||||
|
||||
ck_tile::ScaleFlatmmHostArgs<ActScaleType, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C,
|
||||
{},
|
||||
dequant_scale_n};
|
||||
|
||||
float ave_time = mixed_prec_flatmm_calc<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ActScaleType,
|
||||
ScaleN,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K / PackedSize +
|
||||
sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run A16W4_Flatmm kernel "
|
||||
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
|
||||
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "256", "m dimension")
|
||||
.insert("n", "256", "n dimension")
|
||||
.insert("k", "512", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on GPU")
|
||||
.insert("mixed_prec",
|
||||
"bf16xfp4",
|
||||
"data type for activation and weight, support: bf16xfp4, fp16xfp4")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, class IterSrc, class IterDst>
|
||||
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = FlatmmConfig::N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, class T>
|
||||
auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
|
||||
{
|
||||
assert(scale.get_lengths().size() == 2);
|
||||
int n_ = scale.get_lengths()[1];
|
||||
int k_ = scale.get_lengths()[0];
|
||||
|
||||
constexpr int K_Pack = 2; // fixed for mxfp4
|
||||
constexpr int N_Pack = 2; // fixed for mxfp4
|
||||
constexpr int GranularityK = 32; // fixed for mxfp4
|
||||
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
|
||||
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
|
||||
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
|
||||
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
|
||||
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
k_ / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
|
||||
N_Pack,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {3, 0, 2, 5, 1, 4});
|
||||
}
|
||||
|
||||
#include "run_mixed_prec_flatmm.inc"
|
||||
|
||||
template <typename FlatmmConfig>
|
||||
int run_mixed_prec_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string mixed_prec = arg_parser.get_str("mixed_prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
int persistent_opt = arg_parser.get_int("persistent");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported warp_tile!");
|
||||
}
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
15
example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.hpp
Normal file
15
example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.hpp
Normal file
@@ -0,0 +1,15 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include "a16w4_flatmm.hpp"
|
||||
180
example/ck_tile/18_flatmm/mixed_prec/run_mixed_prec_flatmm.inc
Normal file
180
example/ck_tile/18_flatmm/mixed_prec/run_mixed_prec_flatmm.inc
Normal file
@@ -0,0 +1,180 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
template <typename PrecActType,
|
||||
typename PrecWeightType,
|
||||
typename FlatmmConfig,
|
||||
bool UsePersistentKernel = false,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using CDataType = PrecActType;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
constexpr int DequantGranularityN = 1;
|
||||
constexpr int DequantGranularityK = 32;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
ck_tile::index_t n_warmup = arg_parser.get_int("warmup");
|
||||
ck_tile::index_t n_repeat = arg_parser.get_int("repeat");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_origin_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_rslt_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::HostTensorDescriptor(
|
||||
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffle_host.begin(), N, K);
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b_shuffle = preShuffleScale<FlatmmConfig>(scale_b);
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffle.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
c_rslt_host.SetZero();
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffle.data());
|
||||
|
||||
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK>{
|
||||
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
|
||||
|
||||
invoke_mixed_prec_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(scale_b_dev_ptr),
|
||||
UsePersistentKernel>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
scale_b_dev_ptr,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
|
||||
bool pass = true;
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::DeviceMem b_origin_dev_buf(b_origin_host.get_element_space_size_in_bytes());
|
||||
b_origin_dev_buf.ToDevice(b_origin_host.data());
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::DeviceMem c_gpu_ref_dev_buf(c_gpu_ref_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::HostTensor<AccDataType> scale_A(
|
||||
ck_tile::HostTensorDescriptor({1, K / DequantGranularityK}, {1, 1}));
|
||||
|
||||
// scaleA = 1 has no effect on the result
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(scale_A);
|
||||
ck_tile::DeviceMem scale_A_dev_buf(scale_A.get_element_space_size_in_bytes());
|
||||
scale_A_dev_buf.ToDevice(scale_A.data());
|
||||
|
||||
// convert scale_b from e8m0 to float
|
||||
ck_tile::HostTensor<AccDataType> scale_b_float(ck_tile::HostTensorDescriptor(
|
||||
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
|
||||
std::copy(scale_b.begin(), scale_b.end(), scale_b_float.begin());
|
||||
ck_tile::DeviceMem scale_b_float_dev_buf(scale_b_float.get_element_space_size_in_bytes());
|
||||
scale_b_float_dev_buf.ToDevice(scale_b_float.data());
|
||||
|
||||
c_gpu_ref_dev_buf.SetZero();
|
||||
ck_tile::reference_blockwise_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_origin_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_gpu_ref_dev_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
M,
|
||||
DequantGranularityN,
|
||||
DequantGranularityK,
|
||||
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_rslt_host, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -154,6 +154,9 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
int ScaleGranularityM = -1,
|
||||
int ScaleGranularityN = -1,
|
||||
bool UsePersistentKernel = false,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
@@ -197,21 +200,32 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
ck_tile::HostTensor<CDataType> c_rslt_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));
|
||||
|
||||
// TODO: add different init types
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
// ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
// ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_host);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -222,14 +236,34 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem per_channel_scale_dev_buf(
|
||||
per_channel_scale.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
c_rslt_host.SetZero();
|
||||
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
|
||||
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
|
||||
|
||||
// do pre-shuffle
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig>(b_origin_host);
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
|
||||
if constexpr(FlatmmConfig::TiledMMAPermuteN)
|
||||
{
|
||||
return shuffle_b_v1<FlatmmConfig>(b_origin_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
return shuffle_b<FlatmmConfig>(b_origin_host);
|
||||
}
|
||||
}();
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
|
||||
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
|
||||
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
|
||||
invoke_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -239,24 +273,32 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
CLayout,
|
||||
decltype(per_token_scale_dev_ptr),
|
||||
decltype(per_channel_scale_dev_ptr),
|
||||
UsePersistentKernel>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
per_token_scale_dev_ptr,
|
||||
per_channel_scale_dev_ptr,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
if(ScaleGranularityM != -1 || ScaleGranularityN != -1)
|
||||
throw std::runtime_error("ScaleAB is not supported for CPU verification!\n");
|
||||
ck_tile::HostTensor<CDataType> c_ref_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_ref_host.SetZero();
|
||||
@@ -304,13 +346,41 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
if constexpr(ScaleGranularityM == -1 && ScaleGranularityN == -1)
|
||||
{
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_blockwise_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
d_A,
|
||||
d_B,
|
||||
d_C,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
ScaleGranularityM,
|
||||
ScaleGranularityN,
|
||||
K,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_dev_buf.GetDeviceBuffer(),
|
||||
d_C,
|
||||
|
||||
14
example/ck_tile/21_moe_flatmm/CMakeLists.txt
Normal file
14
example/ck_tile/21_moe_flatmm/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp)
|
||||
add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp)
|
||||
|
||||
|
||||
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
|
||||
|
||||
list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS -Wno-nrvo -Wno-unused-variable -Wno-unused-parameter -Wno-unused-local-typedef -Wno-float-equal)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS --save-temps)
|
||||
target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS})
|
||||
514
example/ck_tile/21_moe_flatmm/mixed_prec/a16w4_moe_flatmm.cpp
Normal file
514
example/ck_tile/21_moe_flatmm/mixed_prec/a16w4_moe_flatmm.cpp
Normal file
@@ -0,0 +1,514 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "a16w4_moe_flatmm.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_gemm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// gemm1
|
||||
// operand-A = [num_token, d_model]
|
||||
// operand-B = [num_expert, hidden, d_model]
|
||||
// operand-C = [num_token, topk, hidden]
|
||||
|
||||
// gemm2
|
||||
// operand-A = [num_token, topk, hidden]
|
||||
// operand-B = [num_expert, d_model, hidden]
|
||||
// operand-C = [num_token, d_model]
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind moe_kind = ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename MoeFlatmmHostArgs>
|
||||
float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
false, // UsePersistentKernel_
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>; // Preshuffle_
|
||||
|
||||
constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t>;
|
||||
|
||||
if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
static_assert(
|
||||
FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0,
|
||||
"requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
|
||||
}
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
|
||||
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using CodegenPipelineProblem =
|
||||
std::conditional_t<MXFP4_Pipeline,
|
||||
ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using CodegenFlatmmPipeline = std::conditional_t<
|
||||
MXFP4_Pipeline,
|
||||
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>,
|
||||
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>>;
|
||||
using FusedAct =
|
||||
std::conditional_t<MXFP4_Pipeline, ck_tile::moe::Swiglu, ck_tile::moe::MoeSilu>;
|
||||
|
||||
using Kernel = ck_tile::MoeFlatmmKernel<TilePartitioner,
|
||||
CodegenFlatmmPipeline,
|
||||
GemmEpilogue,
|
||||
moe_kind,
|
||||
FusedAct>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
|
||||
? 2
|
||||
: 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
|
||||
? 2
|
||||
: 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2 ? args.NumTokens * args.TopK
|
||||
: args.NumTokens,
|
||||
args.K,
|
||||
args.stride_A,
|
||||
is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N * args.NumExperts, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
const int outputN =
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? args.N / 2 : args.N;
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.NumTokens * args.N * sizeof(CDataType), s.stream_id_));
|
||||
else if(args.k_batch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(args.e_ptr,
|
||||
0,
|
||||
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, ck_tile::MoeFlatmmKind moe_kind, class IterSrc, class IterDst>
|
||||
void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N, int K)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = FlatmmConfig::N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
|
||||
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
int up_stride = N / 2 / NLane;
|
||||
|
||||
for(long eid = 0; eid < experts_cnt; ++eid)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
// interleave gate and up part with granularity is 16.
|
||||
int n0_interleave = n >= N / 2 ? (n0 - up_stride) * 2 + 1 : // up part
|
||||
n0 * 2; // gate part
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
long outputIndex = eid * N * K_pk + n0_interleave * KPack * NLane * KLane * K0 +
|
||||
k0 * KPack * NLane * KLane + k1 * KPack * NLane +
|
||||
n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(long eid = 0; eid < experts_cnt; ++eid)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
long outputIndex = eid * N * K_pk + n0 * KPack * NLane * KLane * K0 +
|
||||
k0 * KPack * NLane * KLane + k1 * KPack * NLane +
|
||||
n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig, ck_tile::MoeFlatmmKind moe_kind, typename T>
|
||||
auto shuffle_mxfp4_scale(const ck_tile::HostTensor<T>& scale, int experts_cnt)
|
||||
{
|
||||
assert(scale.get_lengths().size() == 2);
|
||||
int n_ = scale.get_lengths()[1];
|
||||
int k_ = scale.get_lengths()[0];
|
||||
|
||||
int k_per_expert = k_ / experts_cnt;
|
||||
|
||||
constexpr int K_Pack = 2; // fixed for mxfp4
|
||||
constexpr int N_Pack = 2; // fixed for mxfp4
|
||||
constexpr int GranularityK = 32; // fixed for mxfp4
|
||||
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
|
||||
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
|
||||
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
|
||||
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
|
||||
|
||||
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
experts_cnt,
|
||||
k_per_expert / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
N_Pack, // N_Pack = 2 is composed of Gate + Up.
|
||||
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {0, 5, 1, 3, 6, 2, 4});
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
experts_cnt,
|
||||
k_per_expert / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
|
||||
N_Pack,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {0, 4, 1, 3, 6, 2, 5});
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_a16w4_moe_flatmm_example.inc"
|
||||
|
||||
template <typename FlatmmConfig>
|
||||
int run_a16w4_moe_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
const std::string mixed_prec = arg_parser.get_str("mixed_prec");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
const std::string gemm_kind = arg_parser.get_str("gemm_kind");
|
||||
if(gemm_kind == "gemm1_gate_up")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm2")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm2!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
|
||||
"[gemm1_gate_up | gemm2]");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
// else if(warp_tile == 1)
|
||||
// {
|
||||
// return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
|
||||
// }
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
N_Tile / A16W4_FlatmmConfig16::N_Warp_Tile / A16W4_FlatmmConfig16::N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("experts", "8", "Num of experts - 8 by default")
|
||||
.insert("NumTokens", "128", "M dimensions - 128 by default.")
|
||||
.insert("TopK", "3", "Top K - 3 by default.")
|
||||
.insert("N", "4096", "N dimensions - 4096 by default.")
|
||||
.insert("K", "4096", "K dimensions - 4096 by default.")
|
||||
.insert("stride_A", "", "Tensor A strides - it is empty by default.")
|
||||
.insert("stride_B", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_C", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("gemm_kind",
|
||||
"gemm1_gate_up",
|
||||
"Gemm kind in FFN network [gemm1_gate_up | gemm2] - "
|
||||
"gemm1_gate_up by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("mixed_prec",
|
||||
"bf16xfp4",
|
||||
"data type for activation and weight, support: bf16xfp4, fp16xfp4")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 16x16 (950 only, may use a larger tile than warp_tile=0)")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
@@ -0,0 +1,356 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename MoeHostArgs>
|
||||
float invoke_a16w4_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
|
||||
{
|
||||
float ave_time = a16w4_moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
kind,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::string op_name{"Moe Gemm"};
|
||||
|
||||
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
|
||||
sizeof(BDataType) * args.N * args.K / PackedSize +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecActType,
|
||||
typename PrecWeightType,
|
||||
typename FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using CDataType = PrecActType;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
constexpr int ScaleGranularityN = 1;
|
||||
constexpr int ScaleGranularityK = 32;
|
||||
|
||||
const ck_tile::index_t N = arg_parser.get_int("N");
|
||||
const ck_tile::index_t K = arg_parser.get_int("K");
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_A");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_B");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_C");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
|
||||
const ck_tile::index_t topk = arg_parser.get_int("TopK");
|
||||
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
|
||||
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
|
||||
const ck_tile::index_t experts = arg_parser.get_int("experts");
|
||||
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
|
||||
|
||||
ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * MPerBlock * topk;
|
||||
ck_tile::index_t valid_tile_num = sorted_tile_num;
|
||||
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
|
||||
const ck_tile::index_t M = sorted_tile_num * MPerBlock;
|
||||
const ck_tile::index_t outputN = kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? N / 2 : N;
|
||||
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr bool IsInputGemm = kind != ck_tile::MoeFlatmmKind::kFFN_gemm2;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout)));
|
||||
auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
|
||||
is_row_major(b_layout)
|
||||
? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
|
||||
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
|
||||
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::HostTensorDescriptor(
|
||||
{K * experts / ScaleGranularityK, N / ScaleGranularityN}, {N / ScaleGranularityN, 1}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.f, 1.f}(scale_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.0f, 1.0f}(scale_b);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host(
|
||||
ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
|
||||
shuffle_mxfp4_weight<FlatmmConfig, kind>(
|
||||
b_k_n_tensor.begin(), b_shuffle_host.begin(), experts, N, K);
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b_shuffle =
|
||||
shuffle_mxfp4_scale<FlatmmConfig, kind>(scale_b, experts);
|
||||
ck_tile::DeviceMem scale_b_shuffle_dev_buf(scale_b_shuffle.get_element_space_size_in_bytes());
|
||||
|
||||
std::cout << "moe_flatmm:"
|
||||
<< "\n num_experts: " << experts << "\n num_tokens: " << num_tokens
|
||||
<< "\n topk: " << topk << "\n sorted_tile_num: " << sorted_tile_num
|
||||
<< "\n problem_n: " << N << "\n problem_k: " << K
|
||||
<< "\n a_m_k: " << a_m_k_tensor.mDesc << "\n b_k_n: " << b_k_n_tensor.mDesc
|
||||
<< "\n b_shuffle: " << b_shuffle_host.mDesc << "\n c_m_n: " << c_m_n_tensor.mDesc
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::HostTensor<ck_tile::index_t> expert_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> expert_weight(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
|
||||
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
ck_tile::HostTensor<AccDataType> expert_bias(ck_tile::HostTensorDescriptor({experts * N}, {1}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
// for verification only, no need to satify weight normalization
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.0f, 1.0f}(expert_bias);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.0f, 1.0f}(expert_weight);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 0.0f}(expert_bias);
|
||||
}
|
||||
|
||||
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
|
||||
// int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
|
||||
}
|
||||
|
||||
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
// int token_per_tile = num_tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
// sorted_token_ids.mData[0] = 0;
|
||||
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = num_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf{a_m_k_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_origin_dev_buf{b_k_n_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf{b_shuffle_host.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem c_m_n_dev_buf{c_m_n_tensor.get_element_space_size_in_bytes()};
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_tensor.data());
|
||||
b_origin_dev_buf.ToDevice(b_k_n_tensor.data());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_tensor.SetZero();
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_dev{sorted_token_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_ids_dev{expert_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem max_token_id_dev{max_token_id.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_weight_dev{expert_weight.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_bias_dev{expert_bias.get_element_space_size_in_bytes()};
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.data());
|
||||
expert_weight_dev.ToDevice(expert_weight.data());
|
||||
expert_bias_dev.ToDevice(expert_bias.data());
|
||||
scale_b_shuffle_dev_buf.ToDevice(scale_b_shuffle.data());
|
||||
|
||||
const ck_tile::index_t* p_sorted_token_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(sorted_token_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_expert_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(expert_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_max_token_id_dev =
|
||||
static_cast<ck_tile::index_t*>(max_token_id_dev.GetDeviceBuffer());
|
||||
const AccDataType* p_sorted_expert_weight_dev =
|
||||
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
|
||||
|
||||
auto scale_b_shuffle_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
auto exp_bias_dev_ptr = ck_tile::FlatmmScalePointer<1>{
|
||||
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()), experts * N};
|
||||
|
||||
using MoeFlatmmArgs = ck_tile::MoeFlatmmHostArgs<
|
||||
ck_tile::FlatmmScalePointer<-1>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>,
|
||||
ck_tile::FlatmmScalePointer<1>>;
|
||||
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
|
||||
p_sorted_expert_weight_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
num_tokens,
|
||||
experts,
|
||||
topk,
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
nullptr,
|
||||
scale_b_shuffle_dev_ptr,
|
||||
exp_bias_dev_ptr};
|
||||
|
||||
invoke_a16w4_moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
kind>(warmup, repeat, gemm_desc);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("validate"))
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(IsInputGemm ? num_tokens * topk : num_tokens,
|
||||
outputN,
|
||||
stride_C,
|
||||
is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::HostTensor<AccDataType> scale_A(
|
||||
ck_tile::HostTensorDescriptor({1, K / ScaleGranularityK}, {1, 1}));
|
||||
|
||||
// scaleA = 1 has no effect on the result
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(scale_A);
|
||||
ck_tile::DeviceMem scale_A_dev_buf(scale_A.get_element_space_size_in_bytes());
|
||||
scale_A_dev_buf.ToDevice(scale_A.data());
|
||||
|
||||
// convert scale_b from e8m0 to float
|
||||
ck_tile::HostTensor<AccDataType> scale_b_float(ck_tile::HostTensorDescriptor(
|
||||
{K * experts / ScaleGranularityK, N / ScaleGranularityN}, {N / ScaleGranularityN, 1}));
|
||||
std::copy(scale_b.begin(), scale_b.end(), scale_b_float.begin());
|
||||
ck_tile::DeviceMem scale_b_float_dev_buf(scale_b_float.get_element_space_size_in_bytes());
|
||||
scale_b_float_dev_buf.ToDevice(scale_b_float.data());
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
|
||||
c_m_n_ref_buf->SetZero();
|
||||
|
||||
ck_tile::reference_moe_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
static_cast<int>(kind),
|
||||
ck_tile::moe::Swiglu>(
|
||||
p_sorted_token_ids_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
static_cast<const ADataType*>(a_m_k_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<const BDataType*>(b_origin_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
|
||||
p_sorted_expert_weight_dev,
|
||||
num_tokens,
|
||||
MPerBlock,
|
||||
topk,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
M,
|
||||
1,
|
||||
ScaleGranularityK,
|
||||
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()));
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
474
example/ck_tile/21_moe_flatmm/moe_flatmm.cpp
Normal file
474
example/ck_tile/21_moe_flatmm/moe_flatmm.cpp
Normal file
@@ -0,0 +1,474 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "moe_flatmm.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_gemm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
constexpr int N_Warp_Tile = FlatmmConfig::N_Warp_Tile;
|
||||
constexpr int N_Warp = FlatmmConfig::N_Warp;
|
||||
constexpr int KPerLane = FlatmmConfig::K_Warp_Tile / (64 / N_Warp_Tile);
|
||||
|
||||
ck_tile::HostTensor<T> t_view({n_ / N_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
k_ / (64 * KPerLane / N_Warp_Tile),
|
||||
64 / N_Warp_Tile,
|
||||
KPerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
// gemm1
|
||||
// operand-A = [num_token, d_model]
|
||||
// operand-B = [num_expert, hidden, d_model]
|
||||
// operand-C = [num_token, topk, hidden]
|
||||
|
||||
// gemm2
|
||||
// operand-A = [num_token, topk, hidden]
|
||||
// operand-B = [num_expert, d_model, hidden]
|
||||
// operand-C = [num_token, d_model]
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind moe_kind = ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename ScaleM,
|
||||
typename ScaleN>
|
||||
float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
false, // UsePersistentKernel_
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>; // Preshuffle_
|
||||
|
||||
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
static_assert(
|
||||
FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0,
|
||||
"requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
|
||||
}
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up
|
||||
? 2
|
||||
: 1; // determined by scale shuffle pattern
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using Kernel = ck_tile::
|
||||
MoeFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue, moe_kind>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2 ? args.NumTokens * args.TopK
|
||||
: args.NumTokens,
|
||||
args.K,
|
||||
args.stride_A,
|
||||
is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N * args.NumExperts, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
const int outputN =
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? args.N / 2 : args.N;
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.NumTokens * args.N * sizeof(CDataType), s.stream_id_));
|
||||
else if(args.k_batch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(args.e_ptr,
|
||||
0,
|
||||
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_moe_flatmm_example.inc"
|
||||
|
||||
template <template <typename PreType> typename FlatmmConfig>
|
||||
int run_moe_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
const std::string prec_type = arg_parser.get_str("prec");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
const std::string gemm_kind = arg_parser.get_str("gemm_kind");
|
||||
if(gemm_kind == "gemm1_gate_up")
|
||||
{
|
||||
if(prec_type == "fp8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "fp16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm1_gate_only")
|
||||
{
|
||||
if(prec_type == "fp8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "fp16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm2")
|
||||
{
|
||||
if(prec_type == "fp8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::bfloat16_t,
|
||||
FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "fp16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
|
||||
"[gemm1_gate_only | gemm1_gate_up | gemm2]");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_moe_flatmm_example<FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
return !run_moe_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 2)
|
||||
{
|
||||
return !run_moe_flatmm_example<FlatmmConfig16_950>(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return !run_moe_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
}
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
202
example/ck_tile/21_moe_flatmm/moe_flatmm.hpp
Normal file
202
example/ck_tile/21_moe_flatmm/moe_flatmm.hpp
Normal file
@@ -0,0 +1,202 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 32;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32_950 : public FlatmmConfig32<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 64;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 64;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
N_Tile / FlatmmConfig16<DataType>::N_Warp_Tile / FlatmmConfig16<DataType>::N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false; // N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename ADataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::bf16_t;
|
||||
};
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_8bit_type
|
||||
: std::bool_constant<std::is_same_v<T, ck_tile::fp8_t> || std::is_same_v<T, ck_tile::bf8_t>>
|
||||
{
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("experts", "8", "Num of experts - 8 by default")
|
||||
.insert("NumTokens", "128", "M dimensions - 128 by default.")
|
||||
.insert("TopK", "3", "Top K - 3 by default.")
|
||||
.insert("N", "4096", "N dimensions - 4096 by default.")
|
||||
.insert("K", "4096", "K dimensions - 4096 by default.")
|
||||
.insert("stride_A", "", "Tensor A strides - it is empty by default.")
|
||||
.insert("stride_B", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_C", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("gemm_kind",
|
||||
"gemm1_gate_only",
|
||||
"Gemm kind in FFN network [gemm1_gate_only | gemm1_gate_up | gemm2] - "
|
||||
"gemm1_gate_only by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert(
|
||||
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
323
example/ck_tile/21_moe_flatmm/run_moe_flatmm_example.inc
Normal file
323
example/ck_tile/21_moe_flatmm/run_moe_flatmm_example.inc
Normal file
@@ -0,0 +1,323 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename MoeHostArgs>
|
||||
float invoke_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
|
||||
{
|
||||
float ave_time = moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
kind,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::string op_name{"Moe Gemm"};
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
|
||||
sizeof(BDataType) * args.N * args.K +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_moe_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
constexpr int ScaleGranularityM = 1;
|
||||
constexpr int ScaleGranularityN = 1;
|
||||
|
||||
const ck_tile::index_t N = arg_parser.get_int("N");
|
||||
const ck_tile::index_t K = arg_parser.get_int("K");
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_A");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_B");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_C");
|
||||
const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
|
||||
const ck_tile::index_t topk = arg_parser.get_int("TopK");
|
||||
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
|
||||
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
|
||||
const ck_tile::index_t experts = arg_parser.get_int("experts");
|
||||
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
|
||||
|
||||
ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * MPerBlock * topk;
|
||||
ck_tile::index_t valid_tile_num = sorted_tile_num;
|
||||
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
|
||||
const ck_tile::index_t M = sorted_tile_num * MPerBlock;
|
||||
const ck_tile::index_t outputN = kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? N / 2 : N;
|
||||
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr bool IsInputGemm = kind != ck_tile::MoeFlatmmKind::kFFN_gemm2;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout)));
|
||||
auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
|
||||
is_row_major(b_layout)
|
||||
? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
|
||||
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
|
||||
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
|
||||
auto b_shuffle_host = shuffle_b<FlatmmConfig>(b_k_n_tensor);
|
||||
|
||||
std::cout << "moe_flatmm:"
|
||||
<< "\n num_experts: " << experts << "\n num_tokens: " << num_tokens
|
||||
<< "\n topk: " << topk << "\n sorted_tile_num: " << sorted_tile_num
|
||||
<< "\n a_m_k: " << a_m_k_tensor.mDesc << "\n b_k_n: " << b_k_n_tensor.mDesc
|
||||
<< "\n b_shuffle: " << b_shuffle_host.mDesc << "\n c_m_n: " << c_m_n_tensor.mDesc
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf{a_m_k_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_origin_dev_buf{b_k_n_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf{b_shuffle_host.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem c_m_n_dev_buf{c_m_n_tensor.get_element_space_size_in_bytes()};
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_tensor.data());
|
||||
b_origin_dev_buf.ToDevice(b_k_n_tensor.data());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_tensor.SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
const void* p_b_origin = b_origin_dev_buf.GetDeviceBuffer();
|
||||
const void* p_b_shuffle = b_shuffle_dev_buf.GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
|
||||
// TODO: malloc and init sorted tokens and max tokens buffer
|
||||
|
||||
ck_tile::HostTensor<ck_tile::index_t> expert_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> expert_weight(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
|
||||
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> per_token_scale(
|
||||
ck_tile::HostTensorDescriptor({IsInputGemm ? num_tokens : M}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(
|
||||
ck_tile::HostTensorDescriptor({N * experts}, {1}));
|
||||
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_channel_scale);
|
||||
|
||||
// for verification only, no need to satify weight normalization
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_dev{sorted_token_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_ids_dev{expert_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem max_token_id_dev{max_token_id.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_weight_dev{expert_weight.get_element_space_size_in_bytes()};
|
||||
|
||||
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem per_channel_scale_dev_buf(
|
||||
per_channel_scale.get_element_space_size_in_bytes());
|
||||
|
||||
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
|
||||
// int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
|
||||
}
|
||||
|
||||
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
// int token_per_tile = num_tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
// sorted_token_ids.mData[0] = 0;
|
||||
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = num_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.data());
|
||||
expert_weight_dev.ToDevice(expert_weight.data());
|
||||
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
|
||||
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
|
||||
|
||||
const ck_tile::index_t* p_sorted_token_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(sorted_token_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_expert_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(expert_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_max_token_id_dev =
|
||||
static_cast<ck_tile::index_t*>(max_token_id_dev.GetDeviceBuffer());
|
||||
const AccDataType* p_sorted_expert_weight_dev =
|
||||
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
|
||||
|
||||
using MoeFlatmmArgs =
|
||||
ck_tile::MoeFlatmmHostArgs<ck_tile::FlatmmScalePointer<ScaleGranularityM>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN>>;
|
||||
|
||||
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
|
||||
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
|
||||
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
|
||||
p_sorted_expert_weight_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
p_a,
|
||||
p_b_shuffle,
|
||||
p_c,
|
||||
num_tokens,
|
||||
experts,
|
||||
topk,
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
per_token_scale_dev_ptr,
|
||||
per_channel_scale_dev_ptr};
|
||||
|
||||
invoke_moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
kind>(warmup, repeat, gemm_desc);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("validate"))
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(IsInputGemm ? num_tokens * topk : num_tokens,
|
||||
outputN,
|
||||
stride_C,
|
||||
is_row_major(CLayout{})));
|
||||
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
|
||||
|
||||
c_m_n_ref_buf->SetZero();
|
||||
|
||||
ck_tile::reference_moe_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
static_cast<int>(kind),
|
||||
ck_tile::moe::MoeSilu>(
|
||||
p_sorted_token_ids_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b_origin),
|
||||
static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
|
||||
p_sorted_expert_weight_dev,
|
||||
num_tokens,
|
||||
MPerBlock,
|
||||
topk,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
1,
|
||||
1,
|
||||
K,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1 /*kbatch*/, max_accumulated_value);
|
||||
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -21,6 +21,7 @@ add_subdirectory(18_flatmm)
|
||||
add_subdirectory(19_gemm_multi_d)
|
||||
add_subdirectory(20_grouped_convolution)
|
||||
add_subdirectory(21_elementwise)
|
||||
add_subdirectory(21_moe_flatmm)
|
||||
add_subdirectory(35_batched_transpose)
|
||||
add_subdirectory(37_transpose)
|
||||
add_subdirectory(38_block_scale_gemm)
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/e8m0.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
|
||||
@@ -41,10 +41,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
|
||||
{
|
||||
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
|
||||
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
|
||||
r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
// r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
// r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
// r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
// r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -302,12 +302,12 @@ struct buffer_load_if<16, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
|
||||
static_assert(sizeof(mbuf_t) == sizeof(T));
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
@@ -336,12 +336,12 @@ struct buffer_load_if<8, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
@@ -369,12 +369,12 @@ struct buffer_load_if<4, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
@@ -402,12 +402,12 @@ struct buffer_load_if<2, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
@@ -435,12 +435,12 @@ struct buffer_load_if<1, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
@@ -624,7 +624,7 @@ struct buffer_store_if<16>
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = fp32x4_t;
|
||||
using mbuf_t = fp32x4_t;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
@@ -681,7 +681,7 @@ struct buffer_store_if<4>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = float;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_dword %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
@@ -709,7 +709,7 @@ struct buffer_store_if<2>
|
||||
{
|
||||
static_assert(sizeof(T) == 2);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = short;
|
||||
using mbuf_t = short;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_short %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
@@ -737,7 +737,7 @@ struct buffer_store_if<1>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = float;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_byte %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
@@ -1246,6 +1246,14 @@ CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
|
||||
|
||||
// buffer atomic-add bf16
|
||||
CK_TILE_DEVICE_EXTERN cktile_llvm_bf16x2_t llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
cktile_llvm_bf16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2bf16");
|
||||
|
||||
// buffer atomic-add i32
|
||||
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
int32_t vdata,
|
||||
@@ -1448,9 +1456,12 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
((std::is_same<T, int8_t>::value || std::is_same<T, e8m0_t>::value) &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_int4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
|
||||
(std::is_same<T, pk_fp4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
|
||||
"wrong! not implemented");
|
||||
|
||||
using rtn_type = thread_buffer<T, N>;
|
||||
@@ -2194,6 +2205,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
{
|
||||
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
|
||||
@@ -2287,6 +2299,40 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value)
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
bit_cast<cktile_llvm_bf16x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
src_thread_data.template get_as<cktile_llvm_bf16x2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(bf16x2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
src_thread_data.template get_as<cktile_llvm_bf16x2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(bf16x2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
@@ -2835,9 +2881,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
|
||||
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
__attribute__((address_space(3))) cktile_llvm_bf16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) cktile_llvm_bf16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
|
||||
@@ -1308,15 +1308,20 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
static_assert(
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, bf16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, int32_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
((std::is_same<T, int8_t>::value || std::is_same<T, e8m0_t>::value) &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_int4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
|
||||
(std::is_same<T, pk_fp4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
|
||||
"wrong! not implemented");
|
||||
|
||||
using rtn_type = thread_buffer<T, N>;
|
||||
@@ -1964,6 +1969,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
{
|
||||
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
|
||||
@@ -2057,6 +2063,40 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value)
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
bit_cast<cktile_llvm_bf16x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
src_thread_data.template get_as<cktile_llvm_bf16x2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(bf16x2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
src_thread_data.template get_as<cktile_llvm_bf16x2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(bf16x2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
@@ -2605,9 +2645,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
|
||||
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
__attribute__((address_space(3))) cktile_llvm_bf16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) cktile_llvm_bf16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
|
||||
@@ -14,6 +14,14 @@ CK_TILE_HOST_DEVICE T add(const T& a, const T& b)
|
||||
return type_convert<T>(type_convert<ComputeType>(a) + type_convert<ComputeType>(b));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x2_t add_fp16x2_t(const fp16x2_t& a, const fp16x2_t& b)
|
||||
{
|
||||
fp16x2_t rtn;
|
||||
rtn[0] = add<fp16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp16_t, float>(a[1], b[1]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
|
||||
{
|
||||
bf16x2_t rtn;
|
||||
@@ -87,6 +95,37 @@ CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b)
|
||||
template <typename X>
|
||||
CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, const fp16x2_t& x)
|
||||
{
|
||||
union U32FP162_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
fp16x2_t* fp162_a;
|
||||
};
|
||||
|
||||
union U32FP162
|
||||
{
|
||||
uint32_t u32;
|
||||
fp16x2_t fp162;
|
||||
};
|
||||
|
||||
U32FP162_ADDR dword_addr;
|
||||
U32FP162 cur_v;
|
||||
U32FP162 new_;
|
||||
uint32_t old_v, new_v;
|
||||
dword_addr.fp162_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.fp162 = add_fp16x2_t(cur_v.fp162, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
|
||||
{
|
||||
|
||||
@@ -165,6 +165,9 @@ struct sequence
|
||||
return sequence<Is..., Xs...>{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto sum() { return (Is + ... + 0); }
|
||||
CK_TILE_HOST_DEVICE static constexpr auto product() { return (Is * ... * 1); }
|
||||
|
||||
// pickup element at index <Ids...>
|
||||
template <index_t... Ids>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto extract(number<Ids>...)
|
||||
@@ -1236,9 +1239,8 @@ constexpr auto reverse_slice_sequence(Seq,
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
|
||||
constexpr auto slice_sequence(Seq,
|
||||
number<SliceSize>,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
constexpr auto
|
||||
slice_sequence(Seq, number<SliceSize>, Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
constexpr auto r =
|
||||
reverse_slice_sequence(Seq{}.reverse(), number<SliceSize>{}, Mask{}.reverse());
|
||||
|
||||
101
include/ck_tile/core/numeric/e8m0.hpp
Normal file
101
include/ck_tile/core/numeric/e8m0.hpp
Normal file
@@ -0,0 +1,101 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Unsigned representation of a conventional biased Float32 exponent.
|
||||
*
|
||||
* bias = 127;
|
||||
*
|
||||
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
|
||||
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
|
||||
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
|
||||
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
|
||||
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
|
||||
* E8M0_MIN = 0b00000000; => 2^-127
|
||||
* E8M0_MAX = 0b11111110; => 2^127
|
||||
* E8M0_NAN = 0b11111111; => NaN
|
||||
*/
|
||||
|
||||
struct e8m0_bexp_t
|
||||
{
|
||||
using raw_type = uint8_t;
|
||||
using type = raw_type;
|
||||
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale) : data(0)
|
||||
{
|
||||
data = numeric_utils<float>::get_exponent(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const;
|
||||
|
||||
constexpr bool operator==(const e8m0_bexp_t& other) const { return data == other.data; }
|
||||
|
||||
constexpr bool operator!=(const e8m0_bexp_t& other) const { return data != other.data; }
|
||||
};
|
||||
|
||||
using e8m0_t = e8m0_bexp_t;
|
||||
using e8m0_raw_t = typename e8m0_t::raw_type;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<e8m0_t>
|
||||
{
|
||||
using bitwise_type = e8m0_raw_t;
|
||||
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 0;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr int PackedSize = 1;
|
||||
};
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<e8m0_t>
|
||||
{
|
||||
static constexpr e8m0_raw_t binary_min = 0b00000000; // 2^-127
|
||||
static constexpr e8m0_raw_t binary_max = 0b11111110; // 2^127
|
||||
static constexpr e8m0_raw_t binary_nan = 0b11111111;
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t min() { return e8m0_t{binary_min}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t max() { return e8m0_t{binary_max}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t quiet_NaN() { return e8m0_t{binary_nan}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t signaling_NaN() { return e8m0_t{binary_nan}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t epsilon() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t round_error() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t zero() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t infinity() { return signaling_NaN(); }
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t::operator float() const
|
||||
{
|
||||
using traits = numeric_traits<float>;
|
||||
if(data == numeric<e8m0_t>::binary_nan)
|
||||
{
|
||||
return traits::NaN;
|
||||
}
|
||||
else if(data == 0)
|
||||
{
|
||||
return std::numeric_limits<float>::min();
|
||||
}
|
||||
else
|
||||
{
|
||||
return bit_cast<float>(static_cast<traits::bitwise_type>(data) << traits::mant);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -12,15 +12,19 @@ struct numeric_utils : numeric_traits<T>
|
||||
|
||||
using traits = numeric_traits<T>;
|
||||
using _numeric = numeric<T>;
|
||||
using raw_type = typename T::raw_type;
|
||||
using raw_type = typename traits::bitwise_type;
|
||||
|
||||
static constexpr int exp_mask = (1 << traits::exp) - 1;
|
||||
|
||||
static constexpr int get_exponent(raw_type x)
|
||||
static constexpr raw_type get_exponent(raw_type x)
|
||||
{
|
||||
// TODO: check if repeated calls are optimized.
|
||||
return (x >> traits::mant) & exp_mask;
|
||||
}
|
||||
static constexpr raw_type get_exponent(const T& x)
|
||||
{
|
||||
return get_exponent(bit_cast<raw_type>(x));
|
||||
}
|
||||
static constexpr bool is_positive(raw_type x)
|
||||
{
|
||||
return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero;
|
||||
@@ -33,7 +37,7 @@ struct numeric_utils : numeric_traits<T>
|
||||
static constexpr double get_mantissa(raw_type x)
|
||||
{
|
||||
double mantissa = is_subnormal(x) ? 0.0f : 1.0f;
|
||||
for(uint32_t i = 0; i < traits::mant; ++i)
|
||||
for(raw_type i = 0; i < traits::mant; ++i)
|
||||
{
|
||||
mantissa += std::ldexp(static_cast<float>(x & 0b1), -(traits::mant - i));
|
||||
x >>= 1;
|
||||
@@ -43,22 +47,23 @@ struct numeric_utils : numeric_traits<T>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp = 127)
|
||||
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale = 1.f)
|
||||
{
|
||||
using utils = numeric_utils<T>;
|
||||
static constexpr int e8m0_bias = 127; // TODO: make it generic.
|
||||
float sign = utils::is_positive(data) ? 1.0 : -1.0;
|
||||
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
|
||||
float mant = utils::get_mantissa(data);
|
||||
using utils = numeric_utils<T>;
|
||||
float sign = utils::is_positive(data) ? 1.0 : -1.0;
|
||||
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
|
||||
float mant = utils::get_mantissa(data);
|
||||
|
||||
return std::ldexp(sign * mant, exp + scale_exp - e8m0_bias);
|
||||
return std::ldexp(sign * mant * scale, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value)
|
||||
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value, float scale = 1.f)
|
||||
{
|
||||
using bitwise_type = typename numeric_traits<T>::bitwise_type;
|
||||
|
||||
value /= scale;
|
||||
|
||||
if(std::abs(value) > float(numeric<T>::max()))
|
||||
{
|
||||
float max_value = numeric<T>::max();
|
||||
|
||||
@@ -23,14 +23,11 @@ using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float);
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f);
|
||||
|
||||
// TODO: Add stochastic method
|
||||
struct pk_float4_e2m1_t
|
||||
{
|
||||
static constexpr int exponent = 2;
|
||||
static constexpr int mantissa = 1;
|
||||
static constexpr int bias = 1;
|
||||
// TODO: Can we merge raw_type and type?
|
||||
using raw_type = uint8_t;
|
||||
using type = raw_type;
|
||||
@@ -41,18 +38,27 @@ struct pk_float4_e2m1_t
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast<type>(init)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init) : data{float_to_e2m1(init)}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f)
|
||||
: data{float_to_e2m1(init, scale)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t to_fp16(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const { return to_fp16(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE raw_type unpack(number<I>) const;
|
||||
@@ -193,131 +199,160 @@ CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
|
||||
} // namespace impl
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16_t>(data);
|
||||
return impl::_from_f4<bf16_t>(data, scale);
|
||||
#else
|
||||
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
|
||||
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16x2_t() const
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16x2_t>(data);
|
||||
return impl::_from_f4<bf16x2_t>(data, scale);
|
||||
#else
|
||||
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
|
||||
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
|
||||
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
|
||||
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
|
||||
// TODO: make float_to_e2m1 generic so that we can convert from directrly.
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return convert_to_type<pk_fp4_t>(x);
|
||||
return convert_to_type<pk_fp4_t>(x, scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x) { return fp32x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x) { return fp16x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x) { return bf16x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x) { return float_to_e2m1(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale = 1.0f)
|
||||
{
|
||||
return float_to_e2m1(x, scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return float_to_e2m1(type_convert<float>(x));
|
||||
return float_to_e2m1(type_convert<float>(x), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return float_to_e2m1(type_convert<float>(x));
|
||||
return float_to_e2m1(type_convert<float>(x), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
|
||||
float_to_e2m1(type_convert<float>(x[1])));
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
|
||||
float_to_e2m1(type_convert<float>(x[1])));
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(x[0]), float_to_e2m1(x[1]));
|
||||
return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_fp32x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_fp16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_bf16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_float(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_fp16(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_bf16(scale);
|
||||
}
|
||||
|
||||
#if TEST_convert_with_table == 0
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32_t>(data);
|
||||
return impl::_from_f4<fp32_t>(data, scale);
|
||||
#else
|
||||
return convert_to_float<pk_fp4_t>(unpack(number<0>{}));
|
||||
return convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32x2_t>(data);
|
||||
return impl::_from_f4<fp32x2_t>(data, scale);
|
||||
#else
|
||||
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{})),
|
||||
convert_to_float<pk_fp4_t>(unpack(number<1>{}))};
|
||||
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale),
|
||||
convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale)};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16_t>(data);
|
||||
return impl::_from_f4<fp16_t>(data, scale);
|
||||
#else
|
||||
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
|
||||
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16x2_t>(data);
|
||||
return impl::_from_f4<fp16x2_t>(data, scale);
|
||||
#else
|
||||
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
|
||||
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
|
||||
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
|
||||
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
|
||||
{
|
||||
return e2m1_to_fp32_table[data & 0xf];
|
||||
return e2m1_to_fp32_table[unpack(number<0>{})] * scale;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
|
||||
{
|
||||
return fp32x2_t{e2m1_to_fp32_table[data & 0xf], e2m1_to_fp32_table[data >> 4]};
|
||||
return fp32x2_t{e2m1_to_fp32_table[unpack(number<0>{})] * scale, e2m1_to_fp32_table[unpack(number<1>{}] * scale};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
|
||||
{
|
||||
return e2m1_to_fp16_table[data & 0xf];
|
||||
return type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
|
||||
{
|
||||
return fp16x2_t{e2m1_to_fp16_table[data & 0xf], e2m1_to_fp16_table[data >> 4]};
|
||||
return fp16x2_t{
|
||||
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale),
|
||||
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<1>{})]) * scale)};
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -64,6 +64,7 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
|
||||
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -71,16 +72,36 @@ CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
|
||||
CK_TILE_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
|
||||
CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
|
||||
CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr Y scaled_type_convert(X x, float scale);
|
||||
|
||||
#define CK_TILE_SCALED_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ scaled_type_convert<dtype_, stype_>(stype_ x, \
|
||||
float scale) \
|
||||
{ \
|
||||
return sname_##_to_##dname_(x, scale); \
|
||||
} \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return sname_##_to_##dname_(x, 1.f); \
|
||||
}
|
||||
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(float, float, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(bf16_t, bf16, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp16_t, fp16, pk_fp4_t, pk_fp4)
|
||||
#undef CK_TILE_SCALED_TYPE_CONVERT
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/numeric/e8m0.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -88,7 +89,13 @@ template <typename T, typename>
|
||||
struct vector_traits
|
||||
{
|
||||
using scalar_type =
|
||||
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>, int8_t, remove_cvref_t<T>>;
|
||||
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>,
|
||||
int8_t,
|
||||
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_fp4_t> ||
|
||||
std::is_same_v<remove_cvref_t<T>, e8m0_t>,
|
||||
uint8_t,
|
||||
remove_cvref_t<T>>>;
|
||||
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
@@ -96,7 +103,13 @@ struct vector_traits
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N)))>
|
||||
{
|
||||
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
|
||||
using scalar_type = std::conditional_t<
|
||||
std::is_same_v<T, pk_int4_t>,
|
||||
int8_t,
|
||||
std::conditional_t<std::is_same_v<T, pk_fp4_t> || std::is_same_v<remove_cvref_t<T>, e8m0_t>,
|
||||
uint8_t,
|
||||
T>>;
|
||||
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
@@ -138,6 +151,9 @@ using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
|
||||
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
using cktile_llvm_bf16x2_t = __bf16 __attribute__((ext_vector_type(2)));
|
||||
using cktile_llvm_bf16x4_t = __bf16 __attribute__((ext_vector_type(4)));
|
||||
|
||||
// i32
|
||||
// using int32_t = ...
|
||||
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
|
||||
@@ -237,4 +253,13 @@ using pk_int4x4_t = int8_t __attribute((ext_vector_type(4)));
|
||||
using pk_int4x8_t = int8_t __attribute((ext_vector_type(8)));
|
||||
using pk_int4x16_t = int8_t __attribute((ext_vector_type(16)));
|
||||
using pk_int4x32_t = int8_t __attribute((ext_vector_type(32)));
|
||||
|
||||
// pk_fp4_t
|
||||
// using pk_fp4_t
|
||||
using pk_fp4x2_t = uint8_t __attribute((ext_vector_type(2)));
|
||||
using pk_fp4x4_t = uint8_t __attribute((ext_vector_type(4)));
|
||||
using pk_fp4x8_t = uint8_t __attribute((ext_vector_type(8)));
|
||||
using pk_fp4x16_t = uint8_t __attribute((ext_vector_type(16)));
|
||||
using pk_fp4x32_t = uint8_t __attribute((ext_vector_type(32)));
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -653,14 +653,24 @@ struct buffer_view<address_space_enum::global,
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
|
||||
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
|
||||
||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
|
||||
#endif
|
||||
;
|
||||
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
|
||||
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
|
||||
#elif (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
|
||||
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
|
||||
||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
|
||||
#endif
|
||||
;
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
@@ -599,6 +599,88 @@ struct tile_scatter_gather
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<0>{}];
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
valids_[idx_gather],
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
forward_step_scatter);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
|
||||
@@ -409,7 +409,13 @@ struct HostTensor
|
||||
}
|
||||
|
||||
// void SetZero() { ck_tile::ranges::fill<T>(mData, 0); }
|
||||
void SetZero() { std::fill(mData.begin(), mData.end(), 0); }
|
||||
void SetZero()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, e8m0_t>)
|
||||
std::fill(mData.begin(), mData.end(), e8m0_t{1.f});
|
||||
else
|
||||
std::fill(mData.begin(), mData.end(), 0);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
|
||||
|
||||
@@ -32,6 +32,11 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int MaxThreadPerBlock, typename Kernel, typename... Args>
|
||||
__launch_bounds__(MaxThreadPerBlock) __global__ void kentry2(Args... args)
|
||||
{
|
||||
Kernel{}(args...);
|
||||
}
|
||||
//
|
||||
// return a anonymous functor(lambda) to be called later
|
||||
// the KernelImpl should be a class without non-static data member, or let's say
|
||||
|
||||
@@ -273,6 +273,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
@@ -285,6 +293,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
@@ -299,6 +315,121 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
__global__ void blockwise_gemm_kernel(ADataType* A,
|
||||
BDataType* B,
|
||||
CDataType* C,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t strideA,
|
||||
ck_tile::index_t strideB,
|
||||
ck_tile::index_t strideC,
|
||||
ck_tile::index_t scale_granularity_m,
|
||||
ck_tile::index_t scale_granularity_n,
|
||||
ck_tile::index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row = idx / N; // Compute row index
|
||||
int col = idx % N; // Compute column index
|
||||
|
||||
if(row < M && col < N)
|
||||
{
|
||||
AccDataType acc = 0.0, acc_temp = 0.0;
|
||||
|
||||
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
|
||||
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
|
||||
|
||||
float scale_A = 0;
|
||||
float scale_B = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
if(k % scale_granularity_k == 0)
|
||||
{
|
||||
// update acc
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
acc_temp = 0.0;
|
||||
// update scale factors
|
||||
scale_A = scale_A_ptr[(row / scale_granularity_m) +
|
||||
(k / scale_granularity_k) * scale_A_stride];
|
||||
scale_B = scale_B_ptr[(col / scale_granularity_n) +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
}
|
||||
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideA + k
|
||||
: k * strideA + row;
|
||||
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col;
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
}
|
||||
acc_temp += v_a * v_b;
|
||||
}
|
||||
// final accumulation
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideC + col
|
||||
: col * strideC + row;
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -327,6 +458,51 @@ void reference_gemm_gpu(ADataType* a_ptr,
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_blockwise_gemm_gpu(ADataType* a_ptr,
|
||||
BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_a,
|
||||
index_t stride_b,
|
||||
index_t stride_c,
|
||||
index_t scale_granularity_m,
|
||||
index_t scale_granularity_n,
|
||||
index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr)
|
||||
{
|
||||
int totalElements = M * N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
blockwise_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
|
||||
<<<numBlocks, numThreadsPerBlock>>>(a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
scale_granularity_m,
|
||||
scale_granularity_n,
|
||||
scale_granularity_k,
|
||||
scale_A_ptr,
|
||||
scale_B_ptr);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -364,4 +540,5 @@ void reference_batched_gemm_gpu(ADataType* a_ptr,
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
315
include/ck_tile/host/reference/reference_moe_gemm.hpp
Normal file
315
include/ck_tile/host/reference/reference_moe_gemm.hpp
Normal file
@@ -0,0 +1,315 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
|
||||
typename ActivationOp = identity>
|
||||
__global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
const ck_tile::index_t* p_sorted_expert_ids_,
|
||||
const ck_tile::index_t* p_max_token_id_,
|
||||
const ADataType* A,
|
||||
const BDataType* B,
|
||||
CDataType* C,
|
||||
const AccDataType* expert_weight_ptr,
|
||||
ck_tile::index_t Num_tokens,
|
||||
ck_tile::index_t TokensPerBlock,
|
||||
ck_tile::index_t TopK,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t strideA,
|
||||
ck_tile::index_t strideB,
|
||||
ck_tile::index_t strideC,
|
||||
index_t scale_granularity_m,
|
||||
index_t scale_granularity_n,
|
||||
index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr,
|
||||
float* expert_bias_ptr)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
int row = idx / problem_N; // Compute row index
|
||||
int col = idx % problem_N; // Compute column index
|
||||
|
||||
index_t gather_token_id = 0;
|
||||
index_t scatter_token_id = 0;
|
||||
index_t expert_id = 0;
|
||||
|
||||
if(row < p_max_token_id_[0])
|
||||
{
|
||||
expert_id = p_sorted_expert_ids_[row / TokensPerBlock];
|
||||
gather_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
|
||||
scatter_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
|
||||
if(gather_token_id >= Num_tokens)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if(MoeGemmKind == 2)
|
||||
{
|
||||
gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
|
||||
}
|
||||
else
|
||||
{
|
||||
scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if(row < M)
|
||||
{
|
||||
AccDataType acc = 0.0;
|
||||
AccDataType acc_up = 0.0;
|
||||
|
||||
AccDataType acc_temp = 0.0;
|
||||
AccDataType acc_up_temp = 0.0;
|
||||
|
||||
float scale_A = 0;
|
||||
float scale_B = 0;
|
||||
float scale_B_up = 0;
|
||||
|
||||
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
|
||||
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
|
||||
index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
if(k % scale_granularity_k == 0)
|
||||
{
|
||||
// update acc
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
acc_up += acc_up_temp * scale_A * scale_B_up;
|
||||
// reset acc temp
|
||||
acc_temp = 0.0;
|
||||
acc_up_temp = 0.0;
|
||||
// update scale factors
|
||||
scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
|
||||
(k / scale_granularity_k) * scale_A_stride];
|
||||
scale_B =
|
||||
scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
|
||||
(col + problem_N) / scale_granularity_n +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
}
|
||||
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? gather_token_id * strideA + k
|
||||
: k * strideA + gather_token_id;
|
||||
|
||||
long b_index =
|
||||
long(expert_id) * N * K +
|
||||
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
|
||||
: k * strideB + col);
|
||||
long b_index_up;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
b_index_up = long(expert_id) * N * K +
|
||||
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? (col + problem_N) * strideB + k
|
||||
: k * strideB + col + problem_N);
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
AccDataType v_b_up;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
{
|
||||
const fp32x2_t fp32_val_up =
|
||||
pk_int4_t_to_fp32x2_t(B[b_index_up / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b_up = fp32_val_up.hi;
|
||||
else
|
||||
v_b_up = fp32_val_up.lo;
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
{
|
||||
const fp32x2_t fp32_val_up = pk_fp4_to_fp32x2(B[b_index_up / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b_up = fp32_val_up.hi;
|
||||
else
|
||||
v_b_up = fp32_val_up.lo;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
v_b_up = ck_tile::type_convert<AccDataType>(B[b_index_up]);
|
||||
}
|
||||
acc_temp += v_a * v_b;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
acc_up_temp += v_a * v_b_up;
|
||||
}
|
||||
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
acc_up += acc_up_temp * scale_A * scale_B_up;
|
||||
|
||||
float bias = 0.f, bias_up = 0.f;
|
||||
if(expert_bias_ptr != nullptr)
|
||||
{
|
||||
bias = expert_bias_ptr[expert_id * N + col];
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
|
||||
}
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? scatter_token_id * strideC + col
|
||||
: col * strideC + scatter_token_id;
|
||||
if constexpr(MoeGemmKind < 2)
|
||||
{
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(
|
||||
ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
|
||||
}
|
||||
else
|
||||
{
|
||||
// moe gemm2 don't use activation.
|
||||
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
|
||||
using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
|
||||
ck_tile::fp16x2_t,
|
||||
ck_tile::bf16x2_t>;
|
||||
ResV2Type add_v{0, 0};
|
||||
if(c_index % 2)
|
||||
{
|
||||
// result is the second value of fp16 pair.
|
||||
add_v.y = res;
|
||||
}
|
||||
else
|
||||
{
|
||||
// result is the first value of fp16 pair.
|
||||
add_v.x = res;
|
||||
}
|
||||
// mask last bit to make sure atomicAdd pointer is aligned of DWORD.
|
||||
atomic_add<ResV2Type>(reinterpret_cast<ResV2Type*>(C + (c_index & 0xffff'fffe)), add_v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
|
||||
typename ActivationOp = identity>
|
||||
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
|
||||
const index_t* p_sorted_expert_ids_,
|
||||
const index_t* p_max_token_id_,
|
||||
const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const AccDataType* expert_weight_ptr,
|
||||
index_t Num_tokens,
|
||||
index_t TokensPerBlock,
|
||||
index_t TopK,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_a,
|
||||
index_t stride_b,
|
||||
index_t stride_c,
|
||||
index_t scale_granularity_m,
|
||||
index_t scale_granularity_n,
|
||||
index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr,
|
||||
float* exp_bias = nullptr)
|
||||
{
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
int totalElements = M * problem_N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
moe_gemm_kernel<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
LayoutA,
|
||||
LayoutB,
|
||||
LayoutC,
|
||||
MoeGemmKind,
|
||||
ActivationOp><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
|
||||
p_sorted_expert_ids_,
|
||||
p_max_token_id_,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
expert_weight_ptr,
|
||||
Num_tokens,
|
||||
TokensPerBlock,
|
||||
TopK,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
scale_granularity_m,
|
||||
scale_granularity_n,
|
||||
scale_granularity_k,
|
||||
scale_A_ptr,
|
||||
scale_B_ptr,
|
||||
exp_bias);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -953,7 +953,14 @@ struct Silu
|
||||
std::is_same_v<T, int32_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = x * (one / (one + ck_tile::exp(-x)));
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
{
|
||||
y = x * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x));
|
||||
}
|
||||
else
|
||||
{
|
||||
y = x * (one / (one + ck_tile::exp(-x)));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -1308,7 +1315,7 @@ struct Swish
|
||||
|
||||
struct SoftRelu
|
||||
{
|
||||
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
|
||||
SoftRelu(float alpha = 1.f) : alpha_(alpha) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1327,7 +1334,7 @@ struct SoftRelu
|
||||
struct Power
|
||||
{
|
||||
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
|
||||
: alpha_(alpha), beta_(beta), gamma_(gamma){};
|
||||
: alpha_(alpha), beta_(beta), gamma_(gamma) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1349,7 +1356,7 @@ struct Power
|
||||
|
||||
struct ClippedRelu
|
||||
{
|
||||
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
|
||||
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1368,7 +1375,7 @@ struct ClippedRelu
|
||||
|
||||
struct LeakyRelu
|
||||
{
|
||||
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
|
||||
LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1385,7 +1392,7 @@ struct LeakyRelu
|
||||
|
||||
struct Elu
|
||||
{
|
||||
Elu(float alpha = 1.f) : alpha_(alpha){};
|
||||
Elu(float alpha = 1.f) : alpha_(alpha) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1402,7 +1409,7 @@ struct Elu
|
||||
|
||||
struct Logistic
|
||||
{
|
||||
Logistic(float alpha = 1.f) : alpha_(alpha){};
|
||||
Logistic(float alpha = 1.f) : alpha_(alpha) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
|
||||
@@ -27,9 +27,11 @@ template <typename ADataType_,
|
||||
index_t KPerXdl_,
|
||||
bool isCTransposed_,
|
||||
memory_operation_enum MemoryOperation_,
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1>
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false,
|
||||
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
@@ -52,6 +54,8 @@ struct CShuffleEpilogueProblem
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -87,10 +91,14 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t KPerXdl = Problem::KPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
|
||||
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
@@ -190,7 +198,10 @@ struct CShuffleEpilogue
|
||||
}
|
||||
}();
|
||||
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle =
|
||||
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
|
||||
|
||||
static_assert(NumNXdlPerWavePerShuffle % BlockedXDLN_PerWarp == 0);
|
||||
|
||||
static constexpr auto MNPerIterationShuffle = [] {
|
||||
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
|
||||
@@ -239,14 +250,31 @@ struct CShuffleEpilogue
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
|
||||
{
|
||||
constexpr auto block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto block_outer_dstr_encoding = [] {
|
||||
if constexpr(BlockedXDLN_PerWarp == 1)
|
||||
{
|
||||
return tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
|
||||
// BlockedLayout
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
}
|
||||
}();
|
||||
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
|
||||
|
||||
@@ -258,7 +286,87 @@ struct CShuffleEpilogue
|
||||
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
|
||||
}
|
||||
|
||||
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem)
|
||||
{
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = 4;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
|
||||
constexpr int kN0 = NWave;
|
||||
constexpr int kN1 = NPerXdl;
|
||||
constexpr int kN2 = NRepeat;
|
||||
|
||||
using IntrThreadShuffleEncode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>;
|
||||
constexpr auto dram_tile_distribution =
|
||||
make_static_tile_distribution(IntrThreadShuffleEncode{});
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
// auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 0 * NRepeat] = type_convert<ODataType>(
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0]);
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 1 * NRepeat] = type_convert<ODataType>(
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1]);
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 2 * NRepeat] = type_convert<ODataType>(
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2]);
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 3 * NRepeat] = type_convert<ODataType>(
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3]);
|
||||
});
|
||||
|
||||
// c_out_tensor = cast_tile<ODataType>(c_out_tensor_fp32);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<!EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
@@ -284,8 +392,8 @@ struct CShuffleEpilogue
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
@@ -336,8 +444,336 @@ struct CShuffleEpilogue
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie(
|
||||
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
|
||||
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
|
||||
number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem,
|
||||
ScaleM scale_m,
|
||||
ScaleN scale_n)
|
||||
{
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = 4;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
static_assert(MPerXdl == 16, "TiledMMAPermuteN only supports MPerXdl = 16 now");
|
||||
|
||||
constexpr int kN0 = NWave;
|
||||
constexpr int kN1 = NPerXdl;
|
||||
constexpr int kN2 = NRepeat;
|
||||
|
||||
using IntrThreadShuffleEncode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>;
|
||||
constexpr auto dram_tile_distribution =
|
||||
make_static_tile_distribution(IntrThreadShuffleEncode{});
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
using ShuffleAcc =
|
||||
decltype(make_static_distributed_tensor<AccDataType>(dram_tile_distribution));
|
||||
ShuffleAcc shuffle_acc[MRepeat];
|
||||
auto c_out_tensor_fp32 =
|
||||
make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
const index_t iMLane = get_lane_id() / NPerXdl;
|
||||
const index_t iNLane = get_lane_id() % NPerXdl;
|
||||
|
||||
float vec_scale_A[kM2 * MRepeat];
|
||||
float vec_scale_B[NRepeat];
|
||||
|
||||
_Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
|
||||
{
|
||||
vec_scale_B[i] = scale_n[i + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl];
|
||||
}
|
||||
_Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
|
||||
{
|
||||
vec_scale_A[i * kM2 + 0] =
|
||||
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 1] =
|
||||
scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 2] =
|
||||
scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 3] =
|
||||
scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
}
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
shuffle_acc[mIter].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 0] *= vec_scale_B[n_idx];
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 1] *= vec_scale_B[n_idx];
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 2] *= vec_scale_B[n_idx];
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 3] *= vec_scale_B[n_idx];
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 0 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0] *
|
||||
vec_scale_A[mIter * kM2 + 0];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 1 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1] *
|
||||
vec_scale_A[mIter * kM2 + 1];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 2 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2] *
|
||||
vec_scale_A[mIter * kM2 + 2];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 3 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3] *
|
||||
vec_scale_A[mIter * kM2 + 3];
|
||||
});
|
||||
|
||||
c_out_tensor = cast_tile<ODataType>(c_out_tensor_fp32);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<!EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem,
|
||||
ScaleM scale_m,
|
||||
ScaleN scale_n)
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
|
||||
using LDSTileTensor = decltype(make_static_distributed_tensor<AccDataType>(LdsTileDistr));
|
||||
LDSTileTensor lds_tile[2];
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem), lds_block_desc);
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
LdsTileDistr);
|
||||
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
|
||||
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr int kM2 = 4; // Val
|
||||
constexpr int kM1 = (64 / NPerXdl); // Thr
|
||||
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
const index_t iMLane = get_lane_id() / NPerXdl;
|
||||
const index_t iNLane = get_lane_id() % NPerXdl;
|
||||
|
||||
float vec_scale_A[kM0 * kM2 * MRepeat];
|
||||
float vec_scale_B[NRepeat];
|
||||
|
||||
_Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
|
||||
{
|
||||
vec_scale_B[i] = scale_n[i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
}
|
||||
_Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
|
||||
{
|
||||
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||
{
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 0] =
|
||||
scale_m[0 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 1] =
|
||||
scale_m[1 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 2] =
|
||||
scale_m[2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 3] =
|
||||
scale_m[3 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
}
|
||||
}
|
||||
|
||||
lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||
{
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 0] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 1] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 2] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 3] * vec_scale_B[n_xdl];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
constexpr int read_stage = iAccess % 2;
|
||||
constexpr int write_stage = read_stage ^ 1;
|
||||
|
||||
block_sync_lds();
|
||||
constexpr auto idx_y_start = SFC::get_index(number<iAccess.value + 1>{});
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile[read_stage]);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
|
||||
if constexpr(iAccess < num_access - 1)
|
||||
{
|
||||
lds_tile[write_stage].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter * NumMXdlPerWavePerShuffle,
|
||||
nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||
{
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 0] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 1] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 2] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 3] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
|
||||
number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
|
||||
@@ -10,8 +10,11 @@
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v0.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -113,6 +113,7 @@ struct BlockFlatmmASmemBSmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,23 +11,171 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
struct FlatmmProblem
|
||||
{
|
||||
CK_TILE_HOST FlatmmProblem() = default;
|
||||
CK_TILE_HOST FlatmmProblem(
|
||||
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
|
||||
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN, int SharedGranularityK = 0>
|
||||
struct FlatmmScalePointer
|
||||
{
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = SharedGranularityK;
|
||||
|
||||
const float* ptr;
|
||||
index_t scale_stride = 1;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t stride)
|
||||
: ptr(ptr_), scale_stride(stride)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
|
||||
{
|
||||
FlatmmScalePointer ret;
|
||||
// if constexpr(GranularityMN == 0)
|
||||
// {
|
||||
// ret.scalar = scalar;
|
||||
// }
|
||||
// else if constexpr(GranularityMN == 1)
|
||||
// {
|
||||
// ret.ptr = ptr + offset;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// ret.ptr = ptr + offset / GranularityMN;
|
||||
// }
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const
|
||||
{
|
||||
if constexpr(GranularityMN == 1)
|
||||
{
|
||||
return ptr[i];
|
||||
}
|
||||
else
|
||||
{
|
||||
return ptr[i / GranularityMN];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN>
|
||||
struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
{
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
static_assert(GranularityMN != 0);
|
||||
|
||||
union
|
||||
{
|
||||
const float* ptr;
|
||||
float scalar; // if shared granularity is 0, all rows/columns use the same scale value
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(float scalar_) : scalar(scalar_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t stride)
|
||||
: ptr(ptr_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
|
||||
{
|
||||
FlatmmScalePointer ret;
|
||||
if constexpr(GranularityMN == 0)
|
||||
{
|
||||
ret.scalar = scalar;
|
||||
}
|
||||
else if constexpr(GranularityMN == 1)
|
||||
{
|
||||
ret.ptr = ptr + offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
ret.ptr = ptr + offset / GranularityMN;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer& advance() { return *this; }
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const
|
||||
{
|
||||
if constexpr(GranularityMN == 0)
|
||||
{
|
||||
return scalar;
|
||||
}
|
||||
else if constexpr(GranularityMN == 1)
|
||||
{
|
||||
return ptr[i];
|
||||
}
|
||||
else
|
||||
{
|
||||
return ptr[i / GranularityMN];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// shared granularityMN = -1 means no scale
|
||||
template <>
|
||||
struct FlatmmScalePointer<-1, 0>
|
||||
{
|
||||
static constexpr int GranularityMN = -1;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(float) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, [[maybe_unused]] index_t stride)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer& advance() { return *this; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
|
||||
{
|
||||
return FlatmmScalePointer{};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
|
||||
{
|
||||
return 1; // alway return 1, it doesn't change the result
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NumDTensor = 0>
|
||||
struct FlatmmHostArgs
|
||||
struct BaseFlatmmHostArgs
|
||||
{
|
||||
CK_TILE_HOST FlatmmHostArgs() = default;
|
||||
CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
CK_TILE_HOST BaseFlatmmHostArgs() = default;
|
||||
CK_TILE_HOST BaseFlatmmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
@@ -65,8 +213,51 @@ struct FlatmmHostArgs
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<>
|
||||
{
|
||||
CK_TILE_HOST ScaleFlatmmHostArgs() = default;
|
||||
CK_TILE_HOST ScaleFlatmmHostArgs(const void* a_ptr_,
|
||||
const void* b_shuffle_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* c_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_C_,
|
||||
ScaleM scale_m_ = nullptr,
|
||||
ScaleN scale_n_ = nullptr)
|
||||
: BaseFlatmmHostArgs(a_ptr_,
|
||||
b_shuffle_ptr_,
|
||||
ds_ptr_,
|
||||
c_ptr_,
|
||||
k_batch_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
stride_A_,
|
||||
stride_B_,
|
||||
stride_Ds_,
|
||||
stride_C_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
}
|
||||
ScaleM scale_m = nullptr;
|
||||
ScaleN scale_n = nullptr;
|
||||
};
|
||||
|
||||
template <index_t NumDTensor = 0>
|
||||
template <int NumberTensor = 0>
|
||||
using FlatmmHostArgs =
|
||||
ScaleFlatmmHostArgs<FlatmmScalePointer<-1>, FlatmmScalePointer<-1>, NumberTensor>;
|
||||
|
||||
template <class ScaleM, class ScaleN, index_t NumDTensor = 0>
|
||||
struct FlatmmKernelArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
@@ -82,6 +273,8 @@ struct FlatmmKernelArgs
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
index_t stride_E;
|
||||
index_t k_batch;
|
||||
ScaleM scale_m_ptr = nullptr;
|
||||
ScaleN scale_n_ptr = nullptr;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
|
||||
@@ -97,7 +290,8 @@ struct FlatmmKernel
|
||||
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
|
||||
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
@@ -113,7 +307,7 @@ struct FlatmmKernel
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
|
||||
// using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -124,37 +318,87 @@ struct FlatmmKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
assert(!UsePersistentKernel);
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
|
||||
{
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = FlatmmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size,
|
||||
FlatmmKernel,
|
||||
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
// << ", persistent_block_size: " << persistent_block_size
|
||||
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
||||
|
||||
assert(kargs.k_batch == 1);
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr KernelArgs
|
||||
MakeKernelArgs(const FlatmmHostArgs<NumDTensor>& hostArgs)
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_HOST static constexpr FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>
|
||||
MakeKernelArgs(const ScaleFlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
|
||||
{
|
||||
return KernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E,
|
||||
hostArgs.k_batch};
|
||||
return {hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E,
|
||||
hostArgs.k_batch,
|
||||
hostArgs.scale_m,
|
||||
hostArgs.scale_n};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
|
||||
{
|
||||
return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPongSize()
|
||||
{
|
||||
return FlatmmPipeline::GetSmemSize();
|
||||
}
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
template <class KernelArgs>
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto N1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<1>{});
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
@@ -170,11 +414,11 @@ struct FlatmmKernel
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead * kargs.stride_B;
|
||||
b_k_split_offset = k_id * KRead * kargs.stride_B * N1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead;
|
||||
b_k_split_offset = k_id * KRead * N1;
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
@@ -192,6 +436,7 @@ struct FlatmmKernel
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
template <class KernelArgs>
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
|
||||
{
|
||||
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
@@ -203,6 +448,14 @@ struct FlatmmKernel
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -337,7 +590,7 @@ struct FlatmmKernel
|
||||
return DTesnorIsValid;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
@@ -367,9 +620,9 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k /
|
||||
BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
index_t kFlatK =
|
||||
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
@@ -408,7 +661,7 @@ struct FlatmmKernel
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
@@ -417,7 +670,7 @@ struct FlatmmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
@@ -555,16 +808,18 @@ struct FlatmmKernel
|
||||
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
|
||||
}
|
||||
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void
|
||||
RunFlatmm(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_ping,
|
||||
void* smem_ptr_pong,
|
||||
const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
@@ -580,50 +835,74 @@ struct FlatmmKernel
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr);
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
kargs.scale_m_ptr + block_idx_m,
|
||||
kargs.scale_n_ptr + block_idx_n);
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr);
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
|
||||
int partition_idx = blockIdx.x) const
|
||||
{
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_flat_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
do
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_flat_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_ping[GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[GetSmemPongSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
partition_idx += gridDim.x;
|
||||
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
458
include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp
Normal file
458
include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp
Normal file
@@ -0,0 +1,458 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
|
||||
struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
|
||||
{
|
||||
using Underlying = FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
|
||||
using BlockGemmShape =
|
||||
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
static constexpr int QuantPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
static constexpr int N_Pack = 2;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
static constexpr auto I4 = number<4>();
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
// using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "mixed_prec_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
|
||||
{
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size,
|
||||
F16xMXF4FlatmmKernel,
|
||||
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
// << ", persistent_block_size: " << persistent_block_size
|
||||
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
||||
|
||||
assert(kargs.k_batch == 1);
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto scale_n = kargs.scale_n_ptr;
|
||||
|
||||
index_t FlatScaleK =
|
||||
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const e8m0_t*>(scale_n.ptr),
|
||||
make_tuple(FlatScaleN, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(
|
||||
a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_tensor_view = views.at(I1);
|
||||
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& d_tensor_view = views.at(I2);
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<FlatmmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_flat_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_block_window =
|
||||
make_tile_window(b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_n, i_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
auto scale_block_window =
|
||||
make_tile_window(views.at(I4),
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
|
||||
{i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
ds_block_window,
|
||||
e_block_window,
|
||||
scale_block_window);
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void
|
||||
RunFlatmm(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_ping,
|
||||
void* smem_ptr_pong,
|
||||
const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I4);
|
||||
|
||||
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|
||||
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|
||||
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
|
||||
"ScaleM and ScaleN should have the same GranularityK");
|
||||
constexpr bool DoEpiScale =
|
||||
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
|
||||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
|
||||
|
||||
auto a_block_window_with_distr =
|
||||
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
|
||||
a_block_window.get_window_lengths(),
|
||||
a_block_window.get_window_origin(),
|
||||
FlatmmPipeline::GetADramTileDistribution());
|
||||
const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
|
||||
b_flat_block_window,
|
||||
scale_block_window,
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(DoEpiScale)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
kargs.scale_m_ptr + block_idx_m,
|
||||
kargs.scale_n_ptr + block_idx_n);
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
|
||||
int partition_idx = blockIdx.x) const
|
||||
{
|
||||
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
do
|
||||
{
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
|
||||
splitk_batch_offset.b_k_split_offset / QuantPackedSize;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
partition_idx += gridDim.x;
|
||||
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,883 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseFlatmmPipelineAGmemBGmemCRegV0
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
|
||||
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
|
||||
{
|
||||
if (TailNumber::Even == tail_num)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
else if (TailNumber::Odd == tail_num)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
// assert(false);
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
// return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
|
||||
struct FlatmmPipelineAGmemBGmemCRegV0
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using BlockFlatmm =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
|
||||
|
||||
static constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
|
||||
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
|
||||
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto idxM = I0;
|
||||
static constexpr auto idxN = I1;
|
||||
static constexpr auto idxK = I2;
|
||||
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
|
||||
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
|
||||
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
|
||||
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
|
||||
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
static constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1;
|
||||
static constexpr index_t ACopyLoadNumPerK = ACopyLoadNum / KIterPerWarp;
|
||||
static constexpr index_t AcopyPerLoadM = kMPerBlock / ACopyLoadNum;
|
||||
static constexpr index_t BloadGap = MIterPerWarp / 2;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
static constexpr auto warp_m = WarpTile::at(idxM);
|
||||
static constexpr auto warp_n = WarpTile::at(idxN);
|
||||
static constexpr auto warp_k = WarpTile::at(idxK);
|
||||
|
||||
/*
|
||||
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1
|
||||
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1
|
||||
defined(USING_MFMA_16x16x16) && defined(ENABLE_FP16) // mi300 fp16 16c 0.5*K1
|
||||
defined(USING_MFMA_32x32x8) && defined(ENABLE_FP16) // mi300 fp16 32c 0.5*K1
|
||||
|
||||
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP8) // mi350 fp8 32c 2*K1
|
||||
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP8) // mi350 fp8 64c 2*K1
|
||||
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP16) // mi350 fp16 16c 1*K1
|
||||
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP16) // mi350 fp16 32c 1*K1
|
||||
|
||||
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1
|
||||
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
*/
|
||||
struct MfmaConfig
|
||||
{
|
||||
int mfma_per_wg;
|
||||
int dsread_per_wg;
|
||||
};
|
||||
static constexpr MfmaConfig GetMfmaConfig()
|
||||
{
|
||||
|
||||
// K1 per Mfma = 0.5 cases: mfma_per_wg = 2, dsread_per_wg = 1
|
||||
if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 &&
|
||||
std::is_same_v<ADataType, fp8_t>) ||
|
||||
(warp_m == 32 && warp_n == 32 && warp_k == 16 &&
|
||||
std::is_same_v<ADataType, fp8_t>) ||
|
||||
(warp_m == 16 && warp_n == 16 && warp_k == 16 &&
|
||||
std::is_same_v<ADataType, fp16_t>) ||
|
||||
(warp_m == 32 && warp_n == 32 && warp_k == 8 &&
|
||||
std::is_same_v<ADataType, fp16_t>))
|
||||
{
|
||||
return {2, 1};
|
||||
}
|
||||
// K1 per Mfma = 2 cases: mfma_per_wg = 1, dsread_per_wg = 2
|
||||
else if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 128 &&
|
||||
std::is_same_v<ADataType, fp8_t>) ||
|
||||
(warp_m == 32 && warp_n == 32 && warp_k == 64 &&
|
||||
std::is_same_v<ADataType, fp8_t>))
|
||||
{
|
||||
return {1, 2};
|
||||
}
|
||||
// K1 per Mfma = 1 cases: mfma_per_wg = 1, dsread_per_wg = 1
|
||||
else if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 &&
|
||||
std::is_same_v<ADataType, fp16_t>) ||
|
||||
(warp_m == 32 && warp_n == 32 && warp_k == 16 &&
|
||||
std::is_same_v<ADataType, fp16_t>) ||
|
||||
(warp_m == 16 && warp_n == 16 && warp_k == 128 /*&&
|
||||
std::is_same_v<ADataType, fp4_t> */) ||
|
||||
(warp_m == 32 && warp_n == 32 && warp_k == 64 /*&&
|
||||
std::is_same_v<ADataType, fp4_t> */))
|
||||
{
|
||||
return {1, 1};
|
||||
}
|
||||
// Default configuration
|
||||
else
|
||||
{
|
||||
return {1, 1};
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr auto mfma_config = GetMfmaConfig();
|
||||
static constexpr auto mfma_per_wg = mfma_config.mfma_per_wg;
|
||||
static constexpr auto dsread_per_wg = mfma_config.dsread_per_wg;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegV1",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return PipelinePolicy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
// Keypoint of pipeline optimize is workload balance in time
|
||||
// instruction schedule example(128X256X256, 1X4, 16X16X128):
|
||||
// Iter MNK MFMA ds_read ds_write A_load b_load
|
||||
// -1 M6N0: 57 - 8 - -
|
||||
// -1 M6N1: 58 1 - - -
|
||||
// -1 M6N2: 59 - - 7 -
|
||||
// -1 M6N3: 60 2 - - -
|
||||
// -1 M7N0: 61 - - - -
|
||||
// -1 M7N1: 62 3 - - -
|
||||
// -1 M7N2: 63 - - 8 -
|
||||
// -1 M7N3: 64 4 - - -
|
||||
// 0 M0N0K0: 1 - - - -
|
||||
// 0 M0N1: 2 5 - - 2
|
||||
// 0 M0N2: 3 - - - -
|
||||
// 0 M0N3: 4 6 - - -
|
||||
// 0 M1N0: 5 - - - -
|
||||
// 0 M1N1: 6 7 - - 4
|
||||
// 0 M1N2: 7 - - - -
|
||||
// 0 M1N3: 8 8 - - -
|
||||
// 0 M2N0: 9 - - - -
|
||||
// 0 M2N1: 10 9 - - 6
|
||||
// 0 M2N2: 11 - - - -
|
||||
// 0 M2N3: 12 10 - - -
|
||||
// 0 M3N0: 13 - 1 - -
|
||||
// 0 M3N1: 14 11 - - 8
|
||||
// 0 M3N2: 15 - - - -
|
||||
// 0 M3N3: 16 12 - - -
|
||||
// 0 M4N0: 17 - 2 - -
|
||||
// 0 M4N1: 18 13 - - -
|
||||
// 0 M4N2: 19 - - 1 -
|
||||
// 0 M4N3: 20 14 - - -
|
||||
// 0 M5N0: 21 - 3 - -
|
||||
// 0 M5N1: 22 15 - - -
|
||||
// 0 M5N2: 23 - - 2 -
|
||||
// 0 M5N3: 24 16 - - -
|
||||
// 0 M6N0: 25 - 4 - -
|
||||
// 0 M6N1: 26 17 - - -
|
||||
// 0 M6N2: 27 - - 3 -
|
||||
// 0 M6N3: 28 18 - - -
|
||||
// 0 M7N0: 29 - - - -
|
||||
// 0 M7N1: 30 19 - - -
|
||||
// 0 M7N2: 31 - - 4 -
|
||||
// 0 M7N3: 32 20 - - -
|
||||
// 0 M0N0K1: 33 - - - -
|
||||
// 0 M0N1: 34 21 - - 10
|
||||
// 0 M0N2: 35 - - - -
|
||||
// 0 M0N3: 36 22 - - -
|
||||
// 0 M1N0: 37 - - - -
|
||||
// 0 M1N1: 38 23 - - 12
|
||||
// 0 M1N2: 39 - - - -
|
||||
// 0 M1N3: 40 24 - - -
|
||||
// 0 M2N0: 41 - - - -
|
||||
// 0 M2N1: 42 25 - - 14
|
||||
// 0 M2N2: 43 - - - -
|
||||
// 0 M2N3: 44 26 - - -
|
||||
// 0 M3N0: 45 - 5 - -
|
||||
// 0 M3N1: 46 27 - - 16
|
||||
// 0 M3N2: 47 - - - -
|
||||
// 0 M3N3: 48 28 - - -
|
||||
// 0 M4N0: 49 - 6 - -
|
||||
// 0 M4N1: 50 29 - - -
|
||||
// 0 M4N2: 51 - - 5 -
|
||||
// 0 M4N3: 52 30 - - -
|
||||
// 0 M5N0: 53 - 7 - -
|
||||
// 0 M5N1: 54 31 - - -
|
||||
// 0 M5N2: 55 - - 6 -
|
||||
// 0 M5N3: 56 32 - - -
|
||||
// 0 M6N0: 57 - 8 - -
|
||||
// 0 M6N1: 58 1 - - -
|
||||
// 0 M6N2: 59 - - 7 -
|
||||
// 0 M6N3: 60 2 - - -
|
||||
// 0 M7N0: 61 - - - -
|
||||
// 0 M7N1: 62 3 - - -
|
||||
// 0 M7N2: 63 - - 8 -
|
||||
// 0 M7N3: 64 4 - - -
|
||||
|
||||
#if 0
|
||||
constexpr auto dsread_num_perK = dsread_per_wg * MIterPerWarp;
|
||||
constexpr auto dswrite_num_perK = (dsread_num_perK + MWarp * NWarp - 1) / (MWarp * NWarp);
|
||||
constexpr auto dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
|
||||
|
||||
// index_t dsread_perM[MIterPerWarp];
|
||||
// index_t dswrite_perM[MIterPerWarp];
|
||||
index_t dsread_perM[MIterPerWarp];
|
||||
index_t dswrite_perM[MIterPerWarp];
|
||||
index_t load_perM[MIterPerWarp];
|
||||
|
||||
constexpr int dswrite_inst = dswrite_num_perK;
|
||||
constexpr int NIter_num = NIterPerWarp*mfma_per_wg;
|
||||
|
||||
#pragma unroll
|
||||
for(int i=0;i<MIterPerWarp;i++)
|
||||
{
|
||||
dsread_perM[i] = 2;
|
||||
if(i==0)
|
||||
{
|
||||
dswrite_perM[0] = (dswrite_inst - MIterPerWarp + 2) > 0 ? dswrite_inst - MIterPerWarp + 2 : 0;
|
||||
}
|
||||
else if(i==MIterPerWarp-1)
|
||||
{
|
||||
dswrite_perM[MIterPerWarp-1] = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
dswrite_perM[i] = (i + 2 - dswrite_inst) > 0 ? 1 : 0;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int i=0;i<4;i++)
|
||||
{
|
||||
load_perM[i] = 2;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int i=4;i<8;i++)
|
||||
{
|
||||
load_perM[i] = 1;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int i=0;i<MIterPerWarp;i++)
|
||||
{
|
||||
int biger_num = dsread_perM[i] > load_perM[i] ? (dsread_perM[i] > dswrite_perM[i] ? dsread_perM[i] : dswrite_perM[i]) : (load_perM[i] > dswrite_perM[i] ? load_perM[i] : dswrite_perM[i]);
|
||||
int total_num = dsread_perM[i] + load_perM[i] + dswrite_perM[i];
|
||||
int gap = (total_num+NIter_num-1)/NIter_num;
|
||||
|
||||
index_t inst_order[MIterPerWarp*10];
|
||||
#pragma unroll
|
||||
for(int j=0;j<MIterPerWarp*10;j++)
|
||||
{
|
||||
inst_order[j] = 0;
|
||||
}
|
||||
|
||||
int index=0;
|
||||
#pragma unroll
|
||||
for(int j=0;j<biger_num;j++)
|
||||
{
|
||||
if(dswrite_perM[i]>j)
|
||||
{
|
||||
inst_order[index] = 1;
|
||||
index++;
|
||||
}
|
||||
if(load_perM[i]>j)
|
||||
{
|
||||
inst_order[index] = 2;
|
||||
index++;
|
||||
}
|
||||
if(dsread_perM[i]>j)
|
||||
{
|
||||
inst_order[index] = 3;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int j=0;j<NIter_num;j++)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
#pragma unroll
|
||||
for(int m=0;m<gap;m++)
|
||||
{
|
||||
if(m%2==0)
|
||||
{
|
||||
if(inst_order[j+m*NIter_num]==1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
if(inst_order[j+m*NIter_num]==2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if(inst_order[j+m*NIter_num]==3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(inst_order[(m+1)*NIter_num-1-j]==1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
if(inst_order[(m+1)*NIter_num-1-j]==2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if(inst_order[(m+1)*NIter_num-1-j]==3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
|
||||
if constexpr(kMPerBlock == 128 && kNPerBlock == 128 && kKPerBlock == 128)
|
||||
{
|
||||
constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad;
|
||||
constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp;
|
||||
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp;
|
||||
|
||||
static_for<0, A_LDS_Read_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TailHotLoopScheduler()
|
||||
{
|
||||
#if 0
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
|
||||
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block_ping = make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
|
||||
auto a_lds_block_pong = make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
auto a_copy_lds_window_ping =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
auto a_copy_lds_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping_tmp = make_tile_window(
|
||||
a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
|
||||
auto a_warp_window_pong_tmp = make_tile_window(
|
||||
a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows_ping;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
auto block_flatmm = BlockFlatmm();
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_flatmm.MakeCBlockTile();
|
||||
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
|
||||
auto b_flat_dram_window = // tile_window_with_static_distribution
|
||||
make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
// pingpong buffer for B
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_ping;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
|
||||
// Prefetch A0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefetch A1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
auto a_warp_tensor_ping = load_tile(a_warp_windows_ping(mIter)(kIter));
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_ping, b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
block_sync_lds();
|
||||
|
||||
HotLoopScheduler();
|
||||
|
||||
//Next K
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
|
||||
// Prefetch A(2i+3)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
auto a_warp_tensor_pong = load_tile(a_warp_windows_pong(mIter)(kIter));
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_pong, b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
block_sync_lds();
|
||||
|
||||
HotLoopScheduler();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
auto a_warp_tensor_ping = load_tile(a_warp_windows_ping(mIter)(kIter));
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_ping, b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
TailHotLoopScheduler();
|
||||
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
auto a_warp_tensor_pong = load_tile(a_warp_windows_pong(mIter)(kIter));
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_pong, b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
auto a_warp_tensor_ping = load_tile(a_warp_windows_ping(mIter)(kIter));
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_ping, b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -340,6 +341,37 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
@@ -431,12 +463,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy<
|
||||
typename Problem::ADataType,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,188 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
static constexpr index_t KBPerLoad = 32;
|
||||
static constexpr index_t N_Pack = 2; // it's fixed for fp4
|
||||
static constexpr index_t K_Pack = 2; // it's fixed for fp4
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
|
||||
number<ContiguousThreadsCntInDS_READ_16B>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp16xF4_ADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M2, M1 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ALDS_TileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
|
||||
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
constexpr int Repeat = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M0 = TileShape::WarpTile::at(I0);
|
||||
|
||||
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
|
||||
|
||||
constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 8
|
||||
constexpr int XDL_PerThreadK = KBPerLoad / K2; // 4
|
||||
constexpr int K0 = K_Lane; // 4
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Repeat>,
|
||||
tuple<sequence<M0>, sequence<K0, XDL_PerThreadK, K2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NWavePerBlk, N_Pack>, // second
|
||||
// direction
|
||||
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>, // which direction
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4ScaleBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t XDLPerBlock = TileShape::kK / TileShape::WarpTile::at(I2);
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
|
||||
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
|
||||
|
||||
constexpr index_t NWavePerBlk = N_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>, // ?
|
||||
tuple<sequence<NWavePerBlk>, // second direction
|
||||
sequence<K_Lane, N_Lane, N_Pack * K_Pack>>, // first
|
||||
// direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<1>, sequence<2, 2>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -181,6 +181,150 @@ using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
||||
VectorSizeA_,
|
||||
VectorSizeB_>;
|
||||
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct FlatmmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Traits::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Traits::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kPadM = Traits::kPadM;
|
||||
static constexpr bool kPadN = Traits::kPadN;
|
||||
static constexpr bool kPadK = Traits::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm_problem",
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
|
||||
? pixels_per_thread
|
||||
: PackedSize * VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
|
||||
? pixels_per_thread
|
||||
: PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
|
||||
{
|
||||
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
|
||||
constexpr index_t M0 = get_warp_size() / N2;
|
||||
constexpr index_t M1 = BlockGemmShape::kM / M0;
|
||||
|
||||
return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = BlockGemmShape::kN / N0;
|
||||
|
||||
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeA = []() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentA();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadM ? 1 : GetAlignmentA();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr index_t VectorSizeB = []() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentB();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentB();
|
||||
}
|
||||
}();
|
||||
static constexpr index_t VectorSizeC = []() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentC();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadM ? 1 : GetAlignmentC();
|
||||
}
|
||||
}();
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
|
||||
10
include/ck_tile/ops/moe_flatmm.hpp
Normal file
10
include/ck_tile/ops/moe_flatmm.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/moe_flatmm/kernel/moe_flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
1260
include/ck_tile/ops/moe_flatmm/kernel/moe_flatmm_kernel.hpp
Normal file
1260
include/ck_tile/ops/moe_flatmm/kernel/moe_flatmm_kernel.hpp
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,8 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp)
|
||||
add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp)
|
||||
add_gtest_executable(test_ck_tile_mx_scale test_mx_scale.cpp)
|
||||
endif()
|
||||
162
test/ck_tile/data_type/test_mx_scale.cpp
Normal file
162
test/ck_tile/data_type/test_mx_scale.cpp
Normal file
@@ -0,0 +1,162 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
using ck_tile::bf16_t;
|
||||
using ck_tile::bf16x2_t;
|
||||
using ck_tile::fp16_t;
|
||||
using ck_tile::fp16x2_t;
|
||||
using ck_tile::fp32_t;
|
||||
using ck_tile::fp32x2_t;
|
||||
using ck_tile::number;
|
||||
using ck_tile::pk_fp4_t;
|
||||
|
||||
template <typename SRC, typename DST, bool is_device>
|
||||
CK_TILE_HOST void test_convert();
|
||||
|
||||
using ck_tile::e8m0_raw_t;
|
||||
using ck_tile::e8m0_t;
|
||||
|
||||
TEST(OCP_Scale, NumericLimits)
|
||||
{
|
||||
EXPECT_EQ(ck_tile::numeric<e8m0_t>::has_inf(), false);
|
||||
EXPECT_EQ(ck_tile::numeric<e8m0_t>::zero(), ck_tile::numeric<e8m0_t>::signaling_NaN());
|
||||
EXPECT_EQ(ck_tile::numeric<e8m0_t>::min(), e8m0_t{e8m0_raw_t{0b00000000}});
|
||||
EXPECT_EQ(ck_tile::numeric<e8m0_t>::max(), e8m0_t{e8m0_raw_t{0b11111110}});
|
||||
}
|
||||
TEST(OCP_Scale, NumericBasic)
|
||||
{
|
||||
auto scale_1 = e8m0_t{1.0f};
|
||||
auto scale_2 = e8m0_t{e8m0_raw_t{ck_tile::numeric_traits<e8m0_t>::bias}}; // 2^0
|
||||
EXPECT_EQ(scale_1, scale_2);
|
||||
|
||||
auto scale_3 = e8m0_t{8.0f};
|
||||
auto scale_4 = e8m0_t{e8m0_raw_t{3 + ck_tile::numeric_traits<e8m0_t>::bias}}; // 2^3
|
||||
EXPECT_EQ(scale_3, scale_4);
|
||||
}
|
||||
|
||||
TEST(OCP_Scale, ScaledConvertDevice)
|
||||
{
|
||||
constexpr bool is_device = true;
|
||||
test_convert<fp32_t, fp32_t, is_device>(); // fp32 -> fp4 -> fp32
|
||||
test_convert<fp16_t, fp16_t, is_device>();
|
||||
test_convert<bf16_t, bf16_t, is_device>();
|
||||
test_convert<fp32_t, fp16_t, is_device>();
|
||||
test_convert<fp32_t, bf16_t, is_device>();
|
||||
test_convert<fp16_t, fp32_t, is_device>();
|
||||
test_convert<bf16_t, fp32_t, is_device>();
|
||||
}
|
||||
TEST(OCP_Scale, ScaledConvertHost)
|
||||
{
|
||||
constexpr bool is_device = false;
|
||||
test_convert<fp32_t, fp32_t, is_device>(); // fp32 -> fp4 -> fp32
|
||||
test_convert<fp16_t, fp16_t, is_device>();
|
||||
test_convert<bf16_t, bf16_t, is_device>();
|
||||
test_convert<fp32_t, fp16_t, is_device>();
|
||||
test_convert<fp32_t, bf16_t, is_device>();
|
||||
test_convert<fp16_t, fp32_t, is_device>();
|
||||
test_convert<bf16_t, fp32_t, is_device>();
|
||||
}
|
||||
TEST(OCP_Scale, tensorInit)
|
||||
{
|
||||
using scale_t = e8m0_t;
|
||||
ck_tile::HostTensor<scale_t> scales({10, 10});
|
||||
ck_tile::FillUniformDistribution<scale_t>{1.f, 1.f}(scales);
|
||||
scales.SetZero();
|
||||
}
|
||||
|
||||
#define toPF4(x, y) ck_tile::scaled_type_convert<pk_fp4_t>(x, y)
|
||||
#define toDST(x, y) ck_tile::scaled_type_convert<DST>(x, y)
|
||||
#define toDSTx2(x, y) ck_tile::scaled_type_convert<DSTx2_t>(x, y)
|
||||
|
||||
#define toF32(x) ck_tile::type_convert<float>(x)
|
||||
#define toPF4_(x) ck_tile::type_convert<pk_fp4_t>(x)
|
||||
#define toSRC(x) ck_tile::type_convert<SRC>(x)
|
||||
#define toDST_(x) ck_tile::type_convert<DST>(x)
|
||||
|
||||
template <typename Kernel, typename... Args>
|
||||
__global__ void MyKernel(Args... args)
|
||||
{
|
||||
Kernel{}(args...);
|
||||
}
|
||||
template <typename SRC, typename DST, int N>
|
||||
struct SrcPkfp4Dst
|
||||
{
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()(const SRC* src, DST* dst, e8m0_t scale1, e8m0_t scale2) const
|
||||
{
|
||||
|
||||
using SRCx2_t = ck_tile::ext_vector_t<SRC, 2>;
|
||||
using DSTx2_t = ck_tile::ext_vector_t<DST, 2>;
|
||||
|
||||
ck_tile::static_for<0, N, 2>{}([&](auto i) {
|
||||
const auto input2 = SRCx2_t{src[i], src[i + 1]};
|
||||
|
||||
if(i % 4 == 0)
|
||||
{
|
||||
// ex: fp32_t -> fp4 -> bf16_t
|
||||
dst[i] = toDST(toPF4(src[i], scale1), scale2);
|
||||
// ex: fp32x2_t -> pk_fp4 -> unpack<0> -> bf16_t
|
||||
dst[i + 1] = toDST(toPF4_(toPF4(input2, scale1).unpack(number<1>{})), scale2);
|
||||
}
|
||||
else
|
||||
{
|
||||
// ex: fp32x2_t -> pk_fp4_t -> bf16x2_t
|
||||
reinterpret_cast<DSTx2_t*>(dst)[i >> 1] = toDSTx2(toPF4(input2, scale1), scale2);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SRC, typename DST, bool is_device>
|
||||
CK_TILE_HOST void test_convert()
|
||||
{
|
||||
const auto test_data = std::array{4.f, 6.f, 8.f, 10.f};
|
||||
const auto ref_data = std::array{8.f, 16.f, 16.f, 16.f};
|
||||
const auto scale1 = e8m0_t{8.0f};
|
||||
const auto scale2 = e8m0_t{16.0f};
|
||||
|
||||
static_assert(test_data.size() == ref_data.size());
|
||||
static_assert(test_data.size() % 2 == 0);
|
||||
|
||||
constexpr int N = test_data.size();
|
||||
std::array<SRC, N> in;
|
||||
std::array<DST, N> ref, out;
|
||||
|
||||
// prepare input and ground truth in host
|
||||
for(int i = 0; i < N; ++i)
|
||||
{
|
||||
in[i] = toSRC(test_data[i]);
|
||||
ref[i] = toDST_(ref_data[i]);
|
||||
EXPECT_EQ(test_data[i], toF32(in[i]));
|
||||
EXPECT_EQ(ref_data[i], toF32(ref[i]));
|
||||
}
|
||||
|
||||
using job = SrcPkfp4Dst<SRC, DST, N>;
|
||||
|
||||
if constexpr(is_device)
|
||||
{
|
||||
auto in_d = std::make_unique<ck_tile::DeviceMem>(in.size() * sizeof(SRC));
|
||||
auto out_d = std::make_unique<ck_tile::DeviceMem>(out.size() * sizeof(DST));
|
||||
in_d->ToDevice(in.data());
|
||||
|
||||
MyKernel<job><<<1, 1>>>(reinterpret_cast<const SRC*>(in_d->GetDeviceBuffer()),
|
||||
reinterpret_cast<DST*>(out_d->GetDeviceBuffer()),
|
||||
scale1,
|
||||
scale2);
|
||||
|
||||
out_d->FromDevice(out.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
job{}(in.data(), out.data(), scale1, scale2);
|
||||
}
|
||||
|
||||
for(int i = 0; i < N; ++i)
|
||||
EXPECT_EQ(ref[i], out[i]) << "i:" << i;
|
||||
}
|
||||
Reference in New Issue
Block a user