mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Refactor scale and bias function, encapsulate scale/bias pointer to a tensor view and load data by tilewise operation
This commit is contained in:
@@ -309,6 +309,8 @@ struct CShuffleEpilogue
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>;
|
||||
static_assert(GetVectorSizeC() % kN2 == 0);
|
||||
|
||||
constexpr auto dram_tile_distribution =
|
||||
make_static_tile_distribution(IntrThreadShuffleEncode{});
|
||||
|
||||
@@ -324,7 +326,6 @@ struct CShuffleEpilogue
|
||||
|
||||
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
// auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
@@ -332,6 +333,7 @@ struct CShuffleEpilogue
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
// transpose <kM2 x NRepeat> thread matrix
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 0 * NRepeat] = type_convert<ODataType>(
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0]);
|
||||
c_out_tensor.get_thread_buffer()[n_idx + 1 * NRepeat] = type_convert<ODataType>(
|
||||
@@ -342,8 +344,6 @@ struct CShuffleEpilogue
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3]);
|
||||
});
|
||||
|
||||
// c_out_tensor = cast_tile<ODataType>(c_out_tensor_fp32);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
@@ -472,16 +472,16 @@ struct CShuffleEpilogue
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
typename ScaleMWindow,
|
||||
typename ScaleNWindow,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem,
|
||||
ScaleM scale_m,
|
||||
ScaleN scale_n)
|
||||
ScaleMWindow scale_m_window,
|
||||
ScaleNWindow scale_n_window)
|
||||
{
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = 4;
|
||||
@@ -499,9 +499,43 @@ struct CShuffleEpilogue
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>;
|
||||
static_assert(GetVectorSizeC() % kN2 == 0);
|
||||
|
||||
constexpr auto dram_tile_distribution =
|
||||
make_static_tile_distribution(IntrThreadShuffleEncode{});
|
||||
|
||||
constexpr int DynamicTileOffsetFlag = 0;
|
||||
|
||||
auto permute_scale_n_view_1 = transform_tensor_view(
|
||||
scale_n_window.get_bottom_tensor_view(),
|
||||
make_tuple(make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
|
||||
make_unmerge_transform(make_tuple(number<DynamicTileOffsetFlag>{},
|
||||
number<NWave>{},
|
||||
number<NPerXdl>{},
|
||||
number<NRepeat>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2, 3, 4>{}));
|
||||
auto permute_scale_n_view = transform_tensor_view(
|
||||
permute_scale_n_view_1,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
|
||||
make_merge_transform_v3_division_mod(make_tuple(number<DynamicTileOffsetFlag>{},
|
||||
number<NRepeat>{},
|
||||
number<NWave>{},
|
||||
number<NPerXdl>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 4, 2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
auto scale_m_window_with_dist = make_tile_window(
|
||||
scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution());
|
||||
auto scale_n_window_with_dist = make_tile_window(permute_scale_n_view,
|
||||
scale_n_window.get_window_lengths(),
|
||||
scale_n_window.get_window_origin(),
|
||||
o_acc_tile.get_tile_distribution());
|
||||
|
||||
auto scale_m_buffer = load_tile(scale_m_window_with_dist);
|
||||
auto scale_n_buffer = load_tile(scale_n_window_with_dist);
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
@@ -519,56 +553,39 @@ struct CShuffleEpilogue
|
||||
make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
const index_t iMLane = get_lane_id() / NPerXdl;
|
||||
const index_t iNLane = get_lane_id() % NPerXdl;
|
||||
|
||||
float vec_scale_A[kM2 * MRepeat];
|
||||
float vec_scale_B[NRepeat];
|
||||
|
||||
_Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
|
||||
{
|
||||
vec_scale_B[i] = scale_n[i + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl];
|
||||
}
|
||||
_Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
|
||||
{
|
||||
vec_scale_A[i * kM2 + 0] =
|
||||
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 1] =
|
||||
scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 2] =
|
||||
scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 3] =
|
||||
scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
}
|
||||
constexpr int NumAccPerEpiTile = NRepeat * c_warp_y_lengths.product();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
shuffle_acc[mIter].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 0] *= vec_scale_B[n_idx];
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 1] *= vec_scale_B[n_idx];
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 2] *= vec_scale_B[n_idx];
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 3] *= vec_scale_B[n_idx];
|
||||
});
|
||||
auto epi_scale_n = scale_n_buffer.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
static_for<0, NumAccPerEpiTile, 1>{}(
|
||||
[&](auto i) { shuffle_acc[mIter].get_thread_buffer()[i] *= epi_scale_n[i]; });
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
auto epi_scale_m = scale_m_buffer.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
// transpose <kM2 x NRepeat> thread matrix
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 0 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0] *
|
||||
vec_scale_A[mIter * kM2 + 0];
|
||||
epi_scale_m[n_idx * c_warp_y_lengths.product() + 0];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 1 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1] *
|
||||
vec_scale_A[mIter * kM2 + 1];
|
||||
epi_scale_m[n_idx * c_warp_y_lengths.product() + 1];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 2 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2] *
|
||||
vec_scale_A[mIter * kM2 + 2];
|
||||
epi_scale_m[n_idx * c_warp_y_lengths.product() + 2];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 3 * NRepeat] =
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3] *
|
||||
vec_scale_A[mIter * kM2 + 3];
|
||||
epi_scale_m[n_idx * c_warp_y_lengths.product() + 3];
|
||||
});
|
||||
|
||||
c_out_tensor = cast_tile<ODataType>(c_out_tensor_fp32);
|
||||
@@ -592,16 +609,16 @@ struct CShuffleEpilogue
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
typename ScaleMWindow,
|
||||
typename ScaleNWindow,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<!EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem,
|
||||
ScaleM scale_m,
|
||||
ScaleN scale_n)
|
||||
ScaleMWindow scale_m_window,
|
||||
ScaleNWindow scale_n_window)
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
|
||||
@@ -650,63 +667,32 @@ struct CShuffleEpilogue
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr int kM2 = 4; // Val
|
||||
constexpr int kM1 = (64 / NPerXdl); // Thr
|
||||
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
|
||||
auto scale_m_window_with_dist = make_tile_window(
|
||||
scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution());
|
||||
auto scale_n_window_with_dist = make_tile_window(
|
||||
scale_n_window, scale_n_window.get_window_origin(), o_acc_tile.get_tile_distribution());
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
const index_t iMLane = get_lane_id() / NPerXdl;
|
||||
const index_t iNLane = get_lane_id() % NPerXdl;
|
||||
auto scale_m_buffer = load_tile(scale_m_window_with_dist);
|
||||
auto scale_n_buffer = load_tile(scale_n_window_with_dist);
|
||||
|
||||
float vec_scale_A[kM0 * kM2 * MRepeat];
|
||||
float vec_scale_B[NRepeat];
|
||||
constexpr int NumAccPerEpiTile =
|
||||
NumMXdlPerWavePerShuffle * NumNXdlPerWavePerShuffle * c_warp_y_lengths.product();
|
||||
auto epi_tile_idx_slice =
|
||||
[&](const auto& acc_tile_like_tensor, auto epi_m_idx, auto epi_n_idx) {
|
||||
return acc_tile_like_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
|
||||
epi_n_idx * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
};
|
||||
|
||||
_Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
|
||||
{
|
||||
vec_scale_B[i] = scale_n[i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
}
|
||||
_Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
|
||||
{
|
||||
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||
{
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 0] =
|
||||
scale_m[0 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 1] =
|
||||
scale_m[1 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 2] =
|
||||
scale_m[2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 3] =
|
||||
scale_m[3 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
}
|
||||
}
|
||||
lds_tile[0].get_thread_buffer() = epi_tile_idx_slice(o_acc_tile, number<0>{}, number<0>{});
|
||||
|
||||
lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||
{
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 0] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 1] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 2] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 3] * vec_scale_B[n_xdl];
|
||||
}
|
||||
});
|
||||
});
|
||||
auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, number<0>{}, number<0>{});
|
||||
auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, number<0>{}, number<0>{});
|
||||
static_for<0, NumAccPerEpiTile, 1>{}(
|
||||
[&](auto i) { lds_tile[0].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i]; });
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
constexpr int read_stage = iAccess % 2;
|
||||
@@ -724,40 +710,14 @@ struct CShuffleEpilogue
|
||||
|
||||
if constexpr(iAccess < num_access - 1)
|
||||
{
|
||||
lds_tile[write_stage].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter * NumMXdlPerWavePerShuffle,
|
||||
nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||
{
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 0] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 1] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 2] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 3] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
}
|
||||
});
|
||||
lds_tile[write_stage].get_thread_buffer() =
|
||||
epi_tile_idx_slice(o_acc_tile, mIter, nIter);
|
||||
|
||||
epi_scale_m = epi_tile_idx_slice(scale_m_buffer, mIter, nIter);
|
||||
epi_scale_n = epi_tile_idx_slice(scale_n_buffer, mIter, nIter);
|
||||
|
||||
static_for<0, NumAccPerEpiTile, 1>{}([&](auto i) {
|
||||
lds_tile[write_stage].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i];
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -35,44 +35,29 @@ struct FlatmmScalePointer
|
||||
static constexpr int GranularityK = SharedGranularityK;
|
||||
|
||||
const float* ptr;
|
||||
index_t scale_stride = 1;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t stride)
|
||||
: ptr(ptr_), scale_stride(stride)
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
|
||||
: ptr(ptr_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
|
||||
{
|
||||
FlatmmScalePointer ret;
|
||||
// if constexpr(GranularityMN == 0)
|
||||
// {
|
||||
// ret.scalar = scalar;
|
||||
// }
|
||||
// else if constexpr(GranularityMN == 1)
|
||||
// {
|
||||
// ret.ptr = ptr + offset;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// ret.ptr = ptr + offset / GranularityMN;
|
||||
// }
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const
|
||||
{
|
||||
if constexpr(GranularityMN == 1)
|
||||
if constexpr(GranularityMN == 0)
|
||||
{
|
||||
return ptr[i];
|
||||
ret.ptr = ptr + offset / GranularityK;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ptr[i / GranularityMN];
|
||||
ret.ptr = ptr + offset / GranularityMN / GranularityK;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN>
|
||||
@@ -83,54 +68,39 @@ struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
|
||||
static_assert(GranularityMN != 0);
|
||||
|
||||
union
|
||||
{
|
||||
const float* ptr;
|
||||
float scalar; // if shared granularity is 0, all rows/columns use the same scale value
|
||||
};
|
||||
const float* ptr;
|
||||
index_t length;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(float scalar_) : scalar(scalar_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t stride)
|
||||
: ptr(ptr_)
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_)
|
||||
: ptr(ptr_), length(length_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
|
||||
{
|
||||
FlatmmScalePointer ret;
|
||||
if constexpr(GranularityMN == 0)
|
||||
if constexpr(GranularityMN == 1)
|
||||
{
|
||||
ret.scalar = scalar;
|
||||
}
|
||||
else if constexpr(GranularityMN == 1)
|
||||
{
|
||||
ret.ptr = ptr + offset;
|
||||
ret.ptr = ptr + offset;
|
||||
ret.length = length - offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
ret.ptr = ptr + offset / GranularityMN;
|
||||
ret.ptr = ptr + offset / GranularityMN;
|
||||
ret.length = length - offset / GranularityMN;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer& advance() { return *this; }
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const
|
||||
{
|
||||
if constexpr(GranularityMN == 0)
|
||||
{
|
||||
return scalar;
|
||||
}
|
||||
else if constexpr(GranularityMN == 1)
|
||||
{
|
||||
return ptr[i];
|
||||
}
|
||||
// with additional oob check
|
||||
if constexpr(GranularityMN == 1)
|
||||
return i < length ? ptr[i] : 0;
|
||||
else
|
||||
{
|
||||
return ptr[i / GranularityMN];
|
||||
}
|
||||
return i / GranularityMN < length ? ptr[i / GranularityMN] : 0;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -141,14 +111,11 @@ struct FlatmmScalePointer<-1, 0>
|
||||
static constexpr int GranularityMN = -1;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(float) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, [[maybe_unused]] index_t stride)
|
||||
{
|
||||
}
|
||||
const float* ptr = nullptr;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer& advance() { return *this; }
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
|
||||
{
|
||||
@@ -679,7 +646,45 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view);
|
||||
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
|
||||
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
|
||||
|
||||
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
|
||||
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-token scale
|
||||
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-channel scale
|
||||
|
||||
static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1,
|
||||
"only support per-tensor or per-row scaling");
|
||||
static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1,
|
||||
"only support per-tensor or per-column scaling");
|
||||
|
||||
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_m_ptr.ptr,
|
||||
make_tuple(
|
||||
kargs.M / ScaleGranularityM,
|
||||
ScaleGranularityKA == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKA),
|
||||
make_tuple(scale_stride_m, 0),
|
||||
number<ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1>{},
|
||||
number<1>{});
|
||||
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_n_ptr.ptr,
|
||||
make_tuple(
|
||||
ScaleGranularityKB == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKB,
|
||||
kargs.N / ScaleGranularityN),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number<ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1>{},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(a_tensor_view,
|
||||
b_flat_tensor_view,
|
||||
ds_tensor_view,
|
||||
e_tensor_view,
|
||||
scale_m_view,
|
||||
scale_n_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
@@ -745,7 +750,12 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view);
|
||||
return make_tuple(a_pad_view,
|
||||
b_flat_tensor_view,
|
||||
ds_pad_view,
|
||||
e_pad_view,
|
||||
views.at(number<4>{}),
|
||||
views.at(number<5>{}));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
@@ -805,7 +815,28 @@ struct FlatmmKernel
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
|
||||
constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
|
||||
auto scale_m_window = make_tile_window(
|
||||
views.at(number<4>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<ScaleGranularityKA == 0 ? TilePartitioner::NPerBlock
|
||||
: TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
auto scale_n_window = make_tile_window(
|
||||
views.at(number<5>{}),
|
||||
make_tuple(number<ScaleGranularityKB == 0 ? TilePartitioner::MPerBlock
|
||||
: TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
ds_block_window,
|
||||
e_block_window,
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
@@ -837,6 +868,9 @@ struct FlatmmKernel
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
|
||||
|
||||
auto scale_m_window = gemm_tile_windows.at(number<4>{});
|
||||
auto scale_n_window = gemm_tile_windows.at(number<5>{});
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
|
||||
{
|
||||
@@ -847,8 +881,8 @@ struct FlatmmKernel
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
kargs.scale_m_ptr + block_idx_m,
|
||||
kargs.scale_n_ptr + block_idx_n);
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
|
||||
@@ -197,8 +197,10 @@ struct MoeFlatmmKernel
|
||||
// MXF4_Pipeline only has the of scale B and granularityK is 32
|
||||
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
|
||||
static constexpr int MXFP4N_Pack = 2;
|
||||
static constexpr int MXFP4K_Pack = 2;
|
||||
|
||||
static constexpr int N_Pack = MXFP4_Pipeline ? MXFP4N_Pack : 1;
|
||||
static constexpr int K_Pack = MXFP4_Pipeline ? MXFP4K_Pack : 1;
|
||||
|
||||
static constexpr int WeightPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
@@ -659,13 +661,16 @@ struct MoeFlatmmKernel
|
||||
{0, // offset_m is included when construct C-scatter-window offsets
|
||||
output_N_offset});
|
||||
|
||||
constexpr int GranularityK = 32;
|
||||
constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
|
||||
constexpr int XDLPerLoadScaleB =
|
||||
MXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4
|
||||
|
||||
auto scale_block_window = make_tile_window(
|
||||
views.at(I3),
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / GranularityK>{}),
|
||||
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
auto scale_block_window =
|
||||
make_tile_window(views.at(I3),
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * K_Pack *
|
||||
XDLPerLoadScaleB / GranularityK>{}),
|
||||
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
|
||||
}
|
||||
@@ -771,11 +776,10 @@ struct MoeFlatmmKernel
|
||||
smem_ptr_pong);
|
||||
}
|
||||
}();
|
||||
using AccTile = decltype(c_block_tile);
|
||||
|
||||
// Run EpiloguePipeline Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(number<2>{});
|
||||
|
||||
// Run EpiloguePipeline
|
||||
{
|
||||
using EpiProblem = typename EpiloguePipeline::Problem;
|
||||
using ODataType = typename EpiloguePipeline::ODataType;
|
||||
@@ -786,9 +790,11 @@ struct MoeFlatmmKernel
|
||||
constexpr index_t MPerIterationShuffle = EpiloguePipeline::MPerIterationShuffle;
|
||||
constexpr index_t NPerIterationShuffle = EpiloguePipeline::NPerIterationShuffle;
|
||||
|
||||
constexpr index_t EpiVectorSizeC = EpiloguePipeline::GetVectorSizeC();
|
||||
constexpr index_t MRepeat = EpiloguePipeline::MRepeat;
|
||||
constexpr index_t NRepeat = EpiloguePipeline::NRepeat;
|
||||
constexpr index_t EpiVectorSizeC = EpiloguePipeline::GetVectorSizeC();
|
||||
constexpr index_t MRepeat = EpiloguePipeline::MRepeat;
|
||||
constexpr index_t NRepeat = EpiloguePipeline::NRepeat;
|
||||
constexpr index_t OutputNRepeat = IsGateUp ? NRepeat / 2 : NRepeat;
|
||||
constexpr index_t BlockedXDLN_PerWarp = EpiloguePipeline::BlockedXDLN_PerWarp;
|
||||
|
||||
static_assert(!IsGateUp || NumNXdlPerWavePerShuffle % 2 == 0);
|
||||
|
||||
@@ -805,6 +811,195 @@ struct MoeFlatmmKernel
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<ODataType*>(smem_ptr_ping), lds_block_desc);
|
||||
|
||||
constexpr int ScaleGranularityM = decltype(kargs.scale_m)::GranularityMN;
|
||||
constexpr int ScaleGranularityN = decltype(kargs.scale_n)::GranularityMN;
|
||||
|
||||
constexpr index_t scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-token scale
|
||||
constexpr index_t scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-channel scale
|
||||
|
||||
auto output_acc_tile_distr =
|
||||
make_static_tile_distribution(detail::make_embed_tile_distribution_encoding(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MRepeat, MWave>, sequence<OutputNRepeat, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{},
|
||||
typename CWarpDstr::DstrEncode{}));
|
||||
|
||||
const auto scale_m_coord =
|
||||
output_acc_tile_distr.calculate_index(); // 2d thread offset, [i_row, i_col]
|
||||
|
||||
constexpr ck_tile::index_t ScaleMRepeat =
|
||||
decltype(output_acc_tile_distr)::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}];
|
||||
statically_indexed_array<ck_tile::index_t, ScaleMRepeat> scale_m_offsets;
|
||||
|
||||
static_for<0, ScaleMRepeat, 1>{}([&](auto m0) {
|
||||
const auto row_idx =
|
||||
coord_m + m0 * (TilePartitioner::MPerBlock / ScaleMRepeat) + scale_m_coord[I0];
|
||||
scale_m_offsets[m0] = row_to_token_idx(row_idx);
|
||||
});
|
||||
|
||||
constexpr int DynamicTileOffsetFlag = 0;
|
||||
|
||||
constexpr bool EnableBias = decltype(kargs.exp_bias)::GranularityMN != -1;
|
||||
|
||||
auto make_col_broadcast_window = [&](auto scale_pointer) {
|
||||
return make_tile_window(
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_n.ptr + expert_id * kargs.N,
|
||||
make_tuple(1, kargs.N),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number<ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1>{},
|
||||
number<1>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<IsGateUp ? TilePartitioner::NPerBlock / 2
|
||||
: TilePartitioner::NPerBlock>{}),
|
||||
{0, IsGateUp ? coord_n / 2 : coord_n},
|
||||
output_acc_tile_distr);
|
||||
};
|
||||
|
||||
auto permute_tensor_view = [&](auto naive_view, auto is_needed_to_permute_N_PACK) {
|
||||
if constexpr(!is_needed_to_permute_N_PACK)
|
||||
{
|
||||
return naive_view;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto view1 = transform_tensor_view(
|
||||
naive_view,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
|
||||
make_unmerge_transform(make_tuple(number<DynamicTileOffsetFlag>{},
|
||||
number<NRepeat / N_Pack>{},
|
||||
number<NWave>{},
|
||||
number<N_Pack>{},
|
||||
number<NPerXdl>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2, 3, 4, 5>{}));
|
||||
return transform_tensor_view(
|
||||
view1,
|
||||
make_tuple(make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<DynamicTileOffsetFlag>{},
|
||||
number<NRepeat / N_Pack>{},
|
||||
number<N_Pack>{},
|
||||
number<NWave>{},
|
||||
number<NPerXdl>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2, 4, 3, 5>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
};
|
||||
|
||||
auto scale_m_window = make_tile_scatter_gather(
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_m.ptr,
|
||||
make_tuple(kargs.M, 1),
|
||||
make_tuple(scale_stride_m, 0),
|
||||
number<ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1>{},
|
||||
number<1>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, 0}, // offset m is included in gather offsets
|
||||
output_acc_tile_distr,
|
||||
scale_m_offsets);
|
||||
|
||||
auto scale_n_window = make_tile_window(
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_n.ptr + expert_id * kargs.N,
|
||||
make_tuple(1, kargs.N),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number<ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1>{},
|
||||
number<1>{}), // MXF4_Pipeline does't use scale_n, so there is no need to
|
||||
// permute as n_pack
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<IsGateUp ? TilePartitioner::NPerBlock / 2
|
||||
: TilePartitioner::NPerBlock>{}),
|
||||
{0, IsGateUp ? coord_n / 2 : coord_n},
|
||||
output_acc_tile_distr);
|
||||
|
||||
auto scale_n_up_window = make_tile_window(
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_n.ptr + expert_id * kargs.N + kargs.N / 2,
|
||||
make_tuple(1, kargs.N),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number<ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1>{},
|
||||
number<1>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock / 2>{}),
|
||||
{0, coord_n / 2},
|
||||
output_acc_tile_distr);
|
||||
|
||||
auto exp_bias_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.exp_bias.ptr + expert_id * kargs.N,
|
||||
make_tuple(1, kargs.N),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
|
||||
auto exp_bias_window = make_tile_window(
|
||||
permute_tensor_view(exp_bias_view, number<(MXFP4_Pipeline && !IsInputGemm)>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<IsGateUp ? TilePartitioner::NPerBlock / 2
|
||||
: TilePartitioner::NPerBlock>{}),
|
||||
{0, IsGateUp ? coord_n / 2 : coord_n},
|
||||
output_acc_tile_distr);
|
||||
|
||||
auto exp_bias_up_window =
|
||||
make_tile_window(make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.exp_bias.ptr + expert_id * kargs.N + kargs.N / 2,
|
||||
make_tuple(1, kargs.N),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock / 2>{}),
|
||||
{0, coord_n / 2},
|
||||
output_acc_tile_distr);
|
||||
|
||||
auto exp_weight_window =
|
||||
make_tile_window(make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const float*>(kargs.p_sorted_expert_weights),
|
||||
make_tuple(kargs.M, 1),
|
||||
make_tuple(1, 0),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{coord_m, 0},
|
||||
output_acc_tile_distr);
|
||||
|
||||
using ScaleMBuffer = decltype(load_tile(scale_m_window));
|
||||
using ScaleNBuffer = decltype(load_tile(scale_n_window));
|
||||
using ExpBiasBuffer = decltype(load_tile(exp_bias_window));
|
||||
using ExpWeightBuffer = decltype(load_tile(exp_weight_window));
|
||||
|
||||
ScaleMBuffer scale_m_buffer;
|
||||
ScaleNBuffer scale_n_buffer, scale_n_up_buffer;
|
||||
|
||||
ExpBiasBuffer exp_bias_buffer, exp_bias_up_buffer;
|
||||
ExpWeightBuffer exp_weight_buffer;
|
||||
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
{
|
||||
scale_m_buffer = load_tile(scale_m_window);
|
||||
scale_n_buffer = load_tile(scale_n_window);
|
||||
if constexpr(IsGateUp)
|
||||
scale_n_up_buffer = load_tile(scale_n_up_window);
|
||||
}
|
||||
|
||||
if constexpr(EnableBias)
|
||||
{
|
||||
exp_bias_buffer = load_tile(exp_bias_window);
|
||||
if constexpr(IsGateUp)
|
||||
exp_bias_up_buffer = load_tile(exp_bias_up_window);
|
||||
}
|
||||
if constexpr(!IsInputGemm)
|
||||
exp_weight_buffer = load_tile(exp_weight_window);
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<LDS_NPerIterationShuffle>{}),
|
||||
@@ -862,354 +1057,111 @@ struct MoeFlatmmKernel
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr int kM2 = 4; // Val
|
||||
constexpr int kM1 = (64 / NPerXdl); // Thr
|
||||
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
|
||||
|
||||
constexpr int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle *
|
||||
OutputNumNXdlPerWavePerShuffle;
|
||||
constexpr bool EnableBias = decltype(kargs.exp_bias)::GranularityMN != -1;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
const index_t iMLane = get_lane_id() / NPerXdl;
|
||||
const index_t iNLane = get_lane_id() % NPerXdl;
|
||||
auto epi_tile_idx_slice =
|
||||
[&](const auto& acc_tile_like_tensor, auto epi_m_idx, auto epi_n_idx) {
|
||||
return acc_tile_like_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
|
||||
epi_n_idx * OutputNumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(
|
||||
sequence<NumMXdlPerWavePerShuffle, OutputNumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
};
|
||||
|
||||
float vec_scale_A[kM0 * kM2 * MRepeat];
|
||||
float vec_scale_B[NRepeat];
|
||||
|
||||
float vec_expert_weights[kM0 * kM2 * MRepeat];
|
||||
float vec_expert_bias[kM0 * kM2 * MRepeat];
|
||||
|
||||
const float* expert_weights = static_cast<const float*>(kargs.p_sorted_expert_weights);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Load scales and expert weights
|
||||
//===----------------------------------------------------------------------===//
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
{
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
static_for<0, NRepeat / 2, 1>{}([&](auto i) {
|
||||
vec_scale_B[i * 2] =
|
||||
kargs.scale_n[expert_id * kargs.N + coord_n / 2 + i * NWave * NPerXdl +
|
||||
iNWarp * NPerXdl + iNLane];
|
||||
vec_scale_B[i * 2 + 1] =
|
||||
kargs.scale_n[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 +
|
||||
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto i) {
|
||||
vec_scale_B[i] =
|
||||
kargs.scale_n[expert_id * kargs.N + coord_n + i * NWave * NPerXdl +
|
||||
iNWarp * NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
}
|
||||
if constexpr(MXFP4_Pipeline && EnableBias)
|
||||
{
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
static_for<0, NRepeat / 2, 1>{}([&](auto i) {
|
||||
vec_expert_bias[i * 2] =
|
||||
kargs.exp_bias[expert_id * kargs.N + coord_n / 2 + i * NWave * NPerXdl +
|
||||
iNWarp * NPerXdl + iNLane];
|
||||
vec_expert_bias[i * 2 + 1] =
|
||||
kargs.exp_bias[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 +
|
||||
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 2>{}([&](auto i) {
|
||||
vec_expert_bias[i] =
|
||||
kargs.exp_bias[expert_id * kargs.N + coord_n + i * NWave * NPerXdl +
|
||||
iNWarp * 2 * NPerXdl + iNLane];
|
||||
vec_expert_bias[i + 1] =
|
||||
kargs.exp_bias[expert_id * kargs.N + coord_n + i * NWave * NPerXdl +
|
||||
iNWarp * 2 * NPerXdl + NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto i) {
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
index_t M2_offset = m2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave + coord_m;
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + m2] =
|
||||
kargs.scale_m[row_to_token_idx(M2_offset)];
|
||||
if constexpr(!IsInputGemm)
|
||||
vec_expert_weights[i * kM0 * kM2 + m0 * kM2 + m2] =
|
||||
expert_weights[M2_offset];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pingpong process start
|
||||
//===----------------------------------------------------------------------===//
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
LDSTileTensor gate_tensor, up_tensor;
|
||||
|
||||
// gate and up are interleaved along NRepeat dimension.
|
||||
auto gate_up_epi_tile_idx_interleave_slice = [&](auto& dest_gate_tensor,
|
||||
auto& dest_up_tensor,
|
||||
const auto& acc_tile_like_tensor,
|
||||
auto epi_m_idx,
|
||||
auto epi_n_idx) {
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
gate_tensor.set_y_sliced_thread_data(
|
||||
dest_gate_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{}, c_warp_y_lengths),
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 2 * n_xdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
acc_tile_like_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
|
||||
epi_n_idx * OutputNumNXdlPerWavePerShuffle + 2 * n_xdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths)));
|
||||
|
||||
up_tensor.set_y_sliced_thread_data(
|
||||
dest_up_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{}, c_warp_y_lengths),
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 2 * n_xdl + 1>{},
|
||||
acc_tile_like_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
|
||||
epi_n_idx * OutputNumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl + 1>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths)));
|
||||
});
|
||||
};
|
||||
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
{
|
||||
gate_tensor
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[2 * n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[2 * n_xdl + 1];
|
||||
}
|
||||
if constexpr(EnableBias)
|
||||
{
|
||||
gate_tensor
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] +=
|
||||
vec_expert_bias[2 * n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] +=
|
||||
vec_expert_bias[2 * n_xdl + 1];
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
lds_tile[0].get_thread_buffer().at(idx) =
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
up_tensor.get_thread_buffer().at(idx));
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
lds_tile[0].get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * NumNXdlPerWavePerShuffle + n_xdl) * c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[n_xdl];
|
||||
if constexpr(EnableBias)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] +=
|
||||
vec_expert_bias[n_xdl];
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_expert_weights[m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
if constexpr(IsInputGemm)
|
||||
auto process_epi_tile = [&](auto lds_stage, auto epi_m, auto epi_n) {
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
LDSTileTensor gate_tensor, up_tensor;
|
||||
|
||||
gate_up_epi_tile_idx_interleave_slice(
|
||||
gate_tensor, up_tensor, c_block_tile, epi_m, epi_n);
|
||||
auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n);
|
||||
auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n);
|
||||
auto epi_scale_n_up = epi_tile_idx_slice(scale_n_up_buffer, epi_m, epi_n);
|
||||
|
||||
auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
|
||||
auto epi_exp_bias_up = epi_tile_idx_slice(exp_bias_up_buffer, epi_m, epi_n);
|
||||
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
lds_tile[0].get_thread_buffer().at(idx) =
|
||||
ActivationOp{}(lds_tile[0].get_thread_buffer().at(idx));
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
{
|
||||
gate_tensor.get_thread_buffer()[idx] *=
|
||||
epi_scale_m[idx] * epi_scale_n[idx];
|
||||
up_tensor.get_thread_buffer()[idx] *=
|
||||
epi_scale_m[idx] * epi_scale_n_up[idx];
|
||||
}
|
||||
if constexpr(EnableBias)
|
||||
{
|
||||
gate_tensor.get_thread_buffer()[idx] += epi_exp_bias[idx];
|
||||
up_tensor.get_thread_buffer()[idx] += epi_exp_bias_up[idx];
|
||||
}
|
||||
lds_tile[lds_stage].get_thread_buffer().at(idx) =
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
up_tensor.get_thread_buffer().at(idx));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
constexpr int read_stage = iAccess % 2;
|
||||
constexpr int write_stage = read_stage ^ 1;
|
||||
|
||||
block_sync_lds();
|
||||
constexpr auto idx_y_start = SFC::get_index(number<iAccess.value>{});
|
||||
constexpr auto idx_y_start_next = SFC::get_index(number<iAccess.value + 1>{});
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / MPerIterationShuffle>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / NPerIterationShuffle>{};
|
||||
|
||||
constexpr auto mIter_next =
|
||||
number<idx_y_start_next.at(number<0>{}) / MPerIterationShuffle>{};
|
||||
constexpr auto nIter_next =
|
||||
number<idx_y_start_next.at(number<1>{}) / NPerIterationShuffle>{};
|
||||
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile[read_stage]);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
|
||||
if constexpr(iAccess < num_access - 1)
|
||||
else
|
||||
{
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
LDSTileTensor gate_tensor, up_tensor;
|
||||
lds_tile[lds_stage].get_thread_buffer() =
|
||||
epi_tile_idx_slice(c_block_tile, epi_m, epi_n);
|
||||
auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n);
|
||||
auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n);
|
||||
auto epi_exp_weight = epi_tile_idx_slice(exp_weight_buffer, epi_m, epi_n);
|
||||
auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
|
||||
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
gate_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths),
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
|
||||
nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths)));
|
||||
|
||||
up_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths),
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
|
||||
nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl + 1>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths)));
|
||||
});
|
||||
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
{
|
||||
gate_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl + 1];
|
||||
}
|
||||
if constexpr(EnableBias)
|
||||
{
|
||||
gate_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] +=
|
||||
vec_expert_bias[nIter_next *
|
||||
NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] +=
|
||||
vec_expert_bias[nIter_next *
|
||||
NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl + 1];
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
lds_tile[write_stage].get_thread_buffer().at(idx) =
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
up_tensor.get_thread_buffer().at(idx));
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
lds_tile[write_stage].get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
|
||||
nIter_next * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(
|
||||
sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * NumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
|
||||
m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
n_xdl];
|
||||
if constexpr(EnableBias)
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
|
||||
m2] +=
|
||||
vec_expert_bias[nIter_next *
|
||||
NumNXdlPerWavePerShuffle +
|
||||
n_xdl];
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
|
||||
m2] *= vec_expert_weights
|
||||
[mIter_next * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
lds_tile[write_stage].get_thread_buffer().at(idx) = ActivationOp{}(
|
||||
lds_tile[write_stage].get_thread_buffer().at(idx));
|
||||
});
|
||||
}
|
||||
}
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] *=
|
||||
epi_scale_m[idx] * epi_scale_n[idx];
|
||||
if constexpr(EnableBias)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
|
||||
else // for mlp1 gate-only
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] =
|
||||
ActivationOp{}(lds_tile[lds_stage].get_thread_buffer()[idx]);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
constexpr int MPerThread = TileEncodingPattern::Y2;
|
||||
statically_indexed_array<index_t, MPerThread> offsets;
|
||||
|
||||
auto c_coord = dram_tile_distribution.calculate_index();
|
||||
constexpr int NumMEpiTile = MRepeat / NumMXdlPerWavePerShuffle;
|
||||
constexpr int MPerThread = TileEncodingPattern::Y2;
|
||||
statically_indexed_array<statically_indexed_array<index_t, MPerThread>, NumMEpiTile>
|
||||
c_scatter_offsets;
|
||||
auto c_coord = dram_tile_distribution.calculate_index();
|
||||
static_for<0, NumMEpiTile, 1>{}([&](auto mIter) {
|
||||
static_for<0, MPerThread, 1>{}([&](auto m0) {
|
||||
auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0;
|
||||
auto fused_token =
|
||||
@@ -1217,34 +1169,57 @@ struct MoeFlatmmKernel
|
||||
|
||||
index_t scatter_token_id = fused_token & token_id_mask;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
scatter_token_id =
|
||||
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
|
||||
}
|
||||
offsets[m0] = scatter_token_id * kargs.stride_C;
|
||||
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
|
||||
});
|
||||
});
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pingpong process start
|
||||
//===----------------------------------------------------------------------===//
|
||||
process_epi_tile(number<0>{}, number<0>{}, number<0>{});
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
constexpr int read_stage = iAccess % 2;
|
||||
constexpr int write_stage = read_stage ^ 1;
|
||||
|
||||
block_sync_lds();
|
||||
constexpr auto idx_y_start = SFC::get_index(number<iAccess.value>{});
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / MPerIterationShuffle>{};
|
||||
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile[read_stage]);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
|
||||
if constexpr(iAccess < num_access - 1)
|
||||
{
|
||||
constexpr auto idx_y_start_next = SFC::get_index(number<iAccess.value + 1>{});
|
||||
constexpr auto mIter_next =
|
||||
number<idx_y_start_next.at(number<0>{}) / MPerIterationShuffle>{};
|
||||
constexpr auto nIter_next =
|
||||
number<idx_y_start_next.at(number<1>{}) / NPerIterationShuffle>{};
|
||||
|
||||
process_epi_tile(number<write_stage>{}, mIter_next, nIter_next);
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto c_out_tensor =
|
||||
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
auto c_scatter_tile_window =
|
||||
make_tile_scatter_gather(c_block_window.get_bottom_tensor_view(),
|
||||
c_block_window.get_window_lengths(),
|
||||
c_block_window.get_window_origin(),
|
||||
dram_tile_distribution,
|
||||
offsets);
|
||||
c_scatter_offsets[mIter]);
|
||||
|
||||
if constexpr(!IsInputGemm ||
|
||||
EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)
|
||||
{
|
||||
c_scatter_tile_window.update(c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
c_scatter_tile_window.store(c_out_tensor);
|
||||
}
|
||||
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
Reference in New Issue
Block a user