Fix do_weight in gemm1. Fix cshuffle_datatype. Clang-format

This commit is contained in:
OscarXu
2025-05-28 18:29:06 +08:00
parent fc9ef98e7b
commit 772debdf8f
5 changed files with 152 additions and 374 deletions

View File

@@ -40,7 +40,7 @@ using B1DataType = F32;
// using EDataType = F16;
using EDataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CShuffleDataType = EDataType;
using D2DataType = F32;
using DsDataType = ck::Tuple<D2DataType>;
@@ -126,7 +126,7 @@ static constexpr ck::index_t Scale_Block_K = 128;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr bool MulRoutedWeight = false;
static constexpr bool MulRoutedWeight = true;
#if 0
static constexpr ck::index_t MPerBlock = 32;
@@ -466,7 +466,7 @@ int main(int argc, char* argv[])
Tensor<float> b_e_n_k({experts, K, N * 2});
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor<EDataType> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
// handle scale before ref.
for(int t = 0; t < tokens; ++t)
@@ -491,7 +491,7 @@ int main(int argc, char* argv[])
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeGemm1BlockScale<float,
float,
EDataType,
float,
D2DataType,
AccDataType,
PassThrough,

View File

@@ -40,7 +40,7 @@ using B1DataType = F32;
using EDataType = F16;
// using EDataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32; // todo: change to EDataType
using CShuffleDataType = EDataType;
using D2DataType = F32;
using DsDataType = ck::Tuple<D2DataType>;
@@ -64,23 +64,25 @@ struct MulABScaleExpertWeight
__host__ __device__ constexpr void
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
{
// (void) d2;
e = ck::type_convert<EDataType>(c * d2);
// for real kernel use
(void)d2;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void
operator()<EDataType, float, float>(EDataType& e, const float& c, const float& d2) const
{
// for real kernel use
(void)d2;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d2) const
{
// for reference cpu
e = ck::type_convert<EDataType>(c * d2);
}
// template <>
// __host__ __device__ constexpr void
// operator()<float, float, float>(float& e, const float& c, const float& d2) const
// {
// // for reference cpu
// e = ck::type_convert<EDataType>(c * d2);
// }
};
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
@@ -213,7 +215,7 @@ int main(int argc, char* argv[])
{
// use default case
}
else if(argc == 3)
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
@@ -317,9 +319,9 @@ int main(int argc, char* argv[])
{
case 0: break;
case 1:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-1.0, 1.0});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
@@ -467,7 +469,7 @@ int main(int argc, char* argv[])
Tensor<float> a_t_k_k({tokens, topk, K});
Tensor<float> b_e_n_k({experts, K, N});
Tensor<EDataType> c_t_n({tokens, N});
Tensor<float> c_t_n({tokens, N});
for(int t = 0; t < tokens; ++t)
{
@@ -496,7 +498,7 @@ int main(int argc, char* argv[])
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeGemm2BlockScale<float,
float,
EDataType,
float,
D2DataType,
AccDataType,
PassThrough,

View File

@@ -143,19 +143,6 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
}
}
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id);
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
@@ -188,18 +175,6 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
RunWrite(dst_descs, dst_bufs, scatter_offsets);
}
template <typename SrcBuffers, typename DstBuffers>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
{
RunRead(src_descs, src_bufs, scatter_weights);
RunWrite(dst_descs, dst_bufs, scatter_offsets);
}
template <index_t ISrc>
__device__ void
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)

View File

@@ -89,21 +89,21 @@ __global__ void
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
karg.p_a_scale_grid,
karg.p_b_scale_grid,
p_shared,
p_shared1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
karg.p_a_scale_grid,
karg.p_b_scale_grid,
p_shared,
p_shared1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
@@ -1206,7 +1206,7 @@ struct GridwiseMoeGemmBlockScale
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N , ScaleBlockN),
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
@@ -1265,10 +1265,11 @@ struct GridwiseMoeGemmBlockScale
}
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride =
__builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockK));
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockK));
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
@@ -1461,22 +1462,24 @@ struct GridwiseMoeGemmBlockScale
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const BScaleType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
ScaleSliceSizeK,
1,
false>(
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
ScaleSliceSizeK,
1,
false>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
a_grid_desc_ak0_m_ak1,
@@ -1577,39 +1580,36 @@ struct GridwiseMoeGemmBlockScale
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
if constexpr(IsInputGemm) // gu fusion, elementwise
{
static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
static_assert(N4 == 4);
const index_t n1 = get_warp_local_1d_id() / MWave;
const index_t n3 = threadIdx.x % get_warp_size() / NPerXdl;
static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
static_assert(M0 * M1 * M2 == MPerBlock);
static_assert(N4 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m2 = threadIdx.x % get_warp_size() % M2;
vector_type<float, 4> topk_weights;
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
const index_t n_pos = block_n_id * NPerBlock + n0 * N1 * N2 * N3 * N4 +
n1 * N2 * N3 * N4 + n2 * N3 * N4 + n3 * N4;
if constexpr(MulRoutedWeight)
float topk_weight;
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
if constexpr(MulRoutedWeight)
{
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
topk_weight = p_ds_grid[I0][m_pos];
}
static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, n2 * N4 + n4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion, elementwise
{
topk_weights = *c_style_pointer_cast<const vector_type<float, N4>*>(
p_ds_grid[I0] + n_pos);
}
// if((blockIdx.x == 0) && (blockIdx.y == 0)){printf("m0:%d, n_pos:%d\n", static_cast<int>(m0), n_pos);}
static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, n2 * N4 + n4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[n4];
up = up * topk_weights.AsType<float>()[n4];
gate = gate * topk_weight;
up = up * topk_weight;
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
@@ -1625,8 +1625,8 @@ struct GridwiseMoeGemmBlockScale
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[n4];
up = up * topk_weights.AsType<float>()[n4];
gate = gate * topk_weight;
up = up * topk_weight;
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
@@ -1636,11 +1636,18 @@ struct GridwiseMoeGemmBlockScale
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
});
}
else
{
if constexpr(MulRoutedWeight)
{
c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
}
}
});
});
});
}
});
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
@@ -1853,7 +1860,6 @@ struct GridwiseMoeGemmBlockScale
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
auto dstidx = sfc_cde_block.GetIndex(access_id);
const index_t c_token_pos =
@@ -1861,18 +1867,11 @@ struct GridwiseMoeGemmBlockScale
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = token_offset < problem.NumTokens ? 1 : 0.0;
if constexpr(IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
else
{
const float* p_sorted_weights_2 = p_ds_grid[I0];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
});
block_sync_lds();
@@ -1893,8 +1892,7 @@ struct GridwiseMoeGemmBlockScale
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf),
scatter_offsets,
scatter_weights);
scatter_offsets);
if constexpr(access_id < num_access - 1)
{
@@ -2019,10 +2017,11 @@ struct GridwiseMoeGemmBlockScale
}
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride =
__builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N , ScaleBlockN) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockK));
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockK));
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
@@ -2071,12 +2070,12 @@ struct GridwiseMoeGemmBlockScale
IndexType,
1,
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
gather_offsets);
make_multi_index(0, 0, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
gather_offsets);
// Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
@@ -2123,8 +2122,7 @@ struct GridwiseMoeGemmBlockScale
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
//scale
// scale
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
@@ -2160,7 +2158,7 @@ struct GridwiseMoeGemmBlockScale
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
scale_gather_offsets(m0) = static_cast<IndexType>(token_offset) *
scale_gather_offsets(m0) = static_cast<IndexType>(token_offset) *
math::integer_divide_ceil(problem.K, ScaleBlockK);
});
@@ -2222,22 +2220,24 @@ struct GridwiseMoeGemmBlockScale
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const BScaleType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
ScaleSliceSizeK,
1,
false>(
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
ScaleSliceSizeK,
1,
false>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
a_grid_desc_ak0_m_ak1,
@@ -2316,7 +2316,7 @@ struct GridwiseMoeGemmBlockScale
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// TODO: hacky, fix it!
//only used to get lengths
// only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
@@ -2329,39 +2329,36 @@ struct GridwiseMoeGemmBlockScale
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
if constexpr(IsInputGemm) // gu fusion, elementwise
{
static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
static_assert(N4 == 4);
const index_t n1 = get_warp_local_1d_id() / MWave;
const index_t n3 = threadIdx.x % get_warp_size() / NPerXdl;
static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
static_assert(M0 * M1 * M2 == MPerBlock);
static_assert(N4 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m2 = threadIdx.x % get_warp_size() % M2;
vector_type<float, 4> topk_weights;
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
const index_t n_pos = block_n_id * NPerBlock + n0 * N1 * N2 * N3 * N4 +
n1 * N2 * N3 * N4 + n2 * N3 * N4 + n3 * N4;
if constexpr(MulRoutedWeight)
float topk_weight;
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
if constexpr(MulRoutedWeight)
{
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
topk_weight = p_ds_grid[I0][m_pos];
}
static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, n2 * N4 + n4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion, elementwise
{
topk_weights = *c_style_pointer_cast<const vector_type<float, N4>*>(
p_ds_grid[I0] + n_pos);
}
// if((blockIdx.x == 0) && (blockIdx.y == 0)){printf("m0:%d, n_pos:%d\n", static_cast<int>(m0), n_pos);}
static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, n2 * N4 + n4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[n4];
up = up * topk_weights.AsType<float>()[n4];
gate = gate * topk_weight;
up = up * topk_weight;
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
@@ -2377,8 +2374,8 @@ struct GridwiseMoeGemmBlockScale
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[n4];
up = up * topk_weights.AsType<float>()[n4];
gate = gate * topk_weight;
up = up * topk_weight;
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
@@ -2388,11 +2385,19 @@ struct GridwiseMoeGemmBlockScale
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
});
}
else
{
if constexpr(MulRoutedWeight)
{
c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
}
}
});
});
});
}
});
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
@@ -2491,11 +2496,8 @@ struct GridwiseMoeGemmBlockScale
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
const DDataType* ptr_ = p_ds_grid[i];
return make_dynamic_buffer<AddressSpaceEnum::Global>(
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
@@ -2605,7 +2607,6 @@ struct GridwiseMoeGemmBlockScale
// make sure it's safe to write to LDS
StaticallyIndexedArray<IndexType, EMRepeats>
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
auto dstidx = sfc_cde_block.GetIndex(access_id);
const index_t c_token_pos =
@@ -2613,18 +2614,11 @@ struct GridwiseMoeGemmBlockScale
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = token_offset < problem.NumTokens ? 1 : 0.0;
if constexpr(IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
else
{
const float* p_sorted_weights_2 = p_ds_grid[I0];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
});
block_sync_lds();
@@ -2645,8 +2639,7 @@ struct GridwiseMoeGemmBlockScale
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf),
scatter_offsets,
scatter_weights);
scatter_offsets);
if constexpr(access_id < num_access - 1)
{
@@ -2665,45 +2658,6 @@ struct GridwiseMoeGemmBlockScale
I0,
cde_lds_and_global_step);
}
// // print C
// printf("tid: %d, blkid: %d, "
// "c_thread_buf = <%1.f, %1.f, %1.f>\n "
// // "%1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f,"
// // "%1.f, %1.f, %1.f, %1.f, %1.f, %1.f\n"
// , get_thread_local_1d_id(), block_m_id,
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<0>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<1>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<2>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<3>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<4>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<5>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<6>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<7>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<8>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<9>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<10>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<11>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<12>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<13>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<14>{}],
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
// AsType<AccDataType>()[Number<3>{}]);
});
}
}

View File

@@ -262,143 +262,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
});
}
template <typename SrcBuffers,
index_t ThreadScratchId = 0,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// loop over space-filling curve
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
bool oob_val = true;
// copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
src_coords_[i]);
oob_val = oob_val & is_src_valid;
if(i.value == ScatterWeightIdx)
{
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1,
"scatter weight dim, should only one vec");
constexpr auto iScatter =
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
src_vectors(i).template AsType<float>()(j) =
scatter_weights(Number<iScatter>{});
});
}
else if constexpr(SrcScalarPerVectors{}[i] == 1)
{
auto data_types = SrcDatas{};
using DataType = remove_cvref_t<decltype(data_types[i])>;
const auto tmp =
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
}
else
{
src_vectors(i).template AsType<src_vector_t>()(I0) =
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
}
});
constexpr auto get_elem_op_vec_len = []() {
if constexpr(is_detected<is_pack8_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack4_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack2_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
return 1;
};
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
// apply pointwise function
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
using elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
},
Number<nSrc>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
},
Number<nDst>{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2(element_op_, dst_data_refs, src_data_refs);
});
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val;
// move coordinate
if constexpr(iAccess.value != src_num_access - 1)
{
constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
static_for<0, nSrc, 1>{}([&](auto i) {
move_tensor_coordinate(src_descs[i],
src_coords_(i),
make_tensor_coordinate_step(src_descs[i], forward_step));
});
}
});
// move coordinate back to slice origin (or not)
static_for<0, nSrc, 1>{}([&](auto i) {
if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
{
const auto src_reset_step =
make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep());
move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step);
}
});
}
#if 1
template <index_t ThreadScratchId = 0>
__device__ void OOBCheck(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
@@ -608,22 +471,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
RunWrite(dst_descs, dst_bufs, scatter_offsets);
}
template <typename SrcBuffers,
typename DstBuffers,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
DstDescs::Size() == DstBuffers::Size(),
bool> = false>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
{
RunRead(src_descs, src_bufs, scatter_weights);
RunWrite(dst_descs, dst_bufs, scatter_offsets);
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
if constexpr(src_num_access == 0)