mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add bias for f16xf4 moe_flatmm
This commit is contained in:
@@ -50,10 +50,8 @@ template <typename FlatmmConfig,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind moe_kind = ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename ScaleM,
|
||||
typename ScaleN>
|
||||
float a16w4_moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
typename MoeFlatmmHostArgs>
|
||||
float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
|
||||
@@ -158,15 +158,18 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
|
||||
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
ck_tile::HostTensor<AccDataType> expert_bias(ck_tile::HostTensorDescriptor({experts * N}, {1}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
// for verification only, no need to satify weight normalization
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.0f, 1.0f}(expert_bias);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.0f, 1.0f}(expert_weight);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 0.0f}(expert_bias);
|
||||
}
|
||||
|
||||
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
|
||||
@@ -206,19 +209,17 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_tensor.SetZero();
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_dev{sizeof(ck_tile::index_t) *
|
||||
sorted_token_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_ids_dev{sizeof(ck_tile::index_t) *
|
||||
expert_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem max_token_id_dev{sizeof(ck_tile::index_t) *
|
||||
max_token_id.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_weight_dev{sizeof(AccDataType) *
|
||||
expert_weight.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem sorted_token_ids_dev{sorted_token_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_ids_dev{expert_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem max_token_id_dev{max_token_id.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_weight_dev{expert_weight.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_bias_dev{expert_bias.get_element_space_size_in_bytes()};
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.data());
|
||||
expert_weight_dev.ToDevice(expert_weight.data());
|
||||
expert_bias_dev.ToDevice(expert_bias.data());
|
||||
scale_b_shuffle_dev_buf.ToDevice(scale_b_shuffle.data());
|
||||
|
||||
const ck_tile::index_t* p_sorted_token_ids_dev =
|
||||
@@ -229,13 +230,17 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
static_cast<ck_tile::index_t*>(max_token_id_dev.GetDeviceBuffer());
|
||||
const AccDataType* p_sorted_expert_weight_dev =
|
||||
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
|
||||
|
||||
auto scale_b_shuffle_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
auto exp_bias_dev_ptr = ck_tile::FlatmmScalePointer<1>{
|
||||
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()), experts * N};
|
||||
|
||||
using MoeFlatmmArgs = ck_tile::MoeFlatmmHostArgs<
|
||||
ck_tile::FlatmmScalePointer<-1>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>>;
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>,
|
||||
ck_tile::FlatmmScalePointer<1>>;
|
||||
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
|
||||
p_sorted_expert_weight_dev,
|
||||
p_expert_ids_dev,
|
||||
@@ -254,7 +259,8 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
stride_B,
|
||||
stride_C,
|
||||
nullptr,
|
||||
scale_b_shuffle_dev_ptr};
|
||||
scale_b_shuffle_dev_ptr,
|
||||
exp_bias_dev_ptr};
|
||||
|
||||
invoke_a16w4_moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
@@ -328,7 +334,8 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
1,
|
||||
ScaleGranularityK,
|
||||
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()));
|
||||
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()));
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
|
||||
@@ -86,7 +86,7 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
|
||||
|
||||
ck_tile::index_t sorted_tile_num = num_tokens * topk / MPerBlock;
|
||||
ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * MPerBlock * topk;
|
||||
ck_tile::index_t valid_tile_num = sorted_tile_num;
|
||||
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
|
||||
@@ -161,14 +161,10 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
// for verification only, no need to satify weight normalization
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_dev{sizeof(ck_tile::index_t) *
|
||||
sorted_token_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_ids_dev{sizeof(ck_tile::index_t) *
|
||||
expert_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem max_token_id_dev{sizeof(ck_tile::index_t) *
|
||||
max_token_id.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_weight_dev{sizeof(AccDataType) *
|
||||
expert_weight.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem sorted_token_ids_dev{sorted_token_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_ids_dev{expert_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem max_token_id_dev{max_token_id.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_weight_dev{expert_weight.get_element_space_size_in_bytes()};
|
||||
|
||||
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem per_channel_scale_dev_buf(
|
||||
|
||||
@@ -40,7 +40,8 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
index_t scale_granularity_n,
|
||||
index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr)
|
||||
float* scale_B_ptr,
|
||||
float* expert_bias_ptr)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
@@ -200,18 +201,26 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
acc_up += acc_up_temp * scale_A * scale_B_up;
|
||||
|
||||
float bias = 0.f, bias_up = 0.f;
|
||||
if(expert_bias_ptr != nullptr)
|
||||
{
|
||||
bias = expert_bias_ptr[expert_id * N + col];
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
|
||||
}
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? scatter_token_id * strideC + col
|
||||
: col * strideC + scatter_token_id;
|
||||
if constexpr(MoeGemmKind < 2)
|
||||
{
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(
|
||||
ActivationOp{}(acc, MoeGemmKind == 1 ? acc_up : 1));
|
||||
ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
|
||||
}
|
||||
else
|
||||
{
|
||||
// moe gemm2 don't use activation.
|
||||
CDataType res = ck_tile::type_convert<CDataType>(acc * expert_weight_ptr[row]);
|
||||
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
|
||||
using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
|
||||
ck_tile::fp16x2_t,
|
||||
ck_tile::bf16x2_t>;
|
||||
@@ -261,7 +270,8 @@ void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
|
||||
index_t scale_granularity_n,
|
||||
index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr)
|
||||
float* scale_B_ptr,
|
||||
float* exp_bias = nullptr)
|
||||
{
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
int totalElements = M * problem_N;
|
||||
@@ -296,7 +306,8 @@ void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
|
||||
scale_granularity_n,
|
||||
scale_granularity_k,
|
||||
scale_A_ptr,
|
||||
scale_B_ptr);
|
||||
scale_B_ptr,
|
||||
exp_bias);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
class ExpertBias = FlatmmScalePointer<-1>>
|
||||
struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
|
||||
{
|
||||
ck_tile::index_t NumTokens;
|
||||
@@ -24,6 +26,7 @@ struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
|
||||
const ck_tile::index_t* p_sorted_expert_ids;
|
||||
const ck_tile::index_t* p_max_token_id;
|
||||
const void* p_sorted_expert_weights;
|
||||
ExpertBias exp_bias;
|
||||
|
||||
CK_TILE_HOST MoeFlatmmHostArgs() noexcept = default;
|
||||
CK_TILE_HOST MoeFlatmmHostArgs(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
@@ -43,8 +46,9 @@ struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
|
||||
ck_tile::index_t stride_A_,
|
||||
ck_tile::index_t stride_B_,
|
||||
ck_tile::index_t stride_C_,
|
||||
ScaleM scale_m_ = {},
|
||||
ScaleN scale_n_ = {})
|
||||
ScaleM scale_m_ = {},
|
||||
ScaleN scale_n_ = {},
|
||||
ExpertBias exp_bias_ = {})
|
||||
: ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>(a_ptr_,
|
||||
b_ptr_,
|
||||
{}, // d_ptr_array
|
||||
@@ -65,7 +69,8 @@ struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
|
||||
p_sorted_token_ids(p_sorted_token_ids_),
|
||||
p_sorted_expert_ids(p_sorted_expert_ids_),
|
||||
p_max_token_id(p_max_token_id_),
|
||||
p_sorted_expert_weights(p_sorted_expert_weights_)
|
||||
p_sorted_expert_weights(p_sorted_expert_weights_),
|
||||
exp_bias(exp_bias_)
|
||||
{
|
||||
}
|
||||
};
|
||||
@@ -91,13 +96,14 @@ struct MoeSilu
|
||||
|
||||
struct Swiglu
|
||||
{
|
||||
float alpha = 1.702f; // default value used in gpt-oss
|
||||
float limit = 7.0f; // default value used in gpt-oss
|
||||
const float alpha;
|
||||
const float limit;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
Swiglu() = default;
|
||||
CK_TILE_HOST_DEVICE
|
||||
Swiglu(float alpha_, float limit_) : alpha(alpha_), limit(limit_) {}
|
||||
Swiglu(float alpha_ = 1.702f, float limit_ = 7.0f) // use value in gpt-oss as default
|
||||
: alpha(alpha_), limit(limit_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
|
||||
@@ -190,7 +196,9 @@ struct MoeFlatmmKernel
|
||||
|
||||
static constexpr int WeightPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
class ExpertBias = FlatmmScalePointer<-1>>
|
||||
struct MoeFlatmmKernelArgs
|
||||
{
|
||||
const ck_tile::index_t* p_sorted_token_ids;
|
||||
@@ -211,30 +219,34 @@ struct MoeFlatmmKernel
|
||||
ck_tile::index_t k_batch;
|
||||
ScaleM scale_m;
|
||||
ScaleN scale_n;
|
||||
ExpertBias exp_bias;
|
||||
};
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
class ExpertBias = FlatmmScalePointer<-1>>
|
||||
CK_TILE_HOST static constexpr auto
|
||||
MakeKernelArgs(const MoeFlatmmHostArgs<ScaleM, ScaleN>& hostArgs)
|
||||
MakeKernelArgs(const MoeFlatmmHostArgs<ScaleM, ScaleN, ExpertBias>& hostArgs)
|
||||
{
|
||||
return MoeFlatmmKernelArgs<ScaleM, ScaleN>{hostArgs.p_sorted_token_ids,
|
||||
hostArgs.p_sorted_expert_ids,
|
||||
hostArgs.p_max_token_id,
|
||||
hostArgs.p_sorted_expert_weights,
|
||||
hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.NumTokens,
|
||||
hostArgs.TopK,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.k_batch,
|
||||
hostArgs.scale_m,
|
||||
hostArgs.scale_n};
|
||||
return MoeFlatmmKernelArgs<ScaleM, ScaleN, ExpertBias>{hostArgs.p_sorted_token_ids,
|
||||
hostArgs.p_sorted_expert_ids,
|
||||
hostArgs.p_max_token_id,
|
||||
hostArgs.p_sorted_expert_weights,
|
||||
hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.NumTokens,
|
||||
hostArgs.TopK,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.k_batch,
|
||||
hostArgs.scale_m,
|
||||
hostArgs.scale_n,
|
||||
hostArgs.exp_bias};
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
@@ -249,8 +261,8 @@ struct MoeFlatmmKernel
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
static constexpr auto GridSize(const MoeFlatmmKernelArgs<ScaleM, ScaleN>& kargs)
|
||||
template <class MoeFlatmmKernelArgs>
|
||||
static constexpr auto GridSize(const MoeFlatmmKernelArgs& kargs)
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
|
||||
}
|
||||
@@ -647,8 +659,8 @@ struct MoeFlatmmKernel
|
||||
return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs<ScaleM, ScaleN> kargs) const
|
||||
template <class MoeFlatmmKernelArgs>
|
||||
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const
|
||||
{
|
||||
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
@@ -845,6 +857,7 @@ struct MoeFlatmmKernel
|
||||
|
||||
constexpr int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle *
|
||||
OutputNumNXdlPerWavePerShuffle;
|
||||
constexpr bool EnableBias = decltype(kargs.exp_bias)::GranularityMN != -1;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
@@ -855,6 +868,7 @@ struct MoeFlatmmKernel
|
||||
float vec_scale_B[NRepeat];
|
||||
|
||||
float vec_expert_weights[kM0 * kM2 * MRepeat];
|
||||
float vec_expert_bias[kM0 * kM2 * MRepeat];
|
||||
|
||||
const float* expert_weights = static_cast<const float*>(kargs.p_sorted_expert_weights);
|
||||
|
||||
@@ -883,6 +897,32 @@ struct MoeFlatmmKernel
|
||||
});
|
||||
}
|
||||
}
|
||||
if constexpr(MXFP4_Pipeline && EnableBias)
|
||||
{
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
static_for<0, NRepeat / 2, 1>{}([&](auto i) {
|
||||
vec_expert_bias[i * 2] =
|
||||
kargs.exp_bias[expert_id * kargs.N + coord_n / 2 + i * NWave * NPerXdl +
|
||||
iNWarp * NPerXdl + iNLane];
|
||||
vec_expert_bias[i * 2 + 1] =
|
||||
kargs.exp_bias[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 +
|
||||
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 2>{}([&](auto i) {
|
||||
vec_expert_bias[i] =
|
||||
kargs.exp_bias[expert_id * kargs.N + coord_n + i * NWave * NPerXdl +
|
||||
iNWarp * 2 * NPerXdl + iNLane];
|
||||
vec_expert_bias[i + 1] =
|
||||
kargs.exp_bias[expert_id * kargs.N + coord_n + i * NWave * NPerXdl +
|
||||
iNWarp * 2 * NPerXdl + NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto i) {
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
@@ -926,15 +966,16 @@ struct MoeFlatmmKernel
|
||||
c_warp_y_lengths)));
|
||||
});
|
||||
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
{
|
||||
gate_tensor
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
@@ -942,11 +983,19 @@ struct MoeFlatmmKernel
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[2 * n_xdl + 1];
|
||||
});
|
||||
}
|
||||
if constexpr(EnableBias)
|
||||
{
|
||||
gate_tensor
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] +=
|
||||
vec_expert_bias[2 * n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] +=
|
||||
vec_expert_bias[2 * n_xdl + 1];
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
});
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
lds_tile[0].get_thread_buffer().at(idx) =
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
@@ -967,15 +1016,19 @@ struct MoeFlatmmKernel
|
||||
(m_xdl * NumNXdlPerWavePerShuffle + n_xdl) * c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_expert_weights[m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[n_xdl];
|
||||
if constexpr(EnableBias)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] +=
|
||||
vec_expert_bias[n_xdl];
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_expert_weights[m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1041,14 +1094,15 @@ struct MoeFlatmmKernel
|
||||
c_warp_y_lengths)));
|
||||
});
|
||||
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
{
|
||||
gate_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
@@ -1063,10 +1117,24 @@ struct MoeFlatmmKernel
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl + 1];
|
||||
});
|
||||
}
|
||||
if constexpr(EnableBias)
|
||||
{
|
||||
gate_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] +=
|
||||
vec_expert_bias[nIter_next *
|
||||
NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] +=
|
||||
vec_expert_bias[nIter_next *
|
||||
NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl + 1];
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
lds_tile[write_stage].get_thread_buffer().at(idx) =
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
@@ -1090,12 +1158,6 @@ struct MoeFlatmmKernel
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
|
||||
m2] *= vec_expert_weights
|
||||
[mIter_next * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
|
||||
@@ -1105,6 +1167,19 @@ struct MoeFlatmmKernel
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
n_xdl];
|
||||
if constexpr(EnableBias)
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
|
||||
m2] +=
|
||||
vec_expert_bias[nIter_next *
|
||||
NumNXdlPerWavePerShuffle +
|
||||
n_xdl];
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
|
||||
m2] *= vec_expert_weights
|
||||
[mIter_next * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user