mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
update buffer load to lds feature, build passed
This commit is contained in:
@@ -40,7 +40,7 @@ using B0DataType = F4;
|
||||
using B1DataType = XPackedDataType;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
@@ -62,8 +62,8 @@ struct MulABScaleExpertWeight
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
|
||||
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
@@ -580,7 +580,7 @@ int main(int argc, char* argv[])
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<CShuffleDataType> c_t_n({tokens, N});
|
||||
Tensor<float> c_t_n({tokens, N});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
|
||||
@@ -588,7 +588,8 @@ int main(int argc, char* argv[])
|
||||
B0DataType,
|
||||
XDataType,
|
||||
D2DataType,
|
||||
CShuffleDataType,
|
||||
float, // using float for Cshuffle type
|
||||
// in reference
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
|
||||
@@ -0,0 +1,336 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* Transfer that uses direct load instructions to copy data from global to LDS memory.
|
||||
*
|
||||
* Traditional loads first copy data from global to registers, and then from registers to LDS.
|
||||
* Direct loads do not need an intermediate step, data is copied directly from global to LDS,
|
||||
* without the use of additional registers.
|
||||
*
|
||||
* However, the instruction has limitations:
|
||||
* - each thread must copy exactly a single DWORD - 4 bytes;
|
||||
* - threads within a single wavefront must write consecutive DWORDS into LDS,
|
||||
* (data in global do not need to be contiguous, each thread might have its own offset).
|
||||
*
|
||||
* To make sure that all the transfers finished, the `waitcnt` instruction must be used with
|
||||
* `vmcnt` instead of `lgkmcnt`.
|
||||
*
|
||||
* Limitations of the transfer class:
|
||||
* - `SrcData` must be the same as `DstData` - no possibility to convert the data type in flight;
|
||||
* - `DstVectorDim` must be the last dimension;
|
||||
* - `SrcVectorDim` must be the last dimension if `ScalarPerVector` is greater than 1;
|
||||
* - `ScalarPerVector` times the number of bytes of `DstData` must be equal to a single DWORD = 4B
|
||||
* (for examlpe if `DstData` is fp32, then `ScalarPerVector` must be 1; if `DstData` is fp16,
|
||||
* `ScalarPerVector` must be 2);
|
||||
* - if `ScalarPerVector` is greater than 1, the contiguous dimension in src and dst must be
|
||||
* the same dimension;
|
||||
* - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64,
|
||||
* they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way
|
||||
* to guarantee that.
|
||||
*/
|
||||
template <typename ThreadGroup,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t ScalarPerVector,
|
||||
typename IndexType,
|
||||
index_t GatherDim = 1,
|
||||
bool SrcXor = true>
|
||||
struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
static constexpr auto block_slice_lengths = BlockSliceLengths{};
|
||||
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
|
||||
|
||||
static constexpr auto thread_single_load_size = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
|
||||
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
|
||||
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
|
||||
|
||||
static __device__ constexpr bool AreThreadClusterLengthsValid()
|
||||
{
|
||||
// Make sure that ThreadClusterLengths are set in a way that allows for contiguous writes to
|
||||
// LDS by the threads from a single wavefront.
|
||||
// Examples (assuming 64 threads in a wavefront, 128 in a thread block):
|
||||
// 1. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
|
||||
// data type = fp32 -> ScalarPerVector = 1
|
||||
// INVALID: ThreadClusterLengths = [4, 4, 8] since in the first iteration, threads 0-31
|
||||
// write [0, 0, 0] - [0, 3, 7] and thread 32 writes [1, 0, 0] instead of
|
||||
// [0, 4, 0].
|
||||
// VALID: ThreadClusterLengths = [2, 8, 8] or [1, 16, 8] since in the first iteration,
|
||||
// threads 0-63 write [0, 0, 0] - [0, 7, 7] -> 64 consecutive elements (DWORDs).
|
||||
// 2. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
|
||||
// data type = fp16 -> ScalarPerVector = 2
|
||||
// NOTE: ThreadClusterLengths must take into account that each thread writes two
|
||||
// elements (single DWORD) along the contiguous dimension.
|
||||
// INVALID: ThreadClusterLengths = [4, 4, 8] since each 8 threads would try to write
|
||||
// 8 * 2 elements of K1PerBlock and there are only 8;
|
||||
// ThreadClusterLengths = [4, 8, 4] since in the first iteration, threads 0-31
|
||||
// write [0, 0, 0] - [0, 7, 7] (7 since each writes 2 elements) and thread 32
|
||||
// writes [1, 0, 0] instead of [0, 8, 0].
|
||||
// VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
|
||||
// first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
|
||||
// elements = 64 consecutive DWORDs.
|
||||
int num_contiguous_dwords = 4;
|
||||
bool is_contiguous = true;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if(is_contiguous)
|
||||
{
|
||||
num_contiguous_dwords *= thread_cluster_lengths[nDim - i - 1];
|
||||
}
|
||||
if(thread_slice_lengths[nDim - i - 1] > 1)
|
||||
{
|
||||
CK_PRINT<Number<thread_slice_lengths[nDim - i - 1]>>();
|
||||
is_contiguous = false;
|
||||
}
|
||||
});
|
||||
constexpr index_t wavefront_size = get_warp_size();
|
||||
const bool wave_contiguous = num_contiguous_dwords % wavefront_size == 0;
|
||||
|
||||
bool thread_slice_lengths_correct = true;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if(thread_slice_lengths[i] <= 0)
|
||||
{
|
||||
thread_slice_lengths_correct = false;
|
||||
}
|
||||
});
|
||||
|
||||
return wave_contiguous && thread_slice_lengths_correct;
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_Gather_DirectLoad(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const StaticallyIndexedArray<IndexType, gather_num>& gather_offsets)
|
||||
: gather_offsets_(gather_offsets)
|
||||
{
|
||||
static_assert(ck::is_same_v<SrcData, DstData>,
|
||||
"Direct load transfer does not support datatypes conversion. Source and "
|
||||
"destination data types must be the same.");
|
||||
|
||||
static_assert(
|
||||
DstVectorDim == nDim - 1,
|
||||
"Direct load transfer requires the destination vector dimension to be the last one.");
|
||||
|
||||
static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
|
||||
"When loading more than one element per thread at once, the contiguous "
|
||||
"dimension must be the same between source and destination.");
|
||||
|
||||
// constexpr auto dword_bytes = 4;
|
||||
// constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
|
||||
// static_assert(bytes_per_thread_load == dword_bytes,
|
||||
// "Direct load transfer requires each thread to load exactly a single "
|
||||
// "DWORD of data.");
|
||||
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size(),
|
||||
"Inconsistent number of dimensions across lengths and descriptors.");
|
||||
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"The number of threads cannot be less than the number of elements in "
|
||||
"thread cluster lengths.");
|
||||
|
||||
// static_assert(
|
||||
// AreThreadClusterLengthsValid(),
|
||||
// "Thread cluster lengths are incorrect. They must be set in a way that allows a single
|
||||
// " "wavefront to write contiguous DWORDs into LDS memory. ");
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
|
||||
|
||||
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
|
||||
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
auto adjusted_src_origin_idx = [&]() {
|
||||
Index idx;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
idx(i) = i.value == GatherDim ? 0 : src_slice_origin_idx[Number<i>{}];
|
||||
});
|
||||
return idx;
|
||||
}();
|
||||
|
||||
src_coord_ = make_tensor_coordinate(src_desc, adjusted_src_origin_idx);
|
||||
src_slice_origin_ = adjusted_src_origin_idx;
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
dst_slice_origin_ = dst_slice_origin_idx;
|
||||
}
|
||||
|
||||
__device__ void ResetDstSliceWindow(const DstDesc& dst_desc)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global,
|
||||
"Source data must come from a global memory buffer.");
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"Destination data must be stored in an LDS memory buffer.");
|
||||
|
||||
static_assert(
|
||||
ck::is_same_v<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>,
|
||||
"SrcBuffer and SrcData data types must be consistent.");
|
||||
static_assert(
|
||||
ck::is_same_v<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>,
|
||||
"DstBuffer and DstData data types must be consistent.");
|
||||
|
||||
constexpr auto dst_access_lengths = thread_slice_lengths;
|
||||
|
||||
const auto dst_forward_steps = generate_steps(dst_desc, 1);
|
||||
const auto dst_backward_steps = generate_steps(dst_desc, -1);
|
||||
const auto src_forward_steps = generate_steps(src_desc, 1);
|
||||
const auto src_backward_steps = generate_steps(src_desc, -1);
|
||||
|
||||
// Loop over the destination block and copy data.
|
||||
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
// CK_PRINT<decltype(dst_access_lengths), decltype(ordered_dst_access_idx)>();
|
||||
auto gather_offset = gather_offsets_(Number<GatherDim>{});
|
||||
const auto src_offset = src_coord_.GetOffset() + gather_offset;
|
||||
const auto dst_offset = dst_coord_.GetOffset();
|
||||
// printf("Tid: %03d, src_offset: %d, dst_offset: %d\n", get_thread_local_1d_id(),
|
||||
// src_coord_.GetOffset(), dst_coord_.GetOffset());
|
||||
// Check if src data is not in the logic padding area.
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
|
||||
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
|
||||
dst_buf, src_offset, dst_offset, is_src_valid);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
|
||||
});
|
||||
move_on_dim_(i) &= i.value != GatherDim;
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// Decide whether to move forward or backward.
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Reset the destination slice since the entire buffer has been already filled.
|
||||
ResetDstSliceWindow(dst_desc);
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
src_slice_origin_ = src_slice_origin_ + step;
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_);
|
||||
}
|
||||
|
||||
template <typename DescType>
|
||||
__device__ auto generate_steps(const DescType& desc, int sign)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
step_idx(j) = (i.value == j.value) ? sign * thread_steps[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(desc, step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
Index src_slice_origin_;
|
||||
Index dst_slice_origin_;
|
||||
StaticallyIndexedArray<IndexType, gather_num> gather_offsets_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -299,20 +299,20 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
@@ -348,20 +348,20 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,13 +10,13 @@
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp"
|
||||
|
||||
#define DEBUG_LOG 0
|
||||
|
||||
@@ -72,7 +72,6 @@ __global__ void
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
@@ -80,29 +79,29 @@ template <typename GridwiseGemm,
|
||||
TailNumber TailNum = TailNumber::Even>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_grid,
|
||||
karg.p_b_scale_grid,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
@@ -111,7 +110,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -312,10 +310,18 @@ struct GridwiseMoeGemmMXBNS
|
||||
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
|
||||
{
|
||||
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
|
||||
constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
|
||||
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
constexpr auto permuted_desc = transform_tensor_descriptor(
|
||||
TileDesc_K0_MN_K1{},
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(Number<MN>{}, Number<K0>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
permuted_desc,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
|
||||
make_unmerge_transform(make_tuple(Number<MNXdlPerWave / MNXdlPack>{},
|
||||
Number<MNWaves>{},
|
||||
@@ -398,12 +404,29 @@ struct GridwiseMoeGemmMXBNS
|
||||
// not pad M or K
|
||||
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
const auto a_grid_desc_permuted = transform_tensor_descriptor(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_tuple(make_pass_through_transform(K / KPerBlock),
|
||||
make_xor_with_modulo_transform(make_tuple(M, AK0Number)),
|
||||
make_pass_through_transform(AK1Value)),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
|
||||
|
||||
const auto a_grid_desc = transform_tensor_descriptor(
|
||||
a_grid_desc_permuted,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)),
|
||||
make_pass_through_transform(M),
|
||||
make_pass_through_transform(AK1Value)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return a_grid_desc;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -487,12 +510,29 @@ struct GridwiseMoeGemmMXBNS
|
||||
// not pad N or K
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
const auto b_grid_desc_permuted = transform_tensor_descriptor(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_tuple(make_pass_through_transform(K / KPerBlock),
|
||||
make_xor_with_modulo_transform(make_tuple(N, BK0Number)),
|
||||
make_pass_through_transform(BK1Value)),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
|
||||
|
||||
const auto b_grid_desc = transform_tensor_descriptor(
|
||||
b_grid_desc_permuted,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, BK0Number)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(BK1Value)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return b_grid_desc;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -810,9 +850,10 @@ struct GridwiseMoeGemmMXBNS
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
// contiguous in LDS
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock + ABlockLdsExtraM>{}, I1));
|
||||
make_tuple(Number<AK0Number>{}, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
|
||||
// in some cases.
|
||||
@@ -927,9 +968,10 @@ struct GridwiseMoeGemmMXBNS
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
// contiguous in lds
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(BK1Number, Number<KPerBlock + BBlockLdsExtraN>{}, I1));
|
||||
make_tuple(BK1Number, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
@@ -1492,12 +1534,9 @@ struct GridwiseMoeGemmMXBNS
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
|
||||
// A matrix blockwise direct to LDS copy
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -1506,55 +1545,34 @@ struct GridwiseMoeGemmMXBNS
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
IndexType,
|
||||
1,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
gather_offsets);
|
||||
1>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
gather_offsets);
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
@@ -1790,7 +1808,6 @@ struct GridwiseMoeGemmMXBNS
|
||||
m0 * M2 * M1 * M3 * M4 * M5 +
|
||||
m1 * M2 * M3 * M4 * M5 +
|
||||
imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
|
||||
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
topk_weights =
|
||||
@@ -2131,7 +2148,6 @@ struct GridwiseMoeGemmMXBNS
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
@@ -2144,8 +2160,8 @@ struct GridwiseMoeGemmMXBNS
|
||||
const BScaleDataType* p_b_scale_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared,
|
||||
void* p_shared1,
|
||||
void* p_shared_0,
|
||||
void* p_shared_1,
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
@@ -2183,8 +2199,8 @@ struct GridwiseMoeGemmMXBNS
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
// static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
|
||||
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
|
||||
if(expert_block_id * MPerBlock >= max_token_id)
|
||||
return;
|
||||
@@ -2252,112 +2268,100 @@ struct GridwiseMoeGemmMXBNS
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
|
||||
problem.N * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
|
||||
|
||||
// Gride buffer creation
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
#if 1
|
||||
printf("blkx: %u, blky: %u, tidx: %u, a_grid_size: %ld\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
#endif
|
||||
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
// A, B scale buffer
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// dummy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
|
||||
|
||||
// A matrix blockwise direct to LDS copy
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
LDSTypeA,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
IndexType,
|
||||
1,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
gather_offsets);
|
||||
|
||||
// Thread-wise copy
|
||||
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
|
||||
auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
|
||||
1>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
gather_offsets);
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave / NXdlPack>{},
|
||||
I1,
|
||||
Number<NXdlPack>{},
|
||||
Number<KRepeat>{},
|
||||
Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3, 4>,
|
||||
4,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
|
||||
a_block_space_size_aligned * sizeof(ADataType)),
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
|
||||
a_block_space_size_aligned * sizeof(ADataType)),
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
|
||||
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
|
||||
|
||||
// Blockwise GEMM pipeline
|
||||
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
|
||||
@@ -2429,22 +2433,25 @@ struct GridwiseMoeGemmMXBNS
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride,
|
||||
@@ -2472,7 +2479,7 @@ struct GridwiseMoeGemmMXBNS
|
||||
a_grid_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_blockwise_copy_up,
|
||||
@@ -2495,23 +2502,23 @@ struct GridwiseMoeGemmMXBNS
|
||||
else
|
||||
{
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_grid_desc_ak0_m_ak1, // A
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_grid_desc_bk0_n_bk1, // B
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
a_scale_grid_desc_am_ak,
|
||||
c_thread_buf, // C
|
||||
a_scale_grid_desc_am_ak, // A scale
|
||||
a_scale_thread_copy,
|
||||
a_scale_grid_buf,
|
||||
b_scale_grid_desc_bn_ak,
|
||||
b_scale_grid_desc_bn_ak, // B scale
|
||||
b_scale_thread_copy,
|
||||
b_scale_grid_buf,
|
||||
num_k_block_main_loop);
|
||||
@@ -2522,89 +2529,102 @@ struct GridwiseMoeGemmMXBNS
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
|
||||
CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
|
||||
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
|
||||
|
||||
// mul scales
|
||||
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
|
||||
static_assert(M5 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
|
||||
vector_type<float, 4> topk_weights; // for gemm2 only
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
|
||||
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0 / MXdlPack,
|
||||
n0 / NXdlPack,
|
||||
m0 % MXdlPack,
|
||||
n0 % NXdlPack,
|
||||
m2 * M4 + m4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
|
||||
static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
|
||||
static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
|
||||
static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock +
|
||||
m0 * M2 * M1 * M3 * M4 * M5 +
|
||||
m1 * M2 * M3 * M4 * M5 +
|
||||
imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) =
|
||||
topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
|
||||
topk_weights =
|
||||
*c_style_pointer_cast<const vector_type<float, M5>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
}
|
||||
static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation ==
|
||||
Activation::silu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m5];
|
||||
up = up * topk_weights.AsType<float>()[m5];
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m5];
|
||||
up = up * topk_weights.AsType<float>()[m5];
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) =
|
||||
topk_weights.AsType<float>()[m5] *
|
||||
c_thread_buf_fp32[cidx];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -2614,28 +2634,33 @@ struct GridwiseMoeGemmMXBNS
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
static_cast<CShuffleDataType*>(p_shared_0),
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per
|
||||
// shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 * M3 * M4 = MPerXdl
|
||||
M3,
|
||||
M4)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per
|
||||
// shuffle
|
||||
N1, // N1 = NWave
|
||||
N2))), // N2 = NPerXdl
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
|
||||
// shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 * M3 * M4 = MPerXdl
|
||||
M3,
|
||||
M4,
|
||||
M5)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave)
|
||||
// per shuffle
|
||||
N1, // N1 = NWave
|
||||
N2, // N2 = NXdlPack
|
||||
N3))), // N3 = NPerXdl
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{},
|
||||
Sequence<0, 2, 4, 6, 7, 8>{},
|
||||
Sequence<>{},
|
||||
Sequence<1, 3, 5, 9>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -2647,8 +2672,8 @@ struct GridwiseMoeGemmMXBNS
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
@@ -2657,8 +2682,8 @@ struct GridwiseMoeGemmMXBNS
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
@@ -2666,36 +2691,39 @@ struct GridwiseMoeGemmMXBNS
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
|
||||
CShuffleNXdlPerWavePerShuffle / NXdlPack,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
N2,
|
||||
M3,
|
||||
I1,
|
||||
M5,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
|
||||
9,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
n_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
m_thread_data_on_block_idx[I5],
|
||||
n_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
using EDataType = CDataType;
|
||||
|
||||
@@ -2716,16 +2744,18 @@ struct GridwiseMoeGemmMXBNS
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
generate_tie(
|
||||
[&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
generate_tie(
|
||||
[&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin =
|
||||
@@ -2746,51 +2776,63 @@ struct GridwiseMoeGemmMXBNS
|
||||
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
|
||||
constexpr index_t scatter_weight_idx = 3; // hack fix felix
|
||||
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
|
||||
ThisThreadBlock,
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
|
||||
// Sequence support
|
||||
// arbitray type
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
|
||||
CDEBlockTransferCluster,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
|
||||
3, // index_t SrcVectorDim,
|
||||
3, // index_t DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
IndexType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
>{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
ThisThreadBlock,
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
|
||||
// Sequence support
|
||||
// arbitray type
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
|
||||
CDEBlockTransferCluster,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
|
||||
3, // index_t SrcVectorDim,
|
||||
3, // index_t DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
IndexType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
>{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
|
||||
NXdlPerWave / NXdlPack,
|
||||
1,
|
||||
1,
|
||||
MXdlPack,
|
||||
NXdlPack,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
|
||||
CShuffleNXdlPerWavePerShuffle / NXdlPack,
|
||||
1,
|
||||
1,
|
||||
MXdlPack,
|
||||
NXdlPack,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
@@ -2870,7 +2912,6 @@ struct GridwiseMoeGemmMXBNS
|
||||
});
|
||||
}
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user