mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
fix scale_m gather load for a8w8 moe
This commit is contained in:
@@ -404,6 +404,77 @@ struct tile_scatter_gather
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
typename Ys2PageIdxMap,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
|
||||
const Ys2PageIdxMap& ys_to_page_idx_map,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
const auto idx_gather = ys_to_page_idx_map(idx_ys_start);
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value = [&]() {
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
valids_[idx_gather],
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
|
||||
});
|
||||
|
||||
// ys_to_page_idx_map handles all offset calculation.
|
||||
// So ther is no need to move thread coordinate redundantly.
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: currently async load only implemented in inline asm
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
|
||||
@@ -50,9 +50,9 @@ struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
|
||||
ck_tile::index_t stride_C_,
|
||||
ck_tile::index_t n_padded_zeros_ = 0,
|
||||
ck_tile::index_t k_padded_zeros_ = 0,
|
||||
ScaleM scale_m_ = {},
|
||||
ScaleN scale_n_ = {},
|
||||
ExpertBias exp_bias_ = {})
|
||||
ScaleM scale_m_ = {},
|
||||
ScaleN scale_n_ = {},
|
||||
ExpertBias exp_bias_ = {})
|
||||
: ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>(a_ptr_,
|
||||
b_ptr_,
|
||||
{}, // d_ptr_array
|
||||
@@ -524,7 +524,7 @@ struct MoeFlatmmKernel
|
||||
const int expert_id,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const int n_pad_zeros = 12)
|
||||
const int n_pad_zeros = 12)
|
||||
{
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -596,7 +596,7 @@ struct MoeFlatmmKernel
|
||||
|
||||
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
|
||||
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
|
||||
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
@@ -878,14 +878,22 @@ struct MoeFlatmmKernel
|
||||
const auto scale_m_coord =
|
||||
output_acc_tile_distr.calculate_index(); // 2d thread offset, [i_row, i_col]
|
||||
|
||||
constexpr ck_tile::index_t ScaleMRepeat =
|
||||
decltype(output_acc_tile_distr)::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}];
|
||||
statically_indexed_array<ck_tile::index_t, ScaleMRepeat> scale_m_offsets;
|
||||
constexpr index_t kM2 = 4; // Val-dim
|
||||
constexpr index_t kM1 = get_warp_size() / NPerXdl; // Thr-dim
|
||||
constexpr index_t kM0 = MPerXdl / kM1 / kM2; // Var-dim
|
||||
|
||||
static_for<0, ScaleMRepeat, 1>{}([&](auto m0) {
|
||||
const auto row_idx =
|
||||
coord_m + m0 * (TilePartitioner::MPerBlock / ScaleMRepeat) + scale_m_coord[I0];
|
||||
scale_m_offsets[m0] = row_to_token_idx(row_idx);
|
||||
constexpr index_t ScaleMRepeat = MRepeat * kM0 * kM2;
|
||||
statically_indexed_array<index_t, ScaleMRepeat> scale_m_offsets;
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
const auto row_idx =
|
||||
coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0];
|
||||
scale_m_offsets[mIter * number<kM0 * kM2>{} + m0 * number<kM2>{} + m2] =
|
||||
row_to_token_idx(row_idx);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
constexpr int DynamicTileOffsetFlag = 0;
|
||||
@@ -924,18 +932,18 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
};
|
||||
|
||||
auto scale_m_window = make_tile_scatter_gather(
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_m.ptr,
|
||||
make_tuple(kargs.M, 1),
|
||||
make_tuple(scale_stride_m, 0),
|
||||
number<ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1>{},
|
||||
number<1>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, 0}, // offset m is included in gather offsets
|
||||
output_acc_tile_distr,
|
||||
scale_m_offsets);
|
||||
auto scale_m_window =
|
||||
make_tile_scatter_gather(make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_m.ptr,
|
||||
make_tuple(kargs.M, 1),
|
||||
make_tuple(scale_stride_m, 0),
|
||||
number<1>{}, // gather load can't vectorize
|
||||
number<1>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, 0}, // offset m is included in gather offsets
|
||||
output_acc_tile_distr,
|
||||
scale_m_offsets);
|
||||
|
||||
auto scale_n_window = make_tile_window(
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -1015,7 +1023,10 @@ struct MoeFlatmmKernel
|
||||
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
{
|
||||
scale_m_buffer = load_tile(scale_m_window);
|
||||
scale_m_window.load(scale_m_buffer, [](auto ys_coord) {
|
||||
return ys_coord.at(I0) * number<kM0 * kM2>{} + ys_coord.at(I2) * number<kM2>{} +
|
||||
ys_coord.at(I3);
|
||||
});
|
||||
scale_n_buffer = load_tile(scale_n_window);
|
||||
if constexpr(IsGateUp)
|
||||
scale_n_up_buffer = load_tile(scale_n_up_window);
|
||||
|
||||
Reference in New Issue
Block a user