update buffer load to lds feature, build passed

This commit is contained in:
mtgu0705
2025-05-21 02:40:20 -05:00
parent 3a05fa135a
commit 513f92f5b9
4 changed files with 720 additions and 342 deletions

View File

@@ -40,7 +40,7 @@ using B0DataType = F4;
using B1DataType = XPackedDataType;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CShuffleDataType = F16;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
@@ -62,8 +62,8 @@ struct MulABScaleExpertWeight
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
@@ -580,7 +580,7 @@ int main(int argc, char* argv[])
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<CShuffleDataType> c_t_n({tokens, N});
Tensor<float> c_t_n({tokens, N});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
@@ -588,7 +588,8 @@ int main(int argc, char* argv[])
B0DataType,
XDataType,
D2DataType,
CShuffleDataType,
float, // using float for Cshuffle type
// in reference
AccDataType,
PassThrough,
PassThrough,

View File

@@ -0,0 +1,336 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
/**
* Transfer that uses direct load instructions to copy data from global to LDS memory.
*
* Traditional loads first copy data from global to registers, and then from registers to LDS.
* Direct loads do not need an intermediate step, data is copied directly from global to LDS,
* without the use of additional registers.
*
* However, the instruction has limitations:
* - each thread must copy exactly a single DWORD - 4 bytes;
* - threads within a single wavefront must write consecutive DWORDS into LDS,
* (data in global do not need to be contiguous, each thread might have its own offset).
*
* To make sure that all the transfers finished, the `waitcnt` instruction must be used with
* `vmcnt` instead of `lgkmcnt`.
*
* Limitations of the transfer class:
* - `SrcData` must be the same as `DstData` - no possibility to convert the data type in flight;
* - `DstVectorDim` must be the last dimension;
* - `SrcVectorDim` must be the last dimension if `ScalarPerVector` is greater than 1;
* - `ScalarPerVector` times the number of bytes of `DstData` must be equal to a single DWORD = 4B
* (for examlpe if `DstData` is fp32, then `ScalarPerVector` must be 1; if `DstData` is fp16,
* `ScalarPerVector` must be 2);
* - if `ScalarPerVector` is greater than 1, the contiguous dimension in src and dst must be
* the same dimension;
* - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64,
* they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way
* to guarantee that.
*/
template <typename ThreadGroup,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t ScalarPerVector,
typename IndexType,
index_t GatherDim = 1,
bool SrcXor = true>
struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
static constexpr auto block_slice_lengths = BlockSliceLengths{};
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
static constexpr auto thread_single_load_size = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
static __device__ constexpr bool AreThreadClusterLengthsValid()
{
// Make sure that ThreadClusterLengths are set in a way that allows for contiguous writes to
// LDS by the threads from a single wavefront.
// Examples (assuming 64 threads in a wavefront, 128 in a thread block):
// 1. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
// data type = fp32 -> ScalarPerVector = 1
// INVALID: ThreadClusterLengths = [4, 4, 8] since in the first iteration, threads 0-31
// write [0, 0, 0] - [0, 3, 7] and thread 32 writes [1, 0, 0] instead of
// [0, 4, 0].
// VALID: ThreadClusterLengths = [2, 8, 8] or [1, 16, 8] since in the first iteration,
// threads 0-63 write [0, 0, 0] - [0, 7, 7] -> 64 consecutive elements (DWORDs).
// 2. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
// data type = fp16 -> ScalarPerVector = 2
// NOTE: ThreadClusterLengths must take into account that each thread writes two
// elements (single DWORD) along the contiguous dimension.
// INVALID: ThreadClusterLengths = [4, 4, 8] since each 8 threads would try to write
// 8 * 2 elements of K1PerBlock and there are only 8;
// ThreadClusterLengths = [4, 8, 4] since in the first iteration, threads 0-31
// write [0, 0, 0] - [0, 7, 7] (7 since each writes 2 elements) and thread 32
// writes [1, 0, 0] instead of [0, 8, 0].
// VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
// first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
// elements = 64 consecutive DWORDs.
int num_contiguous_dwords = 4;
bool is_contiguous = true;
static_for<0, nDim, 1>{}([&](auto i) {
if(is_contiguous)
{
num_contiguous_dwords *= thread_cluster_lengths[nDim - i - 1];
}
if(thread_slice_lengths[nDim - i - 1] > 1)
{
CK_PRINT<Number<thread_slice_lengths[nDim - i - 1]>>();
is_contiguous = false;
}
});
constexpr index_t wavefront_size = get_warp_size();
const bool wave_contiguous = num_contiguous_dwords % wavefront_size == 0;
bool thread_slice_lengths_correct = true;
static_for<0, nDim, 1>{}([&](auto i) {
if(thread_slice_lengths[i] <= 0)
{
thread_slice_lengths_correct = false;
}
});
return wave_contiguous && thread_slice_lengths_correct;
}
__device__ constexpr ThreadGroupTensorSliceTransfer_Gather_DirectLoad(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const StaticallyIndexedArray<IndexType, gather_num>& gather_offsets)
: gather_offsets_(gather_offsets)
{
static_assert(ck::is_same_v<SrcData, DstData>,
"Direct load transfer does not support datatypes conversion. Source and "
"destination data types must be the same.");
static_assert(
DstVectorDim == nDim - 1,
"Direct load transfer requires the destination vector dimension to be the last one.");
static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
"When loading more than one element per thread at once, the contiguous "
"dimension must be the same between source and destination.");
// constexpr auto dword_bytes = 4;
// constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
// static_assert(bytes_per_thread_load == dword_bytes,
// "Direct load transfer requires each thread to load exactly a single "
// "DWORD of data.");
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size(),
"Inconsistent number of dimensions across lengths and descriptors.");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"The number of threads cannot be less than the number of elements in "
"thread cluster lengths.");
// static_assert(
// AreThreadClusterLengthsValid(),
// "Thread cluster lengths are incorrect. They must be set in a way that allows a single
// " "wavefront to write contiguous DWORDs into LDS memory. ");
const auto thread_cluster_idx =
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
auto adjusted_src_origin_idx = [&]() {
Index idx;
static_for<0, nDim, 1>{}([&](auto i) {
idx(i) = i.value == GatherDim ? 0 : src_slice_origin_idx[Number<i>{}];
});
return idx;
}();
src_coord_ = make_tensor_coordinate(src_desc, adjusted_src_origin_idx);
src_slice_origin_ = adjusted_src_origin_idx;
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
dst_slice_origin_ = dst_slice_origin_idx;
}
__device__ void ResetDstSliceWindow(const DstDesc& dst_desc)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_);
}
template <typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global,
"Source data must come from a global memory buffer.");
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"Destination data must be stored in an LDS memory buffer.");
static_assert(
ck::is_same_v<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>,
"SrcBuffer and SrcData data types must be consistent.");
static_assert(
ck::is_same_v<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>,
"DstBuffer and DstData data types must be consistent.");
constexpr auto dst_access_lengths = thread_slice_lengths;
const auto dst_forward_steps = generate_steps(dst_desc, 1);
const auto dst_backward_steps = generate_steps(dst_desc, -1);
const auto src_forward_steps = generate_steps(src_desc, 1);
const auto src_backward_steps = generate_steps(src_desc, -1);
// Loop over the destination block and copy data.
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// CK_PRINT<decltype(dst_access_lengths), decltype(ordered_dst_access_idx)>();
auto gather_offset = gather_offsets_(Number<GatherDim>{});
const auto src_offset = src_coord_.GetOffset() + gather_offset;
const auto dst_offset = dst_coord_.GetOffset();
// printf("Tid: %03d, src_offset: %d, dst_offset: %d\n", get_thread_local_1d_id(),
// src_coord_.GetOffset(), dst_coord_.GetOffset());
// Check if src data is not in the logic padding area.
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
dst_buf, src_offset, dst_offset, is_src_valid);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
});
move_on_dim_(i) &= i.value != GatherDim;
});
return move_on_dim_;
}
();
// Decide whether to move forward or backward.
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]);
}
else
{
move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]);
}
}
});
});
// Reset the destination slice since the entire buffer has been already filled.
ResetDstSliceWindow(dst_desc);
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
src_slice_origin_ = src_slice_origin_ + step;
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_);
}
template <typename DescType>
__device__ auto generate_steps(const DescType& desc, int sign)
{
return generate_tuple(
[&](auto i) {
Index step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
step_idx(j) = (i.value == j.value) ? sign * thread_steps[i] : 0;
});
return make_tensor_coordinate_step(desc, step_idx);
},
Number<nDim>{});
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
SrcCoord src_coord_;
DstCoord dst_coord_;
Index src_slice_origin_;
Index dst_slice_origin_;
StaticallyIndexedArray<IndexType, gather_num> gather_offsets_;
};
} // namespace ck

View File

@@ -299,20 +299,20 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
@@ -348,20 +348,20 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}

View File

@@ -10,13 +10,13 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp"
#define DEBUG_LOG 0
@@ -72,7 +72,6 @@ __global__ void
#endif // end of if (defined(__gfx9__))
}
#if 0
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
@@ -80,29 +79,29 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid,
karg.p_a_scale_grid,
karg.p_b_grid,
karg.p_b_scale_grid,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
p_shared1,
p_shared_0,
p_shared_1,
karg,
karg.a_element_op,
karg.b_element_op,
@@ -111,7 +110,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
ignore = karg;
#endif // end of if (defined(__gfx9__))
}
#endif
template <typename ALayout,
typename BLayout,
@@ -312,10 +310,18 @@ struct GridwiseMoeGemmMXBNS
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
constexpr auto permuted_desc = transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_xor_with_modulo_transform(make_tuple(Number<MN>{}, Number<K0>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
return transform_tensor_descriptor(
permuted_desc,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(Number<MNXdlPerWave / MNXdlPack>{},
Number<MNWaves>{},
@@ -398,12 +404,29 @@ struct GridwiseMoeGemmMXBNS
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
return a_grid_desc_ak0_m_ak1;
const auto a_grid_desc_permuted = transform_tensor_descriptor(
a_grid_desc_ak0_m_ak1,
make_tuple(make_pass_through_transform(K / KPerBlock),
make_xor_with_modulo_transform(make_tuple(M, AK0Number)),
make_pass_through_transform(AK1Value)),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
const auto a_grid_desc = transform_tensor_descriptor(
a_grid_desc_permuted,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)),
make_pass_through_transform(M),
make_pass_through_transform(AK1Value)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_grid_desc;
}
}
@@ -487,12 +510,29 @@ struct GridwiseMoeGemmMXBNS
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
return b_grid_desc_bk0_n_bk1;
const auto b_grid_desc_permuted = transform_tensor_descriptor(
b_grid_desc_bk0_n_bk1,
make_tuple(make_pass_through_transform(K / KPerBlock),
make_xor_with_modulo_transform(make_tuple(N, BK0Number)),
make_pass_through_transform(BK1Value)),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
const auto b_grid_desc = transform_tensor_descriptor(
b_grid_desc_permuted,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, BK0Number)),
make_pass_through_transform(N),
make_pass_through_transform(BK1Value)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_grid_desc;
}
}
@@ -810,9 +850,10 @@ struct GridwiseMoeGemmMXBNS
// A matrix in LDS memory, dst of blockwise copy
if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// contiguous in LDS
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock + ABlockLdsExtraM>{}, I1));
make_tuple(Number<AK0Number>{}, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
@@ -927,9 +968,10 @@ struct GridwiseMoeGemmMXBNS
// B matrix in LDS memory, dst of blockwise copy
if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// contiguous in lds
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(BK1Number, Number<KPerBlock + BBlockLdsExtraN>{}, I1));
make_tuple(BK1Number, Number<KPerBlock>{}, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
@@ -1492,12 +1534,9 @@ struct GridwiseMoeGemmMXBNS
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
// A matrix blockwise direct to LDS copy
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
@@ -1506,55 +1545,34 @@ struct GridwiseMoeGemmMXBNS
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
IndexType,
1,
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
gather_offsets);
1>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
gather_offsets);
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
make_multi_index(0, 0, 0));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
@@ -1790,7 +1808,6 @@ struct GridwiseMoeGemmMXBNS
m0 * M2 * M1 * M3 * M4 * M5 +
m1 * M2 * M3 * M4 * M5 +
imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
if constexpr(MulRoutedWeight)
{
topk_weights =
@@ -2131,7 +2148,6 @@ struct GridwiseMoeGemmMXBNS
}
}
#if 0
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
@@ -2144,8 +2160,8 @@ struct GridwiseMoeGemmMXBNS
const BScaleDataType* p_b_scale_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
void* p_shared1,
void* p_shared_0,
void* p_shared_1,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
@@ -2183,8 +2199,8 @@ struct GridwiseMoeGemmMXBNS
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
// static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
if(expert_block_id * MPerBlock >= max_token_id)
return;
@@ -2252,112 +2268,100 @@ struct GridwiseMoeGemmMXBNS
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
problem.N * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
// Gride buffer creation
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
#if 1
printf("blkx: %u, blky: %u, tidx: %u, a_grid_size: %ld\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
#endif
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
// A, B scale buffer
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
// dummy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
// A matrix blockwise direct to LDS copy
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
LDSTypeA,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
IndexType,
1,
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
gather_offsets);
// Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
gather_offsets);
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave / NXdlPack>{},
I1,
Number<NXdlPack>{},
Number<KRepeat>{},
Number<BK1Value>{}>,
Sequence<1, 2, 0, 3, 4>,
4,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
a_block_space_size_aligned * sizeof(ADataType)),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
a_block_space_size_aligned * sizeof(ADataType)),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
// Blockwise GEMM pipeline
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
@@ -2429,22 +2433,25 @@ struct GridwiseMoeGemmMXBNS
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
3,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
@@ -2472,7 +2479,7 @@ struct GridwiseMoeGemmMXBNS
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_block_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_blockwise_copy_up,
@@ -2495,23 +2502,23 @@ struct GridwiseMoeGemmMXBNS
else
{
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_grid_desc_ak0_m_ak1, // A
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_grid_desc_bk0_n_bk1, // B
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
a_scale_grid_desc_am_ak,
c_thread_buf, // C
a_scale_grid_desc_am_ak, // A scale
a_scale_thread_copy,
a_scale_grid_buf,
b_scale_grid_desc_bn_ak,
b_scale_grid_desc_bn_ak, // B scale
b_scale_thread_copy,
b_scale_grid_buf,
num_k_block_main_loop);
@@ -2522,89 +2529,102 @@ struct GridwiseMoeGemmMXBNS
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
// mul scales
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
static_assert(M4 == 4);
static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
static_assert(M5 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
vector_type<float, 4> topk_weights; // for gemm2 only
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
if constexpr(MulRoutedWeight)
{
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
p_ds_grid[I2] + m_pos);
}
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0 / MXdlPack,
n0 / NXdlPack,
m0 % MXdlPack,
n0 % NXdlPack,
m2 * M4 + m4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock +
m0 * M2 * M1 * M3 * M4 * M5 +
m1 * M2 * M3 * M4 * M5 +
imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) =
topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
topk_weights =
*c_style_pointer_cast<const vector_type<float, M5>*>(
p_ds_grid[I2] + m_pos);
}
}
static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation ==
Activation::silu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m5];
up = up * topk_weights.AsType<float>()[m5];
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m5];
up = up * topk_weights.AsType<float>()[m5];
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) =
topk_weights.AsType<float>()[m5] *
c_thread_buf_fp32[cidx];
}
}
});
});
});
});
});
@@ -2614,28 +2634,33 @@ struct GridwiseMoeGemmMXBNS
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
static_cast<CShuffleDataType*>(p_shared_0),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per
// shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per
// shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
// shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4,
M5)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave)
// per shuffle
N1, // N1 = NWave
N2, // N2 = NXdlPack
N3))), // N3 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{},
Sequence<0, 2, 4, 6, 7, 8>{},
Sequence<>{},
Sequence<1, 3, 5, 9>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
@@ -2647,8 +2672,8 @@ struct GridwiseMoeGemmMXBNS
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
@@ -2657,8 +2682,8 @@ struct GridwiseMoeGemmMXBNS
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
@@ -2666,36 +2691,39 @@ struct GridwiseMoeGemmMXBNS
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
CShuffleNXdlPerWavePerShuffle / NXdlPack,
I1,
I1,
M2,
N2,
M3,
I1,
M5,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
m_thread_data_on_block_idx[I5],
n_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
using EDataType = CDataType;
@@ -2716,16 +2744,18 @@ struct GridwiseMoeGemmMXBNS
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin =
@@ -2746,51 +2776,63 @@ struct GridwiseMoeGemmMXBNS
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr index_t scatter_weight_idx = 3; // hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// Sequence support
// arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferCluster,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
3, // index_t SrcVectorDim,
3, // index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors,
CShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
IndexType,
1, // ScatterDim
true, // OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
>{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op};
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// Sequence support
// arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferCluster,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
3, // index_t SrcVectorDim,
3, // index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors,
CShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
IndexType,
1, // ScatterDim
true, // OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
>{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op};
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
NXdlPerWave / NXdlPack,
1,
1,
MXdlPack,
NXdlPack,
M2,
1,
M4,
1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
CShuffleNXdlPerWavePerShuffle / NXdlPack,
1,
1,
MXdlPack,
NXdlPack,
M2,
1,
M4,
@@ -2870,7 +2912,6 @@ struct GridwiseMoeGemmMXBNS
});
}
}
#endif
};
} // namespace ck