diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index f5004f72af..2a7cc35c7c 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -309,6 +309,8 @@ struct CShuffleEpilogue tuple, sequence<1, 1>>, sequence<1, 2>, sequence<2, 2>>; + static_assert(GetVectorSizeC() % kN2 == 0); + constexpr auto dram_tile_distribution = make_static_tile_distribution(IntrThreadShuffleEncode{}); @@ -324,7 +326,6 @@ struct CShuffleEpilogue auto shuffle_acc = make_static_distributed_tensor(dram_tile_distribution); auto c_out_tensor = make_static_distributed_tensor(dram_tile_distribution); - // auto c_out_tensor = make_static_distributed_tensor(dram_tile_distribution); static_for<0, MRepeat, 1>{}([&](auto mIter) { shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( @@ -332,6 +333,7 @@ struct CShuffleEpilogue merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); static_for<0, NRepeat, 1>{}([&](auto n_idx) { + // transpose thread matrix c_out_tensor.get_thread_buffer()[n_idx + 0 * NRepeat] = type_convert( shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0]); c_out_tensor.get_thread_buffer()[n_idx + 1 * NRepeat] = type_convert( @@ -342,8 +344,6 @@ struct CShuffleEpilogue shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3]); }); - // c_out_tensor = cast_tile(c_out_tensor_fp32); - if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); @@ -472,16 +472,16 @@ struct CShuffleEpilogue template = 0> CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, void* p_smem, - ScaleM scale_m, - ScaleN scale_n) + ScaleMWindow scale_m_window, + ScaleNWindow scale_n_window) { constexpr int kM0 = MWave; constexpr int kM2 = 4; @@ -499,9 +499,43 @@ struct CShuffleEpilogue tuple, sequence<1, 1>>, sequence<1, 2>, sequence<2, 2>>; + static_assert(GetVectorSizeC() % kN2 == 0); + constexpr auto dram_tile_distribution = make_static_tile_distribution(IntrThreadShuffleEncode{}); + constexpr int DynamicTileOffsetFlag = 0; + + auto permute_scale_n_view_1 = transform_tensor_view( + scale_n_window.get_bottom_tensor_view(), + make_tuple(make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, + number{}, + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2, 3, 4>{})); + auto permute_scale_n_view = transform_tensor_view( + permute_scale_n_view_1, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod(make_tuple(number{}, + number{}, + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1, 4, 2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + auto scale_m_window_with_dist = make_tile_window( + scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution()); + auto scale_n_window_with_dist = make_tile_window(permute_scale_n_view, + scale_n_window.get_window_lengths(), + scale_n_window.get_window_origin(), + o_acc_tile.get_tile_distribution()); + + auto scale_m_buffer = load_tile(scale_m_window_with_dist); + auto scale_n_buffer = load_tile(scale_n_window_with_dist); + auto d_dram_windows = generate_tuple( [&](auto idx) { return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); @@ -519,56 +553,39 @@ struct CShuffleEpilogue make_static_distributed_tensor(dram_tile_distribution); auto c_out_tensor = make_static_distributed_tensor(dram_tile_distribution); - const index_t iMWarp = get_warp_id() / NWave; - const index_t iNWarp = get_warp_id() - iMWarp * NWave; - const index_t iMLane = get_lane_id() / NPerXdl; - const index_t iNLane = get_lane_id() % NPerXdl; - - float vec_scale_A[kM2 * MRepeat]; - float vec_scale_B[NRepeat]; - - _Pragma("unroll") for(int i = 0; i < NRepeat; ++i) - { - vec_scale_B[i] = scale_n[i + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl]; - } - _Pragma("unroll") for(int i = 0; i < MRepeat; ++i) - { - vec_scale_A[i * kM2 + 0] = - scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave]; - vec_scale_A[i * kM2 + 1] = - scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave]; - vec_scale_A[i * kM2 + 2] = - scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave]; - vec_scale_A[i * kM2 + 3] = - scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave]; - } + constexpr int NumAccPerEpiTile = NRepeat * c_warp_y_lengths.product(); static_for<0, MRepeat, 1>{}([&](auto mIter) { shuffle_acc[mIter].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); - static_for<0, NRepeat, 1>{}([&](auto n_idx) { - shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 0] *= vec_scale_B[n_idx]; - shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 1] *= vec_scale_B[n_idx]; - shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 2] *= vec_scale_B[n_idx]; - shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 3] *= vec_scale_B[n_idx]; - }); + auto epi_scale_n = scale_n_buffer.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); + + static_for<0, NumAccPerEpiTile, 1>{}( + [&](auto i) { shuffle_acc[mIter].get_thread_buffer()[i] *= epi_scale_n[i]; }); }); static_for<0, MRepeat, 1>{}([&](auto mIter) { + auto epi_scale_m = scale_m_buffer.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); + static_for<0, NRepeat, 1>{}([&](auto n_idx) { + // transpose thread matrix c_out_tensor_fp32.get_thread_buffer()[n_idx + 0 * NRepeat] = shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0] * - vec_scale_A[mIter * kM2 + 0]; + epi_scale_m[n_idx * c_warp_y_lengths.product() + 0]; c_out_tensor_fp32.get_thread_buffer()[n_idx + 1 * NRepeat] = shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1] * - vec_scale_A[mIter * kM2 + 1]; + epi_scale_m[n_idx * c_warp_y_lengths.product() + 1]; c_out_tensor_fp32.get_thread_buffer()[n_idx + 2 * NRepeat] = shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2] * - vec_scale_A[mIter * kM2 + 2]; + epi_scale_m[n_idx * c_warp_y_lengths.product() + 2]; c_out_tensor_fp32.get_thread_buffer()[n_idx + 3 * NRepeat] = shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3] * - vec_scale_A[mIter * kM2 + 3]; + epi_scale_m[n_idx * c_warp_y_lengths.product() + 3]; }); c_out_tensor = cast_tile(c_out_tensor_fp32); @@ -592,16 +609,16 @@ struct CShuffleEpilogue template = 0> CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, void* p_smem, - ScaleM scale_m, - ScaleN scale_n) + ScaleMWindow scale_m_window, + ScaleNWindow scale_n_window) { constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); @@ -650,63 +667,32 @@ struct CShuffleEpilogue to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - constexpr int kM2 = 4; // Val - constexpr int kM1 = (64 / NPerXdl); // Thr - constexpr int kM0 = MPerXdl / kM1 / kM2; // Val + auto scale_m_window_with_dist = make_tile_window( + scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution()); + auto scale_n_window_with_dist = make_tile_window( + scale_n_window, scale_n_window.get_window_origin(), o_acc_tile.get_tile_distribution()); - const index_t iMWarp = get_warp_id() / NWave; - const index_t iNWarp = get_warp_id() - iMWarp * NWave; - const index_t iMLane = get_lane_id() / NPerXdl; - const index_t iNLane = get_lane_id() % NPerXdl; + auto scale_m_buffer = load_tile(scale_m_window_with_dist); + auto scale_n_buffer = load_tile(scale_n_window_with_dist); - float vec_scale_A[kM0 * kM2 * MRepeat]; - float vec_scale_B[NRepeat]; + constexpr int NumAccPerEpiTile = + NumMXdlPerWavePerShuffle * NumNXdlPerWavePerShuffle * c_warp_y_lengths.product(); + auto epi_tile_idx_slice = + [&](const auto& acc_tile_like_tensor, auto epi_m_idx, auto epi_n_idx) { + return acc_tile_like_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence{}, + c_warp_y_lengths)); + }; - _Pragma("unroll") for(int i = 0; i < NRepeat; ++i) - { - vec_scale_B[i] = scale_n[i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane]; - } - _Pragma("unroll") for(int i = 0; i < MRepeat; ++i) - { - _Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0) - { - vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 0] = - scale_m[0 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl + - i * MPerXdl * MWave]; - vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 1] = - scale_m[1 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl + - i * MPerXdl * MWave]; - vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 2] = - scale_m[2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl + - i * MPerXdl * MWave]; - vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 3] = - scale_m[3 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl + - i * MPerXdl * MWave]; - } - } + lds_tile[0].get_thread_buffer() = epi_tile_idx_slice(o_acc_tile, number<0>{}, number<0>{}); - lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( - merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{}, - c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths)); - static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) { - static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) { - constexpr int acc_xdl_offset = - (m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product(); - _Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0) - { - lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *= - vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 0] * vec_scale_B[n_xdl]; - lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *= - vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 1] * vec_scale_B[n_xdl]; - lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *= - vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 2] * vec_scale_B[n_xdl]; - lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *= - vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 3] * vec_scale_B[n_xdl]; - } - }); - }); + auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, number<0>{}, number<0>{}); + auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, number<0>{}, number<0>{}); + static_for<0, NumAccPerEpiTile, 1>{}( + [&](auto i) { lds_tile[0].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i]; }); static_for<0, num_access, 1>{}([&](auto iAccess) { constexpr int read_stage = iAccess % 2; @@ -724,40 +710,14 @@ struct CShuffleEpilogue if constexpr(iAccess < num_access - 1) { - lds_tile[write_stage].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths)); - static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) { - static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) { - constexpr int acc_xdl_offset = - (m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product(); - _Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0) - { - lds_tile[write_stage] - .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *= - vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 + - m_xdl * kM0 * kM2 + m0 * kM2 + 0] * - vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl]; - lds_tile[write_stage] - .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *= - vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 + - m_xdl * kM0 * kM2 + m0 * kM2 + 1] * - vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl]; - lds_tile[write_stage] - .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *= - vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 + - m_xdl * kM0 * kM2 + m0 * kM2 + 2] * - vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl]; - lds_tile[write_stage] - .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *= - vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 + - m_xdl * kM0 * kM2 + m0 * kM2 + 3] * - vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl]; - } - }); + lds_tile[write_stage].get_thread_buffer() = + epi_tile_idx_slice(o_acc_tile, mIter, nIter); + + epi_scale_m = epi_tile_idx_slice(scale_m_buffer, mIter, nIter); + epi_scale_n = epi_tile_idx_slice(scale_n_buffer, mIter, nIter); + + static_for<0, NumAccPerEpiTile, 1>{}([&](auto i) { + lds_tile[write_stage].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i]; }); } diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index c50f197cce..ede07e221c 100755 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -35,44 +35,29 @@ struct FlatmmScalePointer static constexpr int GranularityK = SharedGranularityK; const float* ptr; - index_t scale_stride = 1; CK_TILE_HOST_DEVICE FlatmmScalePointer() = default; CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {} - CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t stride) - : ptr(ptr_), scale_stride(stride) + CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_) + : ptr(ptr_) { } CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const { FlatmmScalePointer ret; - // if constexpr(GranularityMN == 0) - // { - // ret.scalar = scalar; - // } - // else if constexpr(GranularityMN == 1) - // { - // ret.ptr = ptr + offset; - // } - // else - // { - // ret.ptr = ptr + offset / GranularityMN; - // } - return ret; - } - - CK_TILE_HOST_DEVICE float operator[](index_t i) const - { - if constexpr(GranularityMN == 1) + if constexpr(GranularityMN == 0) { - return ptr[i]; + ret.ptr = ptr + offset / GranularityK; } else { - return ptr[i / GranularityMN]; + ret.ptr = ptr + offset / GranularityMN / GranularityK; } + return ret; } + + CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete; }; template @@ -83,54 +68,39 @@ struct FlatmmScalePointer static_assert(GranularityMN != 0); - union - { - const float* ptr; - float scalar; // if shared granularity is 0, all rows/columns use the same scale value - }; + const float* ptr; + index_t length; CK_TILE_HOST_DEVICE FlatmmScalePointer() = default; - CK_TILE_HOST_DEVICE FlatmmScalePointer(float scalar_) : scalar(scalar_) {} - CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {} - CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t stride) - : ptr(ptr_) + CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {} + CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_) + : ptr(ptr_), length(length_) { } CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const { FlatmmScalePointer ret; - if constexpr(GranularityMN == 0) + if constexpr(GranularityMN == 1) { - ret.scalar = scalar; - } - else if constexpr(GranularityMN == 1) - { - ret.ptr = ptr + offset; + ret.ptr = ptr + offset; + ret.length = length - offset; } else { - ret.ptr = ptr + offset / GranularityMN; + ret.ptr = ptr + offset / GranularityMN; + ret.length = length - offset / GranularityMN; } return ret; } - CK_TILE_HOST_DEVICE FlatmmScalePointer& advance() { return *this; } - CK_TILE_HOST_DEVICE float operator[](index_t i) const { - if constexpr(GranularityMN == 0) - { - return scalar; - } - else if constexpr(GranularityMN == 1) - { - return ptr[i]; - } + // with additional oob check + if constexpr(GranularityMN == 1) + return i < length ? ptr[i] : 0; else - { - return ptr[i / GranularityMN]; - } + return i / GranularityMN < length ? ptr[i / GranularityMN] : 0; } }; @@ -141,14 +111,11 @@ struct FlatmmScalePointer<-1, 0> static constexpr int GranularityMN = -1; static constexpr int GranularityK = 0; - CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default; - CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(float) {} - CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {} - CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, [[maybe_unused]] index_t stride) - { - } + const float* ptr = nullptr; - CK_TILE_HOST_DEVICE FlatmmScalePointer& advance() { return *this; } + CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default; + CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {} + CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {} CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const { @@ -679,7 +646,45 @@ struct FlatmmKernel } }(); - return make_tuple(a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view); + constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN; + constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN; + + constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK; + constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK; + + auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale + : 1; // per-token scale + auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale + : 1; // per-channel scale + + static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1, + "only support per-tensor or per-row scaling"); + static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1, + "only support per-tensor or per-column scaling"); + + const auto scale_m_view = make_naive_tensor_view( + kargs.scale_m_ptr.ptr, + make_tuple( + kargs.M / ScaleGranularityM, + ScaleGranularityKA == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKA), + make_tuple(scale_stride_m, 0), + number{}, + number<1>{}); + const auto scale_n_view = make_naive_tensor_view( + kargs.scale_n_ptr.ptr, + make_tuple( + ScaleGranularityKB == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKB, + kargs.N / ScaleGranularityN), + make_tuple(0, scale_stride_n), + number{}, + number<1>{}); + + return make_tuple(a_tensor_view, + b_flat_tensor_view, + ds_tensor_view, + e_tensor_view, + scale_m_view, + scale_n_view); } template @@ -745,7 +750,12 @@ struct FlatmmKernel } }(); - return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view); + return make_tuple(a_pad_view, + b_flat_tensor_view, + ds_pad_view, + e_pad_view, + views.at(number<4>{}), + views.at(number<5>{})); } template @@ -805,7 +815,28 @@ struct FlatmmKernel make_tuple(number{}, number{}), {i_m, i_n}); - return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window); + constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK; + constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK; + + auto scale_m_window = make_tile_window( + views.at(number<4>{}), + make_tuple(number{}, + number{}), + {i_m, 0}); + auto scale_n_window = make_tile_window( + views.at(number<5>{}), + make_tuple(number{}, + number{}), + {0, i_n}); + + return make_tuple(a_block_window, + b_flat_block_window, + ds_block_window, + e_block_window, + scale_m_window, + scale_n_window); } template @@ -837,6 +868,9 @@ struct FlatmmKernel const auto& c_block_tile = FlatmmPipeline{}.template operator()( a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong); + auto scale_m_window = gemm_tile_windows.at(number<4>{}); + auto scale_n_window = gemm_tile_windows.at(number<5>{}); + // Run Epilogue Pipeline if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1) { @@ -847,8 +881,8 @@ struct FlatmmKernel c_block_tile, d_block_window, smem_ptr_ping, - kargs.scale_m_ptr + block_idx_m, - kargs.scale_n_ptr + block_idx_n); + scale_m_window, + scale_n_window); } else if(UseDefaultScheduler || (get_warp_id() == 0)) { diff --git a/include/ck_tile/ops/moe_flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/moe_flatmm/kernel/moe_flatmm_kernel.hpp index 5f1e6a763f..1b91aa1b38 100644 --- a/include/ck_tile/ops/moe_flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/moe_flatmm/kernel/moe_flatmm_kernel.hpp @@ -197,8 +197,10 @@ struct MoeFlatmmKernel // MXF4_Pipeline only has the of scale B and granularityK is 32 static constexpr bool MXFP4_Pipeline = std::is_same_v; static constexpr int MXFP4N_Pack = 2; + static constexpr int MXFP4K_Pack = 2; static constexpr int N_Pack = MXFP4_Pipeline ? MXFP4N_Pack : 1; + static constexpr int K_Pack = MXFP4_Pipeline ? MXFP4K_Pack : 1; static constexpr int WeightPackedSize = numeric_traits::PackedSize; @@ -659,13 +661,16 @@ struct MoeFlatmmKernel {0, // offset_m is included when construct C-scatter-window offsets output_N_offset}); - constexpr int GranularityK = 32; + constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline + constexpr int XDLPerLoadScaleB = + MXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4 - auto scale_block_window = make_tile_window( - views.at(I3), - make_tuple(number{}, - number{}), - {coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0}); + auto scale_block_window = + make_tile_window(views.at(I3), + make_tuple(number{}, + number{}), + {coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0}); return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window); } @@ -771,11 +776,10 @@ struct MoeFlatmmKernel smem_ptr_pong); } }(); - using AccTile = decltype(c_block_tile); - // Run EpiloguePipeline Pipeline auto& c_block_window = gemm_tile_windows.at(number<2>{}); + // Run EpiloguePipeline { using EpiProblem = typename EpiloguePipeline::Problem; using ODataType = typename EpiloguePipeline::ODataType; @@ -786,9 +790,11 @@ struct MoeFlatmmKernel constexpr index_t MPerIterationShuffle = EpiloguePipeline::MPerIterationShuffle; constexpr index_t NPerIterationShuffle = EpiloguePipeline::NPerIterationShuffle; - constexpr index_t EpiVectorSizeC = EpiloguePipeline::GetVectorSizeC(); - constexpr index_t MRepeat = EpiloguePipeline::MRepeat; - constexpr index_t NRepeat = EpiloguePipeline::NRepeat; + constexpr index_t EpiVectorSizeC = EpiloguePipeline::GetVectorSizeC(); + constexpr index_t MRepeat = EpiloguePipeline::MRepeat; + constexpr index_t NRepeat = EpiloguePipeline::NRepeat; + constexpr index_t OutputNRepeat = IsGateUp ? NRepeat / 2 : NRepeat; + constexpr index_t BlockedXDLN_PerWarp = EpiloguePipeline::BlockedXDLN_PerWarp; static_assert(!IsGateUp || NumNXdlPerWavePerShuffle % 2 == 0); @@ -805,6 +811,195 @@ struct MoeFlatmmKernel auto o_lds_block = make_tensor_view( reinterpret_cast(smem_ptr_ping), lds_block_desc); + constexpr int ScaleGranularityM = decltype(kargs.scale_m)::GranularityMN; + constexpr int ScaleGranularityN = decltype(kargs.scale_n)::GranularityMN; + + constexpr index_t scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale + : 1; // per-token scale + constexpr index_t scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale + : 1; // per-channel scale + + auto output_acc_tile_distr = + make_static_tile_distribution(detail::make_embed_tile_distribution_encoding( + tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}, + typename CWarpDstr::DstrEncode{})); + + const auto scale_m_coord = + output_acc_tile_distr.calculate_index(); // 2d thread offset, [i_row, i_col] + + constexpr ck_tile::index_t ScaleMRepeat = + decltype(output_acc_tile_distr)::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}]; + statically_indexed_array scale_m_offsets; + + static_for<0, ScaleMRepeat, 1>{}([&](auto m0) { + const auto row_idx = + coord_m + m0 * (TilePartitioner::MPerBlock / ScaleMRepeat) + scale_m_coord[I0]; + scale_m_offsets[m0] = row_to_token_idx(row_idx); + }); + + constexpr int DynamicTileOffsetFlag = 0; + + constexpr bool EnableBias = decltype(kargs.exp_bias)::GranularityMN != -1; + + auto make_col_broadcast_window = [&](auto scale_pointer) { + return make_tile_window( + make_naive_tensor_view( + kargs.scale_n.ptr + expert_id * kargs.N, + make_tuple(1, kargs.N), + make_tuple(0, scale_stride_n), + number{}, + number<1>{}), + make_tuple(number{}, + number{}), + {0, IsGateUp ? coord_n / 2 : coord_n}, + output_acc_tile_distr); + }; + + auto permute_tensor_view = [&](auto naive_view, auto is_needed_to_permute_N_PACK) { + if constexpr(!is_needed_to_permute_N_PACK) + { + return naive_view; + } + else + { + auto view1 = transform_tensor_view( + naive_view, + make_tuple( + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, + number{}, + number{}, + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2, 3, 4, 5>{})); + return transform_tensor_view( + view1, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1, 2, 4, 3, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + }; + + auto scale_m_window = make_tile_scatter_gather( + make_naive_tensor_view( + kargs.scale_m.ptr, + make_tuple(kargs.M, 1), + make_tuple(scale_stride_m, 0), + number{}, + number<1>{}), + make_tuple(number{}, + number{}), + {0, 0}, // offset m is included in gather offsets + output_acc_tile_distr, + scale_m_offsets); + + auto scale_n_window = make_tile_window( + make_naive_tensor_view( + kargs.scale_n.ptr + expert_id * kargs.N, + make_tuple(1, kargs.N), + make_tuple(0, scale_stride_n), + number{}, + number<1>{}), // MXF4_Pipeline does't use scale_n, so there is no need to + // permute as n_pack + make_tuple(number{}, + number{}), + {0, IsGateUp ? coord_n / 2 : coord_n}, + output_acc_tile_distr); + + auto scale_n_up_window = make_tile_window( + make_naive_tensor_view( + kargs.scale_n.ptr + expert_id * kargs.N + kargs.N / 2, + make_tuple(1, kargs.N), + make_tuple(0, scale_stride_n), + number{}, + number<1>{}), + make_tuple(number{}, + number{}), + {0, coord_n / 2}, + output_acc_tile_distr); + + auto exp_bias_view = make_naive_tensor_view( + kargs.exp_bias.ptr + expert_id * kargs.N, + make_tuple(1, kargs.N), + make_tuple(0, scale_stride_n), + number{}, + number<1>{}); + + auto exp_bias_window = make_tile_window( + permute_tensor_view(exp_bias_view, number<(MXFP4_Pipeline && !IsInputGemm)>{}), + make_tuple(number{}, + number{}), + {0, IsGateUp ? coord_n / 2 : coord_n}, + output_acc_tile_distr); + + auto exp_bias_up_window = + make_tile_window(make_naive_tensor_view( + kargs.exp_bias.ptr + expert_id * kargs.N + kargs.N / 2, + make_tuple(1, kargs.N), + make_tuple(0, scale_stride_n), + number{}, + number<1>{}), + make_tuple(number{}, + number{}), + {0, coord_n / 2}, + output_acc_tile_distr); + + auto exp_weight_window = + make_tile_window(make_naive_tensor_view( + static_cast(kargs.p_sorted_expert_weights), + make_tuple(kargs.M, 1), + make_tuple(1, 0), + number{}, + number<1>{}), + make_tuple(number{}, + number{}), + {coord_m, 0}, + output_acc_tile_distr); + + using ScaleMBuffer = decltype(load_tile(scale_m_window)); + using ScaleNBuffer = decltype(load_tile(scale_n_window)); + using ExpBiasBuffer = decltype(load_tile(exp_bias_window)); + using ExpWeightBuffer = decltype(load_tile(exp_weight_window)); + + ScaleMBuffer scale_m_buffer; + ScaleNBuffer scale_n_buffer, scale_n_up_buffer; + + ExpBiasBuffer exp_bias_buffer, exp_bias_up_buffer; + ExpWeightBuffer exp_weight_buffer; + + if constexpr(!MXFP4_Pipeline) + { + scale_m_buffer = load_tile(scale_m_window); + scale_n_buffer = load_tile(scale_n_window); + if constexpr(IsGateUp) + scale_n_up_buffer = load_tile(scale_n_up_window); + } + + if constexpr(EnableBias) + { + exp_bias_buffer = load_tile(exp_bias_window); + if constexpr(IsGateUp) + exp_bias_up_buffer = load_tile(exp_bias_up_window); + } + if constexpr(!IsInputGemm) + exp_weight_buffer = load_tile(exp_weight_window); + auto in_lds_window = make_tile_window( o_lds_block, make_tuple(number{}, number{}), @@ -862,354 +1057,111 @@ struct MoeFlatmmKernel constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - constexpr int kM2 = 4; // Val - constexpr int kM1 = (64 / NPerXdl); // Thr - constexpr int kM0 = MPerXdl / kM1 / kM2; // Val - constexpr int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle * OutputNumNXdlPerWavePerShuffle; - constexpr bool EnableBias = decltype(kargs.exp_bias)::GranularityMN != -1; - const index_t iMWarp = get_warp_id() / NWave; - const index_t iNWarp = get_warp_id() - iMWarp * NWave; - const index_t iMLane = get_lane_id() / NPerXdl; - const index_t iNLane = get_lane_id() % NPerXdl; + auto epi_tile_idx_slice = + [&](const auto& acc_tile_like_tensor, auto epi_m_idx, auto epi_n_idx) { + return acc_tile_like_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences( + sequence{}, + c_warp_y_lengths)); + }; - float vec_scale_A[kM0 * kM2 * MRepeat]; - float vec_scale_B[NRepeat]; - - float vec_expert_weights[kM0 * kM2 * MRepeat]; - float vec_expert_bias[kM0 * kM2 * MRepeat]; - - const float* expert_weights = static_cast(kargs.p_sorted_expert_weights); - - //===----------------------------------------------------------------------===// - // Load scales and expert weights - //===----------------------------------------------------------------------===// - if constexpr(!MXFP4_Pipeline) - { - if constexpr(IsGateUp) - { - static_for<0, NRepeat / 2, 1>{}([&](auto i) { - vec_scale_B[i * 2] = - kargs.scale_n[expert_id * kargs.N + coord_n / 2 + i * NWave * NPerXdl + - iNWarp * NPerXdl + iNLane]; - vec_scale_B[i * 2 + 1] = - kargs.scale_n[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 + - i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane]; - }); - } - else - { - static_for<0, NRepeat, 1>{}([&](auto i) { - vec_scale_B[i] = - kargs.scale_n[expert_id * kargs.N + coord_n + i * NWave * NPerXdl + - iNWarp * NPerXdl + iNLane]; - }); - } - } - if constexpr(MXFP4_Pipeline && EnableBias) - { - if constexpr(IsGateUp) - { - static_for<0, NRepeat / 2, 1>{}([&](auto i) { - vec_expert_bias[i * 2] = - kargs.exp_bias[expert_id * kargs.N + coord_n / 2 + i * NWave * NPerXdl + - iNWarp * NPerXdl + iNLane]; - vec_expert_bias[i * 2 + 1] = - kargs.exp_bias[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 + - i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane]; - }); - } - else - { - static_for<0, NRepeat, 2>{}([&](auto i) { - vec_expert_bias[i] = - kargs.exp_bias[expert_id * kargs.N + coord_n + i * NWave * NPerXdl + - iNWarp * 2 * NPerXdl + iNLane]; - vec_expert_bias[i + 1] = - kargs.exp_bias[expert_id * kargs.N + coord_n + i * NWave * NPerXdl + - iNWarp * 2 * NPerXdl + NPerXdl + iNLane]; - }); - } - } - - static_for<0, MRepeat, 1>{}([&](auto i) { - static_for<0, kM0, 1>{}([&](auto m0) { - static_for<0, kM2, 1>{}([&](auto m2) { - index_t M2_offset = m2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl + - i * MPerXdl * MWave + coord_m; - if constexpr(!MXFP4_Pipeline) - vec_scale_A[i * kM0 * kM2 + m0 * kM2 + m2] = - kargs.scale_m[row_to_token_idx(M2_offset)]; - if constexpr(!IsInputGemm) - vec_expert_weights[i * kM0 * kM2 + m0 * kM2 + m2] = - expert_weights[M2_offset]; - }); - }); - }); - - //===----------------------------------------------------------------------===// - // Pingpong process start - //===----------------------------------------------------------------------===// - if constexpr(IsGateUp) - { - LDSTileTensor gate_tensor, up_tensor; - - // gate and up are interleaved along NRepeat dimension. + auto gate_up_epi_tile_idx_interleave_slice = [&](auto& dest_gate_tensor, + auto& dest_up_tensor, + const auto& acc_tile_like_tensor, + auto epi_m_idx, + auto epi_n_idx) { static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) { - gate_tensor.set_y_sliced_thread_data( + dest_gate_tensor.set_y_sliced_thread_data( merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros), merge_sequences(sequence{}, c_warp_y_lengths), - c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 2 * n_xdl>{}, - c_warp_y_index_zeros), + acc_tile_like_tensor.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), merge_sequences(sequence{}, c_warp_y_lengths))); - - up_tensor.set_y_sliced_thread_data( + dest_up_tensor.set_y_sliced_thread_data( merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros), merge_sequences(sequence{}, c_warp_y_lengths), - c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 2 * n_xdl + 1>{}, + acc_tile_like_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence{}, c_warp_y_lengths))); }); + }; - static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) { - static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) { - constexpr int acc_xdl_offset = - (m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) * - c_warp_y_lengths.product(); - - static_for<0, kM0, 1>{}([&](auto m0) { - static_for<0, kM2, 1>{}([&](auto m2) { - if constexpr(!MXFP4_Pipeline) - { - gate_tensor - .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *= - vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] * - vec_scale_B[2 * n_xdl]; - up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *= - vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] * - vec_scale_B[2 * n_xdl + 1]; - } - if constexpr(EnableBias) - { - gate_tensor - .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] += - vec_expert_bias[2 * n_xdl]; - up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] += - vec_expert_bias[2 * n_xdl + 1]; - } - }); - }); - }); - }); - static_for<0, ActVectorSize, 1>{}([&](auto idx) { - lds_tile[0].get_thread_buffer().at(idx) = - ActivationOp{}(gate_tensor.get_thread_buffer().at(idx), - up_tensor.get_thread_buffer().at(idx)); - }); - } - else - { - lds_tile[0].get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{}, - c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths)); - static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) { - static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) { - constexpr int acc_xdl_offset = - (m_xdl * NumNXdlPerWavePerShuffle + n_xdl) * c_warp_y_lengths.product(); - static_for<0, kM0, 1>{}([&](auto m0) { - static_for<0, kM2, 1>{}([&](auto m2) { - if constexpr(!MXFP4_Pipeline) - lds_tile[0] - .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *= - vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] * - vec_scale_B[n_xdl]; - if constexpr(EnableBias) - lds_tile[0] - .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] += - vec_expert_bias[n_xdl]; - if constexpr(!IsInputGemm) - lds_tile[0] - .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *= - vec_expert_weights[m_xdl * kM0 * kM2 + m0 * kM2 + m2]; - }); - }); - }); - }); - if constexpr(IsInputGemm) + auto process_epi_tile = [&](auto lds_stage, auto epi_m, auto epi_n) { + if constexpr(IsGateUp) { + LDSTileTensor gate_tensor, up_tensor; + + gate_up_epi_tile_idx_interleave_slice( + gate_tensor, up_tensor, c_block_tile, epi_m, epi_n); + auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n); + auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n); + auto epi_scale_n_up = epi_tile_idx_slice(scale_n_up_buffer, epi_m, epi_n); + + auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n); + auto epi_exp_bias_up = epi_tile_idx_slice(exp_bias_up_buffer, epi_m, epi_n); + static_for<0, ActVectorSize, 1>{}([&](auto idx) { - lds_tile[0].get_thread_buffer().at(idx) = - ActivationOp{}(lds_tile[0].get_thread_buffer().at(idx)); + if constexpr(!MXFP4_Pipeline) + { + gate_tensor.get_thread_buffer()[idx] *= + epi_scale_m[idx] * epi_scale_n[idx]; + up_tensor.get_thread_buffer()[idx] *= + epi_scale_m[idx] * epi_scale_n_up[idx]; + } + if constexpr(EnableBias) + { + gate_tensor.get_thread_buffer()[idx] += epi_exp_bias[idx]; + up_tensor.get_thread_buffer()[idx] += epi_exp_bias_up[idx]; + } + lds_tile[lds_stage].get_thread_buffer().at(idx) = + ActivationOp{}(gate_tensor.get_thread_buffer().at(idx), + up_tensor.get_thread_buffer().at(idx)); }); } - } - - static_for<0, num_access, 1>{}([&](auto iAccess) { - constexpr int read_stage = iAccess % 2; - constexpr int write_stage = read_stage ^ 1; - - block_sync_lds(); - constexpr auto idx_y_start = SFC::get_index(number{}); - constexpr auto idx_y_start_next = SFC::get_index(number{}); - - constexpr auto mIter = number{}) / MPerIterationShuffle>{}; - constexpr auto nIter = number{}) / NPerIterationShuffle>{}; - - constexpr auto mIter_next = - number{}) / MPerIterationShuffle>{}; - constexpr auto nIter_next = - number{}) / NPerIterationShuffle>{}; - - const auto c_warptile_in_tensor_casted = cast_tile(lds_tile[read_stage]); - - store_tile(in_lds_window, c_warptile_in_tensor_casted); - - if constexpr(iAccess < num_access - 1) + else { - if constexpr(IsGateUp) - { - LDSTileTensor gate_tensor, up_tensor; + lds_tile[lds_stage].get_thread_buffer() = + epi_tile_idx_slice(c_block_tile, epi_m, epi_n); + auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n); + auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n); + auto epi_exp_weight = epi_tile_idx_slice(exp_weight_buffer, epi_m, epi_n); + auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n); - static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) { - gate_tensor.set_y_sliced_thread_data( - merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths), - c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths))); - - up_tensor.set_y_sliced_thread_data( - merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths), - c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths))); - }); - - static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) { - static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) { - constexpr int acc_xdl_offset = - (m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) * - c_warp_y_lengths.product(); - static_for<0, kM0, 1>{}([&](auto m0) { - static_for<0, kM2, 1>{}([&](auto m2) { - if constexpr(!MXFP4_Pipeline) - { - gate_tensor.get_thread_buffer()[acc_xdl_offset + - m0 * kM2 + m2] *= - vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle * - kM0 * kM2 + - m_xdl * kM0 * kM2 + m0 * kM2 + m2] * - vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle + - 2 * n_xdl]; - up_tensor.get_thread_buffer()[acc_xdl_offset + - m0 * kM2 + m2] *= - vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle * - kM0 * kM2 + - m_xdl * kM0 * kM2 + m0 * kM2 + m2] * - vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle + - 2 * n_xdl + 1]; - } - if constexpr(EnableBias) - { - gate_tensor.get_thread_buffer()[acc_xdl_offset + - m0 * kM2 + m2] += - vec_expert_bias[nIter_next * - NumNXdlPerWavePerShuffle + - 2 * n_xdl]; - up_tensor.get_thread_buffer()[acc_xdl_offset + - m0 * kM2 + m2] += - vec_expert_bias[nIter_next * - NumNXdlPerWavePerShuffle + - 2 * n_xdl + 1]; - } - }); - }); - }); - }); - static_for<0, ActVectorSize, 1>{}([&](auto idx) { - lds_tile[write_stage].get_thread_buffer().at(idx) = - ActivationOp{}(gate_tensor.get_thread_buffer().at(idx), - up_tensor.get_thread_buffer().at(idx)); - }); - } - else - { - lds_tile[write_stage].get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences( - sequence{}, - c_warp_y_lengths)); - static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) { - static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) { - constexpr int acc_xdl_offset = - (m_xdl * NumNXdlPerWavePerShuffle + n_xdl) * - c_warp_y_lengths.product(); - static_for<0, kM0, 1>{}([&](auto m0) { - static_for<0, kM2, 1>{}([&](auto m2) { - if constexpr(!MXFP4_Pipeline) - lds_tile[write_stage] - .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + - m2] *= - vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle * - kM0 * kM2 + - m_xdl * kM0 * kM2 + m0 * kM2 + m2] * - vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle + - n_xdl]; - if constexpr(EnableBias) - lds_tile[write_stage] - .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + - m2] += - vec_expert_bias[nIter_next * - NumNXdlPerWavePerShuffle + - n_xdl]; - if constexpr(!IsInputGemm) - lds_tile[write_stage] - .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + - m2] *= vec_expert_weights - [mIter_next * NumMXdlPerWavePerShuffle * kM0 * kM2 + - m_xdl * kM0 * kM2 + m0 * kM2 + m2]; - }); - }); - }); - }); - if constexpr(IsInputGemm) - { - static_for<0, ActVectorSize, 1>{}([&](auto idx) { - lds_tile[write_stage].get_thread_buffer().at(idx) = ActivationOp{}( - lds_tile[write_stage].get_thread_buffer().at(idx)); - }); - } - } + static_for<0, ActVectorSize, 1>{}([&](auto idx) { + if constexpr(!MXFP4_Pipeline) + lds_tile[lds_stage].get_thread_buffer()[idx] *= + epi_scale_m[idx] * epi_scale_n[idx]; + if constexpr(EnableBias) + lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx]; + if constexpr(!IsInputGemm) + lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx]; + else // for mlp1 gate-only + lds_tile[lds_stage].get_thread_buffer()[idx] = + ActivationOp{}(lds_tile[lds_stage].get_thread_buffer()[idx]); + }); } + }; - constexpr int MPerThread = TileEncodingPattern::Y2; - statically_indexed_array offsets; - - auto c_coord = dram_tile_distribution.calculate_index(); + constexpr int NumMEpiTile = MRepeat / NumMXdlPerWavePerShuffle; + constexpr int MPerThread = TileEncodingPattern::Y2; + statically_indexed_array, NumMEpiTile> + c_scatter_offsets; + auto c_coord = dram_tile_distribution.calculate_index(); + static_for<0, NumMEpiTile, 1>{}([&](auto mIter) { static_for<0, MPerThread, 1>{}([&](auto m0) { auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0; auto fused_token = @@ -1217,34 +1169,57 @@ struct MoeFlatmmKernel index_t scatter_token_id = fused_token & token_id_mask; if constexpr(IsInputGemm) - { scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); - } - offsets[m0] = scatter_token_id * kargs.stride_C; + c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; }); + }); + + //===----------------------------------------------------------------------===// + // Pingpong process start + //===----------------------------------------------------------------------===// + process_epi_tile(number<0>{}, number<0>{}, number<0>{}); + + static_for<0, num_access, 1>{}([&](auto iAccess) { + constexpr int read_stage = iAccess % 2; + constexpr int write_stage = read_stage ^ 1; + + block_sync_lds(); + constexpr auto idx_y_start = SFC::get_index(number{}); + constexpr auto mIter = number{}) / MPerIterationShuffle>{}; + + const auto c_warptile_in_tensor_casted = cast_tile(lds_tile[read_stage]); + + store_tile(in_lds_window, c_warptile_in_tensor_casted); + + if constexpr(iAccess < num_access - 1) + { + constexpr auto idx_y_start_next = SFC::get_index(number{}); + constexpr auto mIter_next = + number{}) / MPerIterationShuffle>{}; + constexpr auto nIter_next = + number{}) / NPerIterationShuffle>{}; + + process_epi_tile(number{}, mIter_next, nIter_next); + } block_sync_lds(); auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); - auto c_scatter_tile_window = make_tile_scatter_gather(c_block_window.get_bottom_tensor_view(), c_block_window.get_window_lengths(), c_block_window.get_window_origin(), dram_tile_distribution, - offsets); + c_scatter_offsets[mIter]); if constexpr(!IsInputGemm || EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add) - { c_scatter_tile_window.update(c_out_tensor); - } else - { c_scatter_tile_window.store(c_out_tensor); - } + if constexpr(iAccess != num_access - 1) { constexpr auto step = SFC::get_forward_step(iAccess);