This commit is contained in:
zanzhang
2025-08-26 16:36:18 +08:00
parent 923b8dc4a9
commit 5f9c2dbb8a
2 changed files with 61 additions and 58 deletions

View File

@@ -181,58 +181,55 @@ struct DynamicQuantEpilogue
template <typename ODramWindowTmp,
typename YScaleWindow,
typename MaxTile,
typename OAccTiles>
CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_window,
YScaleWindow& y_scale_window,
OAccTile& y_scale_window,
MaxTile& row_absmax,
OAccTiles& o_acc_tiles,
const bool isArray,
void* smem)
{
// auto reduce = GetBlockReduce2d();
// auto reduce_sync = GetBlockReduce2dSync();
// auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
//
// // auto o_acc_tmp = o_acc_tile;
//
// const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
// auto absmax = ReduceOp::AbsMax{};
//
// const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
// float rtn;
// asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
// : "=v"(rtn)
// : "v"(acc_), "v"(v_0_), "v"(v_1_));
// return rtn;
// };
//
// auto row_absmax = decltype(reduce(o_acc_tiles[0], absmax.GetIdentityValue<AccDataType>(), absmax)){};
// clear_tile(row_absmax);
//
// // static_for<0, BlockShape::Repeat_N, 1>{}([&](auto repeat_n)
// #pragma unroll
// for (int repeat_n = 0; repeat_n < BlockShape::Repeat_N; ++repeat_n)
// {
// auto row_absmax_local = [&]() {
// // if constexpr(UseMax3 && std::is_same_v<AccDataType, float>)
// // {
// // // fast max3+abs implementation
// // return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
// // }
// // else
// // {
// return reduce(o_acc_tiles[repeat_n], absmax.GetIdentityValue<AccDataType>(), absmax);
// // }
// }();
// ck_tile::sweep_tile(row_absmax, [&](auto idx) {
// row_absmax(idx) = max(row_absmax[idx], row_absmax_local[idx]);
// });
// // });
// }
// reduce_sync(row_absmax, f_absmax);
// reduce_crosswarp_sync(row_absmax, smem, f_absmax);
auto reduce = GetBlockReduce2d();
auto reduce_sync = GetBlockReduce2dSync();
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
// auto o_acc_tmp = o_acc_tile;
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
auto absmax = ReduceOp::AbsMax{};
const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
float rtn;
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
auto row_absmax = decltype(reduce(o_acc_tiles[0], absmax.GetIdentityValue<AccDataType>(), absmax)){};
clear_tile(row_absmax);
// static_for<0, BlockShape::Repeat_N, 1>{}([&](auto repeat_n)
#pragma unroll
for (int repeat_n = 0; repeat_n < BlockShape::Repeat_N; ++repeat_n)
{
auto row_absmax_local = [&]() {
// if constexpr(UseMax3 && std::is_same_v<AccDataType, float>)
// {
// // fast max3+abs implementation
// return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
// }
// else
// {
return reduce(o_acc_tiles[repeat_n], absmax.GetIdentityValue<AccDataType>(), absmax);
// }
}();
ck_tile::sweep_tile(row_absmax, [&](auto idx) {
row_absmax(idx) = max(row_absmax[idx], row_absmax_local[idx]);
});
// });
}
reduce_sync(row_absmax, f_absmax);
reduce_crosswarp_sync(row_absmax, smem, f_absmax);
// here y_scale is Acc TYpe, need convert to YScale type later
auto y_scale = tile_elementwise_in(

View File

@@ -125,8 +125,10 @@ struct Rmsnorm2dFwdPipelineOnePass
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
using AccTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
using AccResTensorType = decltype(load_tile(x_residual_window));
AccTensorType x_warp_tensors[Repeat_N];
AccTensorType o_warp_tensors[Repeat_N];
auto square_sum = decltype(block_reduce2d(AccTensorType{},
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
@@ -136,7 +138,6 @@ struct Rmsnorm2dFwdPipelineOnePass
const auto sm_scale_window =
make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution());
#pragma unroll
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
{
auto x = load_tile(x_window);
@@ -158,12 +159,6 @@ struct Rmsnorm2dFwdPipelineOnePass
// compute x = x_resi + x
x_warp_tensors[repeat_n](idx) = type_convert<ComputeDataType>(x_resi(idx)) + x_warp_tensors[repeat_n](idx);
});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, cast_tile<YResidualDataType>(x_warp_tensors[repeat_n]));
if constexpr(x_resi.is_valid())
move_tile_window(y_residual_window, {0, Stride_N});
}
}
// compute mean square each-thread->cross-lane->cross-warp
@@ -175,13 +170,14 @@ struct Rmsnorm2dFwdPipelineOnePass
square_sum(idx) += square_sum_local[idx];
});
}
auto sm_scale = load_tile(sm_scale_window);
const auto gamma = load_tile(gamma_window);
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
auto sm_scale = load_tile(sm_scale_window);
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
@@ -194,32 +190,42 @@ struct Rmsnorm2dFwdPipelineOnePass
static_for<0, Repeat_N, 1>{}([&](auto repeat_n)
{
sweep_tile(x_warp_tensors[0], [&, inv_rms_ = inv_rms](auto idx) {
sweep_tile(o_warp_tensors[0], [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
auto rmsn_ = x_warp_tensors[repeat_n][idx] * inv_rms_[i_idx] * gamma_;
auto rmsn_ = o_warp_tensors[repeat_n][idx] * inv_rms_[i_idx] * gamma_;
if constexpr(sm_scale.is_valid())
{
const auto xs_ = type_convert<ComputeDataType>(sm_scale[j_idx]);
x_warp_tensors[repeat_n](idx) = rmsn_ * xs_;
o_warp_tensors[repeat_n](idx) = rmsn_ * xs_;
}
});
});
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
{
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, cast_tile<YResidualDataType>(x_warp_tensors[repeat_n]));
if constexpr(AccResTensorType::is_valid())
move_tile_window(y_residual_window, {0, Stride_N});
}
}
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
if constexpr(kSaveUnquant)
{
Epilogue{}(
unquant_y_window, o_all_window, sm_scale_window_, y_scale_window_, x_warp_tensors, true, smem);
unquant_y_window, o_all_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem);
}
else
{
Epilogue{}(o_window, sm_scale_window_, y_scale_window_, x_warp_tensors, true, smem);
Epilogue{}(o_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem);
}
}
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)