fix scale_m gather load for a8w8 moe

This commit is contained in:
Feng Shijie
2025-09-09 04:18:48 +00:00
parent 027f5311c6
commit b09b6cdce9
2 changed files with 107 additions and 25 deletions

View File

@@ -404,6 +404,77 @@ struct tile_scatter_gather
});
}
template <typename DistributedTensor,
typename Ys2PageIdxMap,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
const Ys2PageIdxMap& ys_to_page_idx_map,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
const auto idx_gather = ys_to_page_idx_map(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
const vector_t vec_value = [&]() {
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
bool_constant<oob_conditional_check>{});
}
else
{
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
valids_[idx_gather],
bool_constant<oob_conditional_check>{});
}
}();
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
});
// ys_to_page_idx_map handles all offset calculation.
// So ther is no need to move thread coordinate redundantly.
});
});
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,

View File

@@ -50,9 +50,9 @@ struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
ck_tile::index_t stride_C_,
ck_tile::index_t n_padded_zeros_ = 0,
ck_tile::index_t k_padded_zeros_ = 0,
ScaleM scale_m_ = {},
ScaleN scale_n_ = {},
ExpertBias exp_bias_ = {})
ScaleM scale_m_ = {},
ScaleN scale_n_ = {},
ExpertBias exp_bias_ = {})
: ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>(a_ptr_,
b_ptr_,
{}, // d_ptr_array
@@ -524,7 +524,7 @@ struct MoeFlatmmKernel
const int expert_id,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const int n_pad_zeros = 12)
const int n_pad_zeros = 12)
{
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
@@ -596,7 +596,7 @@ struct MoeFlatmmKernel
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
@@ -878,14 +878,22 @@ struct MoeFlatmmKernel
const auto scale_m_coord =
output_acc_tile_distr.calculate_index(); // 2d thread offset, [i_row, i_col]
constexpr ck_tile::index_t ScaleMRepeat =
decltype(output_acc_tile_distr)::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}];
statically_indexed_array<ck_tile::index_t, ScaleMRepeat> scale_m_offsets;
constexpr index_t kM2 = 4; // Val-dim
constexpr index_t kM1 = get_warp_size() / NPerXdl; // Thr-dim
constexpr index_t kM0 = MPerXdl / kM1 / kM2; // Var-dim
static_for<0, ScaleMRepeat, 1>{}([&](auto m0) {
const auto row_idx =
coord_m + m0 * (TilePartitioner::MPerBlock / ScaleMRepeat) + scale_m_coord[I0];
scale_m_offsets[m0] = row_to_token_idx(row_idx);
constexpr index_t ScaleMRepeat = MRepeat * kM0 * kM2;
statically_indexed_array<index_t, ScaleMRepeat> scale_m_offsets;
static_for<0, MRepeat, 1>{}([&](auto mIter) {
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
const auto row_idx =
coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0];
scale_m_offsets[mIter * number<kM0 * kM2>{} + m0 * number<kM2>{} + m2] =
row_to_token_idx(row_idx);
});
});
});
constexpr int DynamicTileOffsetFlag = 0;
@@ -924,18 +932,18 @@ struct MoeFlatmmKernel
}
};
auto scale_m_window = make_tile_scatter_gather(
make_naive_tensor_view<address_space_enum::global>(
kargs.scale_m.ptr,
make_tuple(kargs.M, 1),
make_tuple(scale_stride_m, 0),
number<ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1>{},
number<1>{}),
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, 0}, // offset m is included in gather offsets
output_acc_tile_distr,
scale_m_offsets);
auto scale_m_window =
make_tile_scatter_gather(make_naive_tensor_view<address_space_enum::global>(
kargs.scale_m.ptr,
make_tuple(kargs.M, 1),
make_tuple(scale_stride_m, 0),
number<1>{}, // gather load can't vectorize
number<1>{}),
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, 0}, // offset m is included in gather offsets
output_acc_tile_distr,
scale_m_offsets);
auto scale_n_window = make_tile_window(
make_naive_tensor_view<address_space_enum::global>(
@@ -1015,7 +1023,10 @@ struct MoeFlatmmKernel
if constexpr(!MXFP4_Pipeline)
{
scale_m_buffer = load_tile(scale_m_window);
scale_m_window.load(scale_m_buffer, [](auto ys_coord) {
return ys_coord.at(I0) * number<kM0 * kM2>{} + ys_coord.at(I2) * number<kM2>{} +
ys_coord.at(I3);
});
scale_n_buffer = load_tile(scale_n_window);
if constexpr(IsGateUp)
scale_n_up_buffer = load_tile(scale_n_up_window);