gemm1 up-only pass. GU WIP

This commit is contained in:
OscarXu
2025-05-12 21:08:42 +08:00
parent f1a534f6e7
commit fe8bb251da
6 changed files with 171 additions and 18 deletions

View File

@@ -269,7 +269,7 @@ int main(int argc, char* argv[])
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / (valid_tile_num / experts);
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;

View File

@@ -268,7 +268,7 @@ int main(int argc, char* argv[])
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0};
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_B = (N * 2 + Scale_Block_N - 1) / Scale_Block_N;
ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2;
ck::index_t KBatch = 1;
@@ -279,7 +279,7 @@ int main(int argc, char* argv[])
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / (valid_tile_num / experts);
expert_ids.mData[i] = (i / valid_tile_num) / experts;
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
@@ -303,7 +303,7 @@ int main(int argc, char* argv[])
{Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<B1DataType> b1_e_n_k(HostTensorDescriptor(
{experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N * 2 + Scale_Block_N - 1) / Scale_Block_N},
{experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N * 2},
{(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
@@ -381,8 +381,8 @@ int main(int argc, char* argv[])
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
// a0_t_k.savetxt("a.txt");
// expert_ids.savetxt("expert_ids.txt", "int");
// sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
// d2_e_n.savetxt("d2_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
@@ -503,8 +503,8 @@ int main(int argc, char* argv[])
expert_ids,
max_token_id,
MPerBlock,
a0_t_k,
b0_e_n_k,
a_t_k,
b_e_n_k,
d2_e_n,
c_t_k_n,
PassThrough{},

View File

@@ -131,6 +131,18 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
}
}
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
}
}
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
@@ -170,7 +182,18 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets)
{
RunRead(src_descs, src_bufs);
RunWrite(dst_descs, dst_bufs, scatter_offsets);
}
template <typename SrcBuffers, typename DstBuffers>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
{
RunRead(src_descs, src_bufs, scatter_weights);

View File

@@ -1206,7 +1206,7 @@ struct GridwiseMoeGemmBlockScale
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N * (IsInputGemm ? 2 : 1), ScaleBlockN),
make_tuple(math::integer_divide_ceil(problem.N , ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
@@ -1267,7 +1267,7 @@ struct GridwiseMoeGemmBlockScale
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride =
__builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N * (IsInputGemm ? 2 : 1), ScaleBlockN) *
__builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockK));
// N0, K0, Blocksize*KPack
@@ -1958,7 +1958,7 @@ struct GridwiseMoeGemmBlockScale
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N * (IsInputGemm ? 2 : 1), ScaleBlockN),
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
@@ -2019,7 +2019,7 @@ struct GridwiseMoeGemmBlockScale
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride =
__builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N * (IsInputGemm ? 2 : 1), ScaleBlockN) *
__builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N , ScaleBlockN) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockK));
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
@@ -2158,8 +2158,8 @@ struct GridwiseMoeGemmBlockScale
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
scale_gather_offsets(m0) =
token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK);
scale_gather_offsets(m0) = static_cast<IndexType>(token_offset) *
math::integer_divide_ceil(problem.K, ScaleBlockK);
});
// printf("blkid: %d, tid:%d, a_thread_offset: %d, scale_gather_offsets: %d\n", block_m_id,
@@ -2222,7 +2222,7 @@ struct GridwiseMoeGemmBlockScale
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up =
ThreadwiseTensorSliceTransfer_v2<BScaleType,

View File

@@ -149,6 +149,119 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template <typename SrcBuffers,
index_t ThreadScratchId = 0,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// loop over space-filling curve
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
bool oob_val = true;
// copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
src_coords_[i]);
oob_val = oob_val & is_src_valid;
src_vectors(i).template AsType<src_vector_t>()(I0) =
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
});
constexpr auto get_elem_op_vec_len = []() {
if constexpr(is_detected<is_pack8_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack4_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack2_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
return 1;
};
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
// apply pointwise function
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
using elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
},
Number<nSrc>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
},
Number<nDst>{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2(element_op_, dst_data_refs, src_data_refs);
});
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val;
// move coordinate
if constexpr(iAccess.value != src_num_access - 1)
{
constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
static_for<0, nSrc, 1>{}([&](auto i) {
move_tensor_coordinate(src_descs[i],
src_coords_(i),
make_tensor_coordinate_step(src_descs[i], forward_step));
});
}
});
// move coordinate back to slice origin (or not)
static_for<0, nSrc, 1>{}([&](auto i) {
if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
{
const auto src_reset_step =
make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep());
move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step);
}
});
}
template <typename SrcBuffers,
index_t ThreadScratchId = 0,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
@@ -413,7 +526,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
OOBCheck(thread_scratch_id);
@@ -485,6 +598,21 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
DstDescs::Size() == DstBuffers::Size(),
bool> = false>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets)
{
RunRead(src_descs, src_bufs);
RunWrite(dst_descs, dst_bufs, scatter_offsets);
}
template <typename SrcBuffers,
typename DstBuffers,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
DstDescs::Size() == DstBuffers::Size(),
bool> = false>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,

View File

@@ -163,7 +163,7 @@ struct ReferenceMoeGemm1BlockScale : public device::BaseOperator
arg.c_element_op_(v_c, v_acc);
arg.c_element_op_(v_c_up, v_acc_up);
#if 0
if constexpr(ActivationType == 1)
{
if constexpr(is_same_v<BDataType, pk_i4_t>)
@@ -184,6 +184,8 @@ struct ReferenceMoeGemm1BlockScale : public device::BaseOperator
tensor_operation::element_wise::Gelu{}(v_c, v_c);
arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up;
}
#endif
arg.c_t_k_n_(t, topk_id, n) = v_c;
}
};