mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Moe gemm activation (#2026)
* fix useless code and remove usless oob * clang format * fix coredump in e2e test * fix2 * fix clang format * fix output oob * impl int64 but result not correct * int64 index ok now * input output all ok * fix uint32 * revert v1 test * use uint32 * mork to support 13w tokens * moe sorting fix moebuf * fix merge * update moe api fix aiter build * fix buid * fuse silu * silu ok * acale ok * add silu * change code * gemm2 ok * gufusion compatible ok, fix warnings * gu fusion for m32 m64 ok * support bf16 cshuffle * i4 gemm2 ok * i4 gemm2 ok and i4 gemm1 build * 16x16 run ok * change flops; change cshuffle dtype * fuse gelu silu act in moe gemm1 * fp8 with act ready * int4 act ready * remove useless changes * remove useless code change * fix clang format * add the arch limit of int4 moe gemm * fuse moe activation * fix fp8 16x16 * fix no quant case * fix bugs * fix fp8 gufusion bug * remove useless comments * refine activation code & complete moe example * fix int8 bugs * merge tkw1 --------- Co-authored-by: coderfeli <coderfeli@163.com> Co-authored-by: feli <felix.li@amd.com> Co-authored-by: illsilin <Illia.Silin@amd.com> Co-authored-by: root <root@hjbog-srdc-51.amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp"
|
||||
|
||||
@@ -26,12 +26,17 @@ namespace ck {
|
||||
// two lds chunks.
|
||||
// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
|
||||
// buffer when we declare __shared__ inside blkgemmpipe
|
||||
|
||||
enum Activation
|
||||
{
|
||||
gelu_and_mul = 0,
|
||||
silu_and_mul = 1
|
||||
};
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
bool IsInputGemm = false,
|
||||
bool MulRoutedWeight = true,
|
||||
TailNumber TailNum = TailNumber::Even>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -45,22 +50,19 @@ __global__ void
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
TailNum>(karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -70,8 +72,6 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
bool IsInputGemm = false,
|
||||
bool MulRoutedWeight = true,
|
||||
TailNumber TailNum = TailNumber::Even>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -86,23 +86,20 @@ __global__ void
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
TailNum>(karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -154,7 +151,12 @@ template <typename ALayout,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t ActivationOperation = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = true,
|
||||
bool PerTokenQuant = false,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ADataType,
|
||||
@@ -227,6 +229,7 @@ struct GridwiseMoeGemm
|
||||
const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
|
||||
const index_t gridx = NSwizzle ? nblock * mblock : nblock;
|
||||
const index_t gridy = NSwizzle ? 1 : mblock;
|
||||
|
||||
return std::make_tuple(gridx, gridy, 1);
|
||||
}
|
||||
|
||||
@@ -305,7 +308,7 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
|
||||
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
|
||||
IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
@@ -497,8 +500,8 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
|
||||
template <typename ELayout>
|
||||
__host__ __device__ static auto
|
||||
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
|
||||
__host__ __device__ static auto MakeCGridDescriptor_M_N(
|
||||
IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
|
||||
@@ -909,7 +912,8 @@ struct GridwiseMoeGemm
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack>())>;
|
||||
KPack,
|
||||
IsInputGemm>())>;
|
||||
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
@@ -1141,9 +1145,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = true,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run(const index_t* p_sorted_token_ids,
|
||||
const index_t* p_sorted_expert_ids,
|
||||
const index_t* p_max_token_id,
|
||||
@@ -1203,6 +1205,7 @@ struct GridwiseMoeGemm
|
||||
return {blockIdx.x, blockIdx.y};
|
||||
}
|
||||
}();
|
||||
|
||||
const index_t block_n_id = block_mn.first;
|
||||
const index_t block_m_id = block_mn.second;
|
||||
const index_t token0 =
|
||||
@@ -1218,7 +1221,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets;
|
||||
StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
@@ -1226,9 +1229,10 @@ struct GridwiseMoeGemm
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
gather_offsets(m0) = token_offset * problem.K;
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
@@ -1239,7 +1243,6 @@ struct GridwiseMoeGemm
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
@@ -1269,6 +1272,7 @@ struct GridwiseMoeGemm
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
IndexType,
|
||||
1,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1311,24 +1315,74 @@ struct GridwiseMoeGemm
|
||||
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
|
||||
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
|
||||
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
|
||||
decltype(c_thread_buf) c_thread_buf_up;
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
float,
|
||||
c_thread_buf.num_of_v_,
|
||||
c_thread_buf.s_per_v,
|
||||
true>
|
||||
c_thread_buf_fp32;
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_blockwise_copy_up,
|
||||
b_grid_buf,
|
||||
b_grid_buf_up,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
c_thread_buf_up,
|
||||
num_k_block_main_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
}
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
@@ -1356,6 +1410,185 @@ struct GridwiseMoeGemm
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
|
||||
// mul scales
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
const float* p_scale_b = p_ds_grid[I1];
|
||||
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
|
||||
if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
|
||||
{
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
|
||||
p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
|
||||
get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_scale_b += expert_id;
|
||||
}
|
||||
|
||||
vector_type<int32_t, 4> scale_token_ids;
|
||||
vector_type<float, 4> topk_weights;
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
|
||||
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
scale_token_ids =
|
||||
*c_style_pointer_cast<const vector_type<int32_t, M4>*>(
|
||||
p_sorted_token_ids + m_pos);
|
||||
}
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
|
||||
float scale_a = [&]() {
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
|
||||
const index_t token_offset = fused_token & 0xffffff;
|
||||
return token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset]
|
||||
: 0.0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return p_sorted_weights_0[0];
|
||||
}
|
||||
}();
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, m2 * M4 + m4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) =
|
||||
scale_a * scale_b * c_thread_buf[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) *
|
||||
topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
vector_type<float, 4> topk_weights; // for gemm2 only
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
|
||||
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, m2 * M4 + m4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = topk_weights.AsType<float>()[m4] *
|
||||
c_thread_buf_fp32[cidx];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
@@ -1453,17 +1686,8 @@ struct GridwiseMoeGemm
|
||||
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
const DDataType* ptr_ = p_ds_grid[i];
|
||||
// hack logic here to support different kind of strides. todo fix it.
|
||||
// ascale t, 1; bscale E, N, 1, move ptr to E
|
||||
if(i.value == 1)
|
||||
{
|
||||
ptr_ +=
|
||||
expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1);
|
||||
}
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
@@ -1526,7 +1750,8 @@ struct GridwiseMoeGemm
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
IndexType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
@@ -1538,7 +1763,6 @@ struct GridwiseMoeGemm
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
@@ -1568,35 +1792,21 @@ struct GridwiseMoeGemm
|
||||
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
|
||||
constexpr auto ENThreads =
|
||||
CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets;
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
|
||||
|
||||
auto dstidx = sfc_cde_block.GetIndex(access_id);
|
||||
const index_t c_token_pos =
|
||||
block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
|
||||
: 0.0;
|
||||
IndexType token_offset = fused_token & 0xffffff;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
const float* p_sorted_weights_2 = p_ds_grid[I2];
|
||||
if constexpr(sizeof(ADataType) < 2)
|
||||
weight = p_sorted_weights_2[c_token_pos + m0] * weight;
|
||||
else
|
||||
weight = p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
@@ -1604,7 +1814,7 @@ struct GridwiseMoeGemm
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_thread_buf_fp32,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
@@ -1617,8 +1827,7 @@ struct GridwiseMoeGemm
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(c_grid_buf),
|
||||
scatter_offsets,
|
||||
scatter_weights);
|
||||
scatter_offsets);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
@@ -1643,9 +1852,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = true,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
|
||||
const index_t* p_sorted_expert_ids,
|
||||
const index_t* p_max_token_id,
|
||||
@@ -1721,7 +1928,7 @@ struct GridwiseMoeGemm
|
||||
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
|
||||
token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<index_t, AMRepeats>
|
||||
StaticallyIndexedArray<IndexType, AMRepeats>
|
||||
gather_offsets; //= p_sorted_token_ids[token_pos];
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
@@ -1730,7 +1937,7 @@ struct GridwiseMoeGemm
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
gather_offsets(m0) = token_offset * problem.K;
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
|
||||
|
||||
@@ -1773,6 +1980,7 @@ struct GridwiseMoeGemm
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
IndexType,
|
||||
1,
|
||||
2>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1967,11 +2175,12 @@ struct GridwiseMoeGemm
|
||||
const DDataType* ptr_ = p_ds_grid[i];
|
||||
// hack logic here to support different kind of strides. todo fix it.
|
||||
// ascale t, 1; bscale E, N, 1, move ptr to E
|
||||
if(i.value == 1)
|
||||
{
|
||||
ptr_ +=
|
||||
expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1);
|
||||
}
|
||||
// if(i.value == 1)
|
||||
// {
|
||||
// ptr_ +=
|
||||
// expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N :
|
||||
// 1);
|
||||
// }
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
},
|
||||
@@ -2036,7 +2245,8 @@ struct GridwiseMoeGemm
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
IndexType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
@@ -2078,12 +2288,9 @@ struct GridwiseMoeGemm
|
||||
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
|
||||
constexpr auto ENThreads =
|
||||
CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<index_t, EMRepeats>
|
||||
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
|
||||
|
||||
auto dstidx = sfc_cde_block.GetIndex(access_id);
|
||||
const index_t c_token_pos =
|
||||
@@ -2091,23 +2298,11 @@ struct GridwiseMoeGemm
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
|
||||
: 0.0;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
const float* p_sorted_weights_2 = p_ds_grid[I2];
|
||||
if constexpr(sizeof(ADataType) < 2)
|
||||
weight = p_sorted_weights_2[c_token_pos + m0] * weight;
|
||||
else
|
||||
weight = p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
@@ -2128,8 +2323,7 @@ struct GridwiseMoeGemm
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(c_grid_buf),
|
||||
scatter_offsets,
|
||||
scatter_weights);
|
||||
scatter_offsets);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user