mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
Moe gemm activation (#2026)
* fix useless code and remove usless oob * clang format * fix coredump in e2e test * fix2 * fix clang format * fix output oob * impl int64 but result not correct * int64 index ok now * input output all ok * fix uint32 * revert v1 test * use uint32 * mork to support 13w tokens * moe sorting fix moebuf * fix merge * update moe api fix aiter build * fix buid * fuse silu * silu ok * acale ok * add silu * change code * gemm2 ok * gufusion compatible ok, fix warnings * gu fusion for m32 m64 ok * support bf16 cshuffle * i4 gemm2 ok * i4 gemm2 ok and i4 gemm1 build * 16x16 run ok * change flops; change cshuffle dtype * fuse gelu silu act in moe gemm1 * fp8 with act ready * int4 act ready * remove useless changes * remove useless code change * fix clang format * add the arch limit of int4 moe gemm * fuse moe activation * fix fp8 16x16 * fix no quant case * fix bugs * fix fp8 gufusion bug * remove useless comments * refine activation code & complete moe example * fix int8 bugs * merge tkw1 --------- Co-authored-by: coderfeli <coderfeli@163.com> Co-authored-by: feli <felix.li@amd.com> Co-authored-by: illsilin <Illia.Silin@amd.com> Co-authored-by: root <root@hjbog-srdc-51.amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -41,6 +41,7 @@ template <typename SliceLengths,
|
||||
bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
|
||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||
// save addr computation
|
||||
typename IndexType,
|
||||
index_t GatherDim = 1,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
@@ -88,7 +89,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op,
|
||||
const StaticallyIndexedArray<index_t, gather_num>& gather_offsets)
|
||||
const StaticallyIndexedArray<IndexType, gather_num>& gather_offsets)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
|
||||
src_element_op_(src_element_op),
|
||||
@@ -221,7 +222,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
auto gather_offset =
|
||||
gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
|
||||
|
||||
const index_t ld_offset = src_coord_.GetOffset() + gather_offset;
|
||||
const IndexType ld_offset = src_coord_.GetOffset() + gather_offset;
|
||||
src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<bool>(src_data_idx_seq, true);
|
||||
|
||||
@@ -935,7 +936,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
DstCoord dst_coord_;
|
||||
const SrcElementwiseOperation src_element_op_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
|
||||
StaticallyIndexedArray<IndexType, gather_num> gather_offsets_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -43,6 +43,7 @@ template <typename SrcDatas,
|
||||
index_t DstScalarPerVector,
|
||||
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
typename IndexType,
|
||||
index_t ScatterDim = 1,
|
||||
bool OutputScatter = true,
|
||||
index_t ScatterWeightIdx = 3,
|
||||
@@ -153,7 +154,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// loop over space-filling curve
|
||||
@@ -172,31 +172,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
src_coords_[i]);
|
||||
|
||||
oob_val = oob_val & is_src_valid;
|
||||
if(i.value == ScatterWeightIdx)
|
||||
{
|
||||
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1,
|
||||
"scatter weight dim, should only one vec");
|
||||
constexpr auto iScatter =
|
||||
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
|
||||
src_vectors(i).template AsType<float>()(j) =
|
||||
scatter_weights(Number<iScatter>{});
|
||||
});
|
||||
}
|
||||
else if constexpr(SrcScalarPerVectors{}[i] == 1)
|
||||
{
|
||||
auto data_types = SrcDatas{};
|
||||
using DataType = remove_cvref_t<decltype(data_types[i])>;
|
||||
const auto tmp =
|
||||
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
|
||||
static_for<0, SrcScalarPerVector, 1>{}(
|
||||
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
|
||||
}
|
||||
else
|
||||
{
|
||||
src_vectors(i).template AsType<src_vector_t>()(I0) =
|
||||
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
|
||||
}
|
||||
src_vectors(i).template AsType<src_vector_t>()(I0) =
|
||||
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
|
||||
});
|
||||
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
@@ -412,7 +389,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
OOBCheck(thread_scratch_id);
|
||||
@@ -420,8 +397,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
|
||||
// loop over space-filling curve
|
||||
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
auto scatter_offset = 0;
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
IndexType scatter_offset = 0;
|
||||
if constexpr(OutputScatter)
|
||||
{
|
||||
constexpr auto iScatter =
|
||||
@@ -431,8 +408,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
|
||||
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
|
||||
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
// dst_coords_[i]);
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
|
||||
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
|
||||
@@ -488,10 +467,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
RunRead(src_descs, src_bufs);
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user