mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] support split-k a16w4 gemm1 (#3389)
* initial version to support moe gemm1 split-k
* add missing args
* fix build warning
* update reference
* for split-k disable bias and weight
* remove debug log
* fix format
* fix div by zero errors
* fix cmake config
* update
* resolve conflicts
* remove useless changes
* reformat
* fix
* remove useless changes
* fix ci
---------
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: root <root@smci355-ccs-aus-m01-25.cs-aus.dcgpu>
[ROCm/composable_kernel commit: dae85ead64]
This commit is contained in:
@@ -31,13 +31,14 @@ if(has_supported_gpu)
|
||||
add_executable(tile_example_grouped_flatmm grouped_flatmm.cpp)
|
||||
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
if (GPU_TARGETS MATCHES "gfx95")
|
||||
if(GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx94")
|
||||
add_executable(tile_example_mixed_prec_flatmm mixed_prec/mixed_prec_flatmm.cpp)
|
||||
target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_a16w4_moe_flatmm mixed_prec/a16w4_moe_flatmm.cpp)
|
||||
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
endif()
|
||||
if (GPU_TARGETS MATCHES "gfx95")
|
||||
include(mxgemm/mx_flatmm_instance.cmake)
|
||||
mx_flatmm_instance_generate(EXAMPLE_MX_FLATMM_FILES)
|
||||
message(STATUS "Generated MX FlatMM kernel files: ${EXAMPLE_MX_FLATMM_FILES}")
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
|
||||
@@ -191,13 +191,15 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
std::cout << "Launching kernel " << Kernel::GetName() << "\n"
|
||||
<< "with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
<< "\n"
|
||||
<< "k_batch: " << kargs.k_batch << std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
@@ -471,10 +473,33 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[])
|
||||
throw std::runtime_error("Unsupported precision type for gemm2!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm1_split_k")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_split_k!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
|
||||
"[gemm1_gate_up | gemm2]");
|
||||
"[gemm1_gate_up | gemm1_split_k | gemm2]");
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
@@ -69,7 +69,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("gemm_kind",
|
||||
"gemm1_gate_up",
|
||||
"Gemm kind in FFN network [gemm1_gate_up | gemm2] - "
|
||||
"Gemm kind in FFN network [gemm1_gate_up | gemm2 | gemm1_split_k] - "
|
||||
"gemm1_gate_up by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
@@ -80,7 +80,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 16x16 (950 only, may use a larger tile than warp_tile=0)")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.")
|
||||
.insert("k_batch", "1", "parallism to control splik-k.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -67,9 +67,12 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using CDataType = PrecActType;
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using CDataType =
|
||||
std::conditional_t<kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k, float, PrecActType>;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
@@ -88,6 +91,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
|
||||
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
|
||||
const ck_tile::index_t experts = arg_parser.get_int("experts");
|
||||
const ck_tile::index_t k_batch = arg_parser.get_int("k_batch");
|
||||
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
|
||||
@@ -231,14 +235,15 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
|
||||
|
||||
auto scale_b_shuffle_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()),
|
||||
N / ScaleGranularityN};
|
||||
auto exp_bias_dev_ptr = ck_tile::FlatmmScalePointer<1>{
|
||||
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()), experts * N};
|
||||
|
||||
using MoeFlatmmArgs = ck_tile::MoeFlatmmHostArgs<
|
||||
ck_tile::FlatmmScalePointer<-1>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>,
|
||||
ck_tile::FlatmmScalePointer<1>>;
|
||||
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
|
||||
p_sorted_expert_weight_dev,
|
||||
@@ -250,7 +255,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
num_tokens,
|
||||
experts,
|
||||
topk,
|
||||
1, // k_batch
|
||||
k_batch, // k_batch
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
|
||||
@@ -85,8 +85,9 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
c_rslt_host.SetZero();
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffle.data());
|
||||
|
||||
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK>{
|
||||
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
|
||||
auto scale_b_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
|
||||
|
||||
invoke_mixed_prec_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
|
||||
@@ -25,14 +25,16 @@ using BF16 = ck_tile::bf16_t;
|
||||
using ROW = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using COL = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
inline constexpr auto ODD = ck_tile::TailNumber::Odd;
|
||||
inline constexpr auto EVEN = ck_tile::TailNumber::Even;
|
||||
|
||||
inline constexpr int ScaleGranularityM = 1;
|
||||
inline constexpr int ScaleGranularityN = 1;
|
||||
inline constexpr int ScaleGranularityK = 32;
|
||||
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>;
|
||||
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>;
|
||||
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>;
|
||||
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>;
|
||||
|
||||
template float mx_flatmm_calc<FLATMM_CONFIG,
|
||||
A_DATA_TYPE,
|
||||
|
||||
@@ -105,10 +105,12 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());
|
||||
|
||||
auto scale_a_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
|
||||
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
auto scale_a_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
|
||||
auto scale_b_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
|
||||
invoke_mx_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
|
||||
@@ -18,7 +18,7 @@ template <typename ADataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
|
||||
typename ActivationOp = identity>
|
||||
__global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
const ck_tile::index_t* p_sorted_expert_ids_,
|
||||
@@ -43,10 +43,11 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
float* scale_B_ptr,
|
||||
float* expert_bias_ptr)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
int row = idx / problem_N; // Compute row index
|
||||
int col = idx % problem_N; // Compute column index
|
||||
constexpr auto is_split_k = MoeGemmKind == 3;
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
int row = idx / problem_N; // Compute row index
|
||||
int col = idx % problem_N; // Compute column index
|
||||
|
||||
index_t gather_token_id = 0;
|
||||
index_t scatter_token_id = 0;
|
||||
@@ -203,7 +204,7 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
acc_up += acc_up_temp * scale_A * scale_B_up;
|
||||
|
||||
float bias = 0.f, bias_up = 0.f;
|
||||
if(expert_bias_ptr != nullptr)
|
||||
if(expert_bias_ptr != nullptr && !is_split_k)
|
||||
{
|
||||
bias = expert_bias_ptr[expert_id * N + col];
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
@@ -221,23 +222,24 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
else
|
||||
{
|
||||
// moe gemm2 don't use activation.
|
||||
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
|
||||
using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
|
||||
ck_tile::fp16x2_t,
|
||||
ck_tile::bf16x2_t>;
|
||||
ResV2Type add_v{0, 0};
|
||||
auto weight =
|
||||
is_split_k ? ck_tile::type_convert<AccDataType>(1.0f) : expert_weight_ptr[row];
|
||||
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * weight);
|
||||
|
||||
thread_buffer<CDataType, 2> add_v = 0;
|
||||
if(c_index % 2)
|
||||
{
|
||||
// result is the second value of fp16 pair.
|
||||
add_v.y = res;
|
||||
add_v.template get_as<CDataType>()[1] = res;
|
||||
}
|
||||
else
|
||||
{
|
||||
// result is the first value of fp16 pair.
|
||||
add_v.x = res;
|
||||
add_v.template get_as<CDataType>()[0] = res;
|
||||
}
|
||||
// mask last bit to make sure atomicAdd pointer is aligned of DWORD.
|
||||
atomic_add<ResV2Type>(reinterpret_cast<ResV2Type*>(C + (c_index & 0xffff'fffe)), add_v);
|
||||
atomic_add_g<CDataType, 2>(reinterpret_cast<CDataType*>(C + (c_index & 0xffff'fffe)),
|
||||
add_v);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -249,7 +251,7 @@ template <typename ADataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
|
||||
typename ActivationOp = identity>
|
||||
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
|
||||
const index_t* p_sorted_expert_ids_,
|
||||
|
||||
@@ -28,17 +28,18 @@ struct FlatmmProblem
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN, int SharedGranularityK = 0>
|
||||
template <int SharedGranularityMN, int SharedGranularityK = 0, typename ScaleType_ = float>
|
||||
struct FlatmmScalePointer
|
||||
{
|
||||
using ScaleType = ScaleType_;
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = SharedGranularityK;
|
||||
|
||||
const float* ptr;
|
||||
const ScaleType* ptr;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, [[maybe_unused]] index_t length_)
|
||||
: ptr(ptr_)
|
||||
{
|
||||
}
|
||||
@@ -57,23 +58,24 @@ struct FlatmmScalePointer
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
|
||||
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const = delete;
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN>
|
||||
struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
template <int SharedGranularityMN, typename ScaleType_>
|
||||
struct FlatmmScalePointer<SharedGranularityMN, 0, ScaleType_>
|
||||
{
|
||||
using ScaleType = ScaleType_;
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
static_assert(GranularityMN != 0);
|
||||
|
||||
const float* ptr;
|
||||
const ScaleType* ptr;
|
||||
index_t length;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_)
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_), length(1) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, index_t length_)
|
||||
: ptr(ptr_), length(length_)
|
||||
{
|
||||
}
|
||||
@@ -94,7 +96,7 @@ struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const
|
||||
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const
|
||||
{
|
||||
// with additional oob check
|
||||
if constexpr(GranularityMN == 1)
|
||||
@@ -105,23 +107,24 @@ struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
};
|
||||
|
||||
// shared granularityMN = -1 means no scale
|
||||
template <>
|
||||
struct FlatmmScalePointer<-1, 0>
|
||||
template <typename ScaleType_>
|
||||
struct FlatmmScalePointer<-1, 0, ScaleType_>
|
||||
{
|
||||
using ScaleType = ScaleType_;
|
||||
static constexpr int GranularityMN = -1;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
const float* ptr = nullptr;
|
||||
const ScaleType* ptr = nullptr;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*, index_t) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
|
||||
{
|
||||
return FlatmmScalePointer{};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
|
||||
CK_TILE_HOST_DEVICE constexpr ScaleType operator[](index_t) const
|
||||
{
|
||||
return 1; // alway return 1, it doesn't change the result
|
||||
}
|
||||
|
||||
@@ -132,6 +132,7 @@ enum class MoeFlatmmKind
|
||||
kFFN_gemm1_gate_only,
|
||||
kFFN_gemm1_gate_up,
|
||||
kFFN_gemm2,
|
||||
kFFN_gemm1_split_k,
|
||||
};
|
||||
|
||||
namespace moe {
|
||||
@@ -222,8 +223,10 @@ struct MoeFlatmmKernel
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
|
||||
static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
|
||||
static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
|
||||
static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
|
||||
static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
|
||||
static constexpr bool IsGemm1SplitK = kind == MoeFlatmmKind::kFFN_gemm1_split_k;
|
||||
static constexpr bool IsBShuffled = true;
|
||||
|
||||
// static constexpr index_t kBlockSize = EpiloguePipeline::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = EpiloguePipeline::kMPerBlock;
|
||||
@@ -395,15 +398,6 @@ struct MoeFlatmmKernel
|
||||
a_k_split_offset = k_id * KRead * kargs.stride_A;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead * kargs.stride_B;
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead;
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
{
|
||||
splitted_k = KRead;
|
||||
@@ -412,6 +406,22 @@ struct MoeFlatmmKernel
|
||||
{
|
||||
splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
|
||||
}
|
||||
|
||||
if constexpr(IsBShuffled)
|
||||
{
|
||||
b_k_split_offset = k_id * splitted_k * NPerXdl;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead * kargs.stride_B;
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
@@ -573,15 +583,16 @@ struct MoeFlatmmKernel
|
||||
return DTesnorIsValid;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = IsInputGemm ? memory_operation_enum::set
|
||||
: memory_operation_enum::atomic_add,
|
||||
template <memory_operation_enum DstInMemOp = (IsInputGemm && !IsGemm1SplitK)
|
||||
? memory_operation_enum::set
|
||||
: memory_operation_enum::atomic_add,
|
||||
typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
EDataType* e_ptr,
|
||||
[[maybe_unused]] const AccDataType* exp_weight_ptr,
|
||||
const int expert_id,
|
||||
[[maybe_unused]] const int expert_id,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
@@ -742,13 +753,13 @@ struct MoeFlatmmKernel
|
||||
{
|
||||
index_t scale_k =
|
||||
BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
|
||||
const auto scale_k_offset =
|
||||
(splitk_batch_offset.b_k_split_offset / BGranularityK) * K_Pack;
|
||||
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
using ScaleType = e8m0_t;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
|
||||
scale_n.ptr + expert_id * kargs.N * scale_k + scale_k_offset,
|
||||
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
@@ -1386,11 +1397,16 @@ struct MoeFlatmmKernel
|
||||
if constexpr(!BMXFP4_Pipeline)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] *=
|
||||
epi_scale_m[idx] * epi_scale_n[idx];
|
||||
if constexpr(EnableBias)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
|
||||
else // for mlp1 gate-only
|
||||
if(kind !=
|
||||
MoeFlatmmKind::kFFN_gemm1_split_k) // disable weight and bias for split-k
|
||||
{
|
||||
if constexpr(EnableBias)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
|
||||
}
|
||||
if constexpr(kind ==
|
||||
MoeFlatmmKind::kFFN_gemm1_gate_only) // for mlp1 gate-only
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] =
|
||||
ActivationOp{}(lds_tile[lds_stage].get_thread_buffer()[idx]);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user