From 5f9c2dbb8aaa27b94b445dc95e297116ff038bcf Mon Sep 17 00:00:00 2001 From: zanzhang Date: Tue, 26 Aug 2025 16:36:18 +0800 Subject: [PATCH] fix bugs --- .../ops/epilogue/dynamic_quant_epilogue.hpp | 87 +++++++++---------- .../rmsnorm2d_fwd_pipeline_one_pass.hpp | 32 ++++--- 2 files changed, 61 insertions(+), 58 deletions(-) diff --git a/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp b/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp index c04bcc9482..f3bc6a25ff 100644 --- a/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp @@ -181,58 +181,55 @@ struct DynamicQuantEpilogue template CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_window, YScaleWindow& y_scale_window, - OAccTile& y_scale_window, - MaxTile& row_absmax, OAccTiles& o_acc_tiles, const bool isArray, void* smem) { -// auto reduce = GetBlockReduce2d(); -// auto reduce_sync = GetBlockReduce2dSync(); -// auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); -// -// // auto o_acc_tmp = o_acc_tile; -// -// const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; -// auto absmax = ReduceOp::AbsMax{}; -// -// const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) { -// float rtn; -// asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)" -// : "=v"(rtn) -// : "v"(acc_), "v"(v_0_), "v"(v_1_)); -// return rtn; -// }; -// -// auto row_absmax = decltype(reduce(o_acc_tiles[0], absmax.GetIdentityValue(), absmax)){}; -// clear_tile(row_absmax); -// -// // static_for<0, BlockShape::Repeat_N, 1>{}([&](auto repeat_n) -// #pragma unroll -// for (int repeat_n = 0; repeat_n < BlockShape::Repeat_N; ++repeat_n) -// { -// auto row_absmax_local = [&]() { -// // if constexpr(UseMax3 && std::is_same_v) -// // { -// // // fast max3+abs implementation -// // return reduce(o_acc_tmp, type_convert(0), f_max3, sequence<1, 2>{}); -// // } -// // else -// // { -// return reduce(o_acc_tiles[repeat_n], absmax.GetIdentityValue(), absmax); -// // } -// }(); -// ck_tile::sweep_tile(row_absmax, [&](auto idx) { -// row_absmax(idx) = max(row_absmax[idx], row_absmax_local[idx]); -// }); -// // }); -// } -// reduce_sync(row_absmax, f_absmax); -// reduce_crosswarp_sync(row_absmax, smem, f_absmax); + auto reduce = GetBlockReduce2d(); + auto reduce_sync = GetBlockReduce2dSync(); + auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); + + // auto o_acc_tmp = o_acc_tile; + + const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; + auto absmax = ReduceOp::AbsMax{}; + + const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) { + float rtn; + asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)" + : "=v"(rtn) + : "v"(acc_), "v"(v_0_), "v"(v_1_)); + return rtn; + }; + + auto row_absmax = decltype(reduce(o_acc_tiles[0], absmax.GetIdentityValue(), absmax)){}; + clear_tile(row_absmax); + + // static_for<0, BlockShape::Repeat_N, 1>{}([&](auto repeat_n) +#pragma unroll + for (int repeat_n = 0; repeat_n < BlockShape::Repeat_N; ++repeat_n) + { + auto row_absmax_local = [&]() { + // if constexpr(UseMax3 && std::is_same_v) + // { + // // fast max3+abs implementation + // return reduce(o_acc_tmp, type_convert(0), f_max3, sequence<1, 2>{}); + // } + // else + // { + return reduce(o_acc_tiles[repeat_n], absmax.GetIdentityValue(), absmax); + // } + }(); + ck_tile::sweep_tile(row_absmax, [&](auto idx) { + row_absmax(idx) = max(row_absmax[idx], row_absmax_local[idx]); + }); + // }); + } + reduce_sync(row_absmax, f_absmax); + reduce_crosswarp_sync(row_absmax, smem, f_absmax); // here y_scale is Acc TYpe, need convert to YScale type later auto y_scale = tile_elementwise_in( diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp index f304b1516d..b9c32c2537 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp @@ -125,8 +125,10 @@ struct Rmsnorm2dFwdPipelineOnePass Policy::template GetBlockReduce2dCrossWarpSync(); using AccTensorType = decltype(cast_tile(load_tile(x_window))); + using AccResTensorType = decltype(load_tile(x_residual_window)); AccTensorType x_warp_tensors[Repeat_N]; + AccTensorType o_warp_tensors[Repeat_N]; auto square_sum = decltype(block_reduce2d(AccTensorType{}, reduce_square_sum_func.GetIdentityValue(), @@ -136,7 +138,6 @@ struct Rmsnorm2dFwdPipelineOnePass const auto sm_scale_window = make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution()); -#pragma unroll for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n) { auto x = load_tile(x_window); @@ -158,12 +159,6 @@ struct Rmsnorm2dFwdPipelineOnePass // compute x = x_resi + x x_warp_tensors[repeat_n](idx) = type_convert(x_resi(idx)) + x_warp_tensors[repeat_n](idx); }); - if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) - { - store_tile(y_residual_window, cast_tile(x_warp_tensors[repeat_n])); - if constexpr(x_resi.is_valid()) - move_tile_window(y_residual_window, {0, Stride_N}); - } } // compute mean square each-thread->cross-lane->cross-warp @@ -175,13 +170,14 @@ struct Rmsnorm2dFwdPipelineOnePass square_sum(idx) += square_sum_local[idx]; }); } - auto sm_scale = load_tile(sm_scale_window); const auto gamma = load_tile(gamma_window); block_reduce2d_sync(square_sum, reduce_sum_func); block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); + auto sm_scale = load_tile(sm_scale_window); + // compute inv-rms auto inv_rms = tile_elementwise_in( [&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum); @@ -194,32 +190,42 @@ struct Rmsnorm2dFwdPipelineOnePass static_for<0, Repeat_N, 1>{}([&](auto repeat_n) { - sweep_tile(x_warp_tensors[0], [&, inv_rms_ = inv_rms](auto idx) { + sweep_tile(o_warp_tensors[0], [&, inv_rms_ = inv_rms](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]); const auto gamma_ = type_convert(gamma[j_idx]); - auto rmsn_ = x_warp_tensors[repeat_n][idx] * inv_rms_[i_idx] * gamma_; + auto rmsn_ = o_warp_tensors[repeat_n][idx] * inv_rms_[i_idx] * gamma_; if constexpr(sm_scale.is_valid()) { const auto xs_ = type_convert(sm_scale[j_idx]); - x_warp_tensors[repeat_n](idx) = rmsn_ * xs_; + o_warp_tensors[repeat_n](idx) = rmsn_ * xs_; } }); }); + for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n) + { + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) + { + store_tile(y_residual_window, cast_tile(x_warp_tensors[repeat_n])); + if constexpr(AccResTensorType::is_valid()) + move_tile_window(y_residual_window, {0, Stride_N}); + } + } + if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) { if constexpr(kSaveUnquant) { Epilogue{}( - unquant_y_window, o_all_window, sm_scale_window_, y_scale_window_, x_warp_tensors, true, smem); + unquant_y_window, o_all_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem); } else { - Epilogue{}(o_window, sm_scale_window_, y_scale_window_, x_warp_tensors, true, smem); + Epilogue{}(o_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem); } } else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)