mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Ck moe hot fix (#1979)
* fix useless code and remove usless oob * clang format * fix coredump in e2e test * fix2 * fix clang format * fix output oob * clang format * rm useless comments --------- Co-authored-by: coderfeli <coderfeli@163.com> Co-authored-by: illsilin <Illia.Silin@amd.com>
This commit is contained in:
@@ -1563,12 +1563,8 @@ struct GridwiseMoeGemm
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<index_t, EMRepeats>
|
||||
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets;
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock]
|
||||
// & 0xff000000) >> 24;
|
||||
|
||||
auto dstidx = sfc_cde_block.GetIndex(access_id);
|
||||
const index_t c_token_pos =
|
||||
@@ -1576,7 +1572,9 @@ struct GridwiseMoeGemm
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]];
|
||||
float weight = token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
|
||||
: 0.0;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
@@ -2074,9 +2072,6 @@ struct GridwiseMoeGemm
|
||||
StaticallyIndexedArray<index_t, EMRepeats>
|
||||
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock]
|
||||
// & 0xff000000) >> 24;
|
||||
|
||||
auto dstidx = sfc_cde_block.GetIndex(access_id);
|
||||
const index_t c_token_pos =
|
||||
@@ -2084,7 +2079,9 @@ struct GridwiseMoeGemm
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]];
|
||||
float weight = token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
|
||||
: 0.0;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
|
||||
@@ -430,13 +430,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
}
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
|
||||
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
|
||||
|
||||
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_offset, true, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
|
||||
dst_offset, is_dst_valid, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
|
||||
});
|
||||
|
||||
// move coordinate
|
||||
|
||||
Reference in New Issue
Block a user