mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Fix do_weight in gemm1. Fix cshuffle_datatype. Clang-format
This commit is contained in:
@@ -40,7 +40,7 @@ using B1DataType = F32;
|
||||
// using EDataType = F16;
|
||||
using EDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CShuffleDataType = EDataType;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D2DataType>;
|
||||
|
||||
@@ -126,7 +126,7 @@ static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr bool MulRoutedWeight = false;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
#if 0
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
@@ -466,7 +466,7 @@ int main(int argc, char* argv[])
|
||||
Tensor<float> b_e_n_k({experts, K, N * 2});
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
Tensor<EDataType> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
|
||||
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
|
||||
|
||||
// handle scale before ref.
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
@@ -491,7 +491,7 @@ int main(int argc, char* argv[])
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeGemm1BlockScale<float,
|
||||
float,
|
||||
EDataType,
|
||||
float,
|
||||
D2DataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
|
||||
@@ -40,7 +40,7 @@ using B1DataType = F32;
|
||||
using EDataType = F16;
|
||||
// using EDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32; // todo: change to EDataType
|
||||
using CShuffleDataType = EDataType;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D2DataType>;
|
||||
|
||||
@@ -64,23 +64,25 @@ struct MulABScaleExpertWeight
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
|
||||
{
|
||||
// (void) d2;
|
||||
e = ck::type_convert<EDataType>(c * d2);
|
||||
// for real kernel use
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, float, float>(EDataType& e, const float& c, const float& d2) const
|
||||
{
|
||||
// for real kernel use
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, float>(float& e, const float& c, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
e = ck::type_convert<EDataType>(c * d2);
|
||||
}
|
||||
// template <>
|
||||
// __host__ __device__ constexpr void
|
||||
// operator()<float, float, float>(float& e, const float& c, const float& d2) const
|
||||
// {
|
||||
// // for reference cpu
|
||||
// e = ck::type_convert<EDataType>(c * d2);
|
||||
// }
|
||||
};
|
||||
|
||||
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
|
||||
@@ -213,7 +215,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 3)
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
@@ -317,9 +319,9 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-1.0, 1.0});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
@@ -467,7 +469,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
Tensor<float> a_t_k_k({tokens, topk, K});
|
||||
Tensor<float> b_e_n_k({experts, K, N});
|
||||
Tensor<EDataType> c_t_n({tokens, N});
|
||||
Tensor<float> c_t_n({tokens, N});
|
||||
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
@@ -496,7 +498,7 @@ int main(int argc, char* argv[])
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeGemm2BlockScale<float,
|
||||
float,
|
||||
EDataType,
|
||||
float,
|
||||
D2DataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
|
||||
@@ -143,19 +143,6 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
@@ -188,18 +175,6 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
}
|
||||
|
||||
template <typename SrcBuffers, typename DstBuffers>
|
||||
__device__ void Run(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
}
|
||||
|
||||
template <index_t ISrc>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
|
||||
|
||||
@@ -89,21 +89,21 @@ __global__ void
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -1206,7 +1206,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(math::integer_divide_ceil(problem.N , ScaleBlockN),
|
||||
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
|
||||
|
||||
@@ -1265,10 +1265,11 @@ struct GridwiseMoeGemmBlockScale
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride =
|
||||
__builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK));
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
@@ -1461,22 +1462,24 @@ struct GridwiseMoeGemmBlockScale
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
|
||||
const BScaleType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
auto b_scale_thread_copy_up =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleType,
|
||||
BScaleType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_desc),
|
||||
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ScaleSliceSizeK,
|
||||
1,
|
||||
false>(
|
||||
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
|
||||
BScaleType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_desc),
|
||||
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ScaleSliceSizeK,
|
||||
1,
|
||||
false>(
|
||||
b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
@@ -1577,39 +1580,36 @@ struct GridwiseMoeGemmBlockScale
|
||||
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
|
||||
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
|
||||
|
||||
if constexpr(IsInputGemm) // gu fusion, elementwise
|
||||
{
|
||||
static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
|
||||
static_assert(N4 == 4);
|
||||
const index_t n1 = get_warp_local_1d_id() / MWave;
|
||||
const index_t n3 = threadIdx.x % get_warp_size() / NPerXdl;
|
||||
static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock);
|
||||
static_assert(N4 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m2 = threadIdx.x % get_warp_size() % M2;
|
||||
|
||||
vector_type<float, 4> topk_weights;
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
|
||||
const index_t n_pos = block_n_id * NPerBlock + n0 * N1 * N2 * N3 * N4 +
|
||||
n1 * N2 * N3 * N4 + n2 * N3 * N4 + n3 * N4;
|
||||
if constexpr(MulRoutedWeight)
|
||||
float topk_weight;
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
|
||||
topk_weight = p_ds_grid[I0][m_pos];
|
||||
}
|
||||
static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
|
||||
static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, n2 * N4 + n4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion, elementwise
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, N4>*>(
|
||||
p_ds_grid[I0] + n_pos);
|
||||
}
|
||||
// if((blockIdx.x == 0) && (blockIdx.y == 0)){printf("m0:%d, n_pos:%d\n", static_cast<int>(m0), n_pos);}
|
||||
static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, n2 * N4 + n4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[n4];
|
||||
up = up * topk_weights.AsType<float>()[n4];
|
||||
gate = gate * topk_weight;
|
||||
up = up * topk_weight;
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
@@ -1625,8 +1625,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[n4];
|
||||
up = up * topk_weights.AsType<float>()[n4];
|
||||
gate = gate * topk_weight;
|
||||
up = up * topk_weight;
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
@@ -1636,11 +1636,18 @@ struct GridwiseMoeGemmBlockScale
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
@@ -1853,7 +1860,6 @@ struct GridwiseMoeGemmBlockScale
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
|
||||
auto dstidx = sfc_cde_block.GetIndex(access_id);
|
||||
const index_t c_token_pos =
|
||||
@@ -1861,18 +1867,11 @@ struct GridwiseMoeGemmBlockScale
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = token_offset < problem.NumTokens ? 1 : 0.0;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
else
|
||||
{
|
||||
const float* p_sorted_weights_2 = p_ds_grid[I0];
|
||||
weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
@@ -1893,8 +1892,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(c_grid_buf),
|
||||
scatter_offsets,
|
||||
scatter_weights);
|
||||
scatter_offsets);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
@@ -2019,10 +2017,11 @@ struct GridwiseMoeGemmBlockScale
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride =
|
||||
__builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N , ScaleBlockN) * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK));
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK));
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
@@ -2071,12 +2070,12 @@ struct GridwiseMoeGemmBlockScale
|
||||
IndexType,
|
||||
1,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
gather_offsets);
|
||||
make_multi_index(0, 0, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
gather_offsets);
|
||||
|
||||
// Thread-wise copy
|
||||
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
|
||||
@@ -2123,8 +2122,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
|
||||
//scale
|
||||
// scale
|
||||
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
|
||||
constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
|
||||
constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
|
||||
@@ -2160,7 +2158,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
scale_gather_offsets(m0) = static_cast<IndexType>(token_offset) *
|
||||
scale_gather_offsets(m0) = static_cast<IndexType>(token_offset) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK);
|
||||
});
|
||||
|
||||
@@ -2222,22 +2220,24 @@ struct GridwiseMoeGemmBlockScale
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
|
||||
const BScaleType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
auto b_scale_thread_copy_up =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleType,
|
||||
BScaleType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_desc),
|
||||
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ScaleSliceSizeK,
|
||||
1,
|
||||
false>(
|
||||
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
|
||||
BScaleType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_desc),
|
||||
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ScaleSliceSizeK,
|
||||
1,
|
||||
false>(
|
||||
b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
@@ -2316,7 +2316,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
//only used to get lengths
|
||||
// only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
|
||||
|
||||
@@ -2329,39 +2329,36 @@ struct GridwiseMoeGemmBlockScale
|
||||
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
|
||||
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
|
||||
|
||||
if constexpr(IsInputGemm) // gu fusion, elementwise
|
||||
{
|
||||
static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
|
||||
static_assert(N4 == 4);
|
||||
const index_t n1 = get_warp_local_1d_id() / MWave;
|
||||
const index_t n3 = threadIdx.x % get_warp_size() / NPerXdl;
|
||||
static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock);
|
||||
static_assert(N4 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m2 = threadIdx.x % get_warp_size() % M2;
|
||||
|
||||
vector_type<float, 4> topk_weights;
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
|
||||
const index_t n_pos = block_n_id * NPerBlock + n0 * N1 * N2 * N3 * N4 +
|
||||
n1 * N2 * N3 * N4 + n2 * N3 * N4 + n3 * N4;
|
||||
if constexpr(MulRoutedWeight)
|
||||
float topk_weight;
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
|
||||
topk_weight = p_ds_grid[I0][m_pos];
|
||||
}
|
||||
static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
|
||||
static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, n2 * N4 + n4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion, elementwise
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, N4>*>(
|
||||
p_ds_grid[I0] + n_pos);
|
||||
}
|
||||
// if((blockIdx.x == 0) && (blockIdx.y == 0)){printf("m0:%d, n_pos:%d\n", static_cast<int>(m0), n_pos);}
|
||||
static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, n2 * N4 + n4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[n4];
|
||||
up = up * topk_weights.AsType<float>()[n4];
|
||||
gate = gate * topk_weight;
|
||||
up = up * topk_weight;
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
@@ -2377,8 +2374,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[n4];
|
||||
up = up * topk_weights.AsType<float>()[n4];
|
||||
gate = gate * topk_weight;
|
||||
up = up * topk_weight;
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
@@ -2388,11 +2385,19 @@ struct GridwiseMoeGemmBlockScale
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
|
||||
}
|
||||
}
|
||||
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
@@ -2491,11 +2496,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
const DDataType* ptr_ = p_ds_grid[i];
|
||||
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
@@ -2605,7 +2607,6 @@ struct GridwiseMoeGemmBlockScale
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<IndexType, EMRepeats>
|
||||
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
|
||||
auto dstidx = sfc_cde_block.GetIndex(access_id);
|
||||
const index_t c_token_pos =
|
||||
@@ -2613,18 +2614,11 @@ struct GridwiseMoeGemmBlockScale
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = token_offset < problem.NumTokens ? 1 : 0.0;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
else
|
||||
{
|
||||
const float* p_sorted_weights_2 = p_ds_grid[I0];
|
||||
weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
@@ -2645,8 +2639,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(c_grid_buf),
|
||||
scatter_offsets,
|
||||
scatter_weights);
|
||||
scatter_offsets);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
@@ -2665,45 +2658,6 @@ struct GridwiseMoeGemmBlockScale
|
||||
I0,
|
||||
cde_lds_and_global_step);
|
||||
}
|
||||
|
||||
// // print C
|
||||
// printf("tid: %d, blkid: %d, "
|
||||
// "c_thread_buf = <%1.f, %1.f, %1.f>\n "
|
||||
// // "%1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f,"
|
||||
// // "%1.f, %1.f, %1.f, %1.f, %1.f, %1.f\n"
|
||||
// , get_thread_local_1d_id(), block_m_id,
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<0>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<1>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<2>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<3>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<4>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<5>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<6>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<7>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<8>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<9>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<10>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<11>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<12>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<13>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<14>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<3>{}]);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -262,143 +262,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcBuffers,
|
||||
index_t ThreadScratchId = 0,
|
||||
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// loop over space-filling curve
|
||||
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
|
||||
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
|
||||
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
|
||||
|
||||
bool oob_val = true;
|
||||
|
||||
// copy data from src_bufs into src_vectors
|
||||
static_for<0, nSrc, 1>{}([&](auto i) {
|
||||
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
|
||||
src_coords_[i]);
|
||||
|
||||
oob_val = oob_val & is_src_valid;
|
||||
if(i.value == ScatterWeightIdx)
|
||||
{
|
||||
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1,
|
||||
"scatter weight dim, should only one vec");
|
||||
constexpr auto iScatter =
|
||||
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
|
||||
src_vectors(i).template AsType<float>()(j) =
|
||||
scatter_weights(Number<iScatter>{});
|
||||
});
|
||||
}
|
||||
else if constexpr(SrcScalarPerVectors{}[i] == 1)
|
||||
{
|
||||
auto data_types = SrcDatas{};
|
||||
using DataType = remove_cvref_t<decltype(data_types[i])>;
|
||||
const auto tmp =
|
||||
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
|
||||
static_for<0, SrcScalarPerVector, 1>{}(
|
||||
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
|
||||
}
|
||||
else
|
||||
{
|
||||
src_vectors(i).template AsType<src_vector_t>()(I0) =
|
||||
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
|
||||
}
|
||||
});
|
||||
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
if constexpr(is_detected<is_pack8_invocable_t, decltype(element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(element_op_)::is_pack8_invocable)
|
||||
return math::min(8, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack4_invocable_t, decltype(element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(element_op_)::is_pack4_invocable)
|
||||
return math::min(4, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack2_invocable_t, decltype(element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(element_op_)::is_pack2_invocable)
|
||||
return math::min(2, SrcScalarPerVector);
|
||||
}
|
||||
return 1;
|
||||
};
|
||||
|
||||
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
|
||||
|
||||
// apply pointwise function
|
||||
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& {
|
||||
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
|
||||
|
||||
using elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
|
||||
|
||||
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iDst) -> auto& {
|
||||
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
|
||||
|
||||
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// apply pointwise function
|
||||
// pointwise function signature:
|
||||
// element_op_(dst_data_refs[I0],
|
||||
// dst_data_refs[I1],
|
||||
// ...,
|
||||
// src_data_refs[I0],
|
||||
// src_data_refs[I1],
|
||||
// ...)
|
||||
unpack2(element_op_, dst_data_refs, src_data_refs);
|
||||
});
|
||||
|
||||
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
|
||||
oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val;
|
||||
|
||||
// move coordinate
|
||||
if constexpr(iAccess.value != src_num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
|
||||
|
||||
static_for<0, nSrc, 1>{}([&](auto i) {
|
||||
move_tensor_coordinate(src_descs[i],
|
||||
src_coords_(i),
|
||||
make_tensor_coordinate_step(src_descs[i], forward_step));
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// move coordinate back to slice origin (or not)
|
||||
static_for<0, nSrc, 1>{}([&](auto i) {
|
||||
if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
|
||||
{
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#if 1
|
||||
template <index_t ThreadScratchId = 0>
|
||||
__device__ void OOBCheck(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
@@ -608,22 +471,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
}
|
||||
|
||||
template <typename SrcBuffers,
|
||||
typename DstBuffers,
|
||||
enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
|
||||
DstDescs::Size() == DstBuffers::Size(),
|
||||
bool> = false>
|
||||
__device__ void Run(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
if constexpr(src_num_access == 0)
|
||||
|
||||
Reference in New Issue
Block a user