mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
fix bugs
This commit is contained in:
@@ -181,58 +181,55 @@ struct DynamicQuantEpilogue
|
||||
|
||||
template <typename ODramWindowTmp,
|
||||
typename YScaleWindow,
|
||||
typename MaxTile,
|
||||
typename OAccTiles>
|
||||
CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_window,
|
||||
YScaleWindow& y_scale_window,
|
||||
OAccTile& y_scale_window,
|
||||
MaxTile& row_absmax,
|
||||
OAccTiles& o_acc_tiles,
|
||||
const bool isArray,
|
||||
void* smem)
|
||||
{
|
||||
// auto reduce = GetBlockReduce2d();
|
||||
// auto reduce_sync = GetBlockReduce2dSync();
|
||||
// auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
|
||||
//
|
||||
// // auto o_acc_tmp = o_acc_tile;
|
||||
//
|
||||
// const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
|
||||
// auto absmax = ReduceOp::AbsMax{};
|
||||
//
|
||||
// const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
|
||||
// float rtn;
|
||||
// asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
|
||||
// : "=v"(rtn)
|
||||
// : "v"(acc_), "v"(v_0_), "v"(v_1_));
|
||||
// return rtn;
|
||||
// };
|
||||
//
|
||||
// auto row_absmax = decltype(reduce(o_acc_tiles[0], absmax.GetIdentityValue<AccDataType>(), absmax)){};
|
||||
// clear_tile(row_absmax);
|
||||
//
|
||||
// // static_for<0, BlockShape::Repeat_N, 1>{}([&](auto repeat_n)
|
||||
// #pragma unroll
|
||||
// for (int repeat_n = 0; repeat_n < BlockShape::Repeat_N; ++repeat_n)
|
||||
// {
|
||||
// auto row_absmax_local = [&]() {
|
||||
// // if constexpr(UseMax3 && std::is_same_v<AccDataType, float>)
|
||||
// // {
|
||||
// // // fast max3+abs implementation
|
||||
// // return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
|
||||
// // }
|
||||
// // else
|
||||
// // {
|
||||
// return reduce(o_acc_tiles[repeat_n], absmax.GetIdentityValue<AccDataType>(), absmax);
|
||||
// // }
|
||||
// }();
|
||||
// ck_tile::sweep_tile(row_absmax, [&](auto idx) {
|
||||
// row_absmax(idx) = max(row_absmax[idx], row_absmax_local[idx]);
|
||||
// });
|
||||
// // });
|
||||
// }
|
||||
// reduce_sync(row_absmax, f_absmax);
|
||||
// reduce_crosswarp_sync(row_absmax, smem, f_absmax);
|
||||
auto reduce = GetBlockReduce2d();
|
||||
auto reduce_sync = GetBlockReduce2dSync();
|
||||
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
|
||||
|
||||
// auto o_acc_tmp = o_acc_tile;
|
||||
|
||||
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
|
||||
auto absmax = ReduceOp::AbsMax{};
|
||||
|
||||
const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
|
||||
float rtn;
|
||||
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
|
||||
: "=v"(rtn)
|
||||
: "v"(acc_), "v"(v_0_), "v"(v_1_));
|
||||
return rtn;
|
||||
};
|
||||
|
||||
auto row_absmax = decltype(reduce(o_acc_tiles[0], absmax.GetIdentityValue<AccDataType>(), absmax)){};
|
||||
clear_tile(row_absmax);
|
||||
|
||||
// static_for<0, BlockShape::Repeat_N, 1>{}([&](auto repeat_n)
|
||||
#pragma unroll
|
||||
for (int repeat_n = 0; repeat_n < BlockShape::Repeat_N; ++repeat_n)
|
||||
{
|
||||
auto row_absmax_local = [&]() {
|
||||
// if constexpr(UseMax3 && std::is_same_v<AccDataType, float>)
|
||||
// {
|
||||
// // fast max3+abs implementation
|
||||
// return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
return reduce(o_acc_tiles[repeat_n], absmax.GetIdentityValue<AccDataType>(), absmax);
|
||||
// }
|
||||
}();
|
||||
ck_tile::sweep_tile(row_absmax, [&](auto idx) {
|
||||
row_absmax(idx) = max(row_absmax[idx], row_absmax_local[idx]);
|
||||
});
|
||||
// });
|
||||
}
|
||||
reduce_sync(row_absmax, f_absmax);
|
||||
reduce_crosswarp_sync(row_absmax, smem, f_absmax);
|
||||
|
||||
// here y_scale is Acc TYpe, need convert to YScale type later
|
||||
auto y_scale = tile_elementwise_in(
|
||||
|
||||
@@ -125,8 +125,10 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
using AccTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
|
||||
using AccResTensorType = decltype(load_tile(x_residual_window));
|
||||
|
||||
AccTensorType x_warp_tensors[Repeat_N];
|
||||
AccTensorType o_warp_tensors[Repeat_N];
|
||||
|
||||
auto square_sum = decltype(block_reduce2d(AccTensorType{},
|
||||
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
|
||||
@@ -136,7 +138,6 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
const auto sm_scale_window =
|
||||
make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution());
|
||||
|
||||
#pragma unroll
|
||||
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
|
||||
{
|
||||
auto x = load_tile(x_window);
|
||||
@@ -158,12 +159,6 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
// compute x = x_resi + x
|
||||
x_warp_tensors[repeat_n](idx) = type_convert<ComputeDataType>(x_resi(idx)) + x_warp_tensors[repeat_n](idx);
|
||||
});
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
store_tile(y_residual_window, cast_tile<YResidualDataType>(x_warp_tensors[repeat_n]));
|
||||
if constexpr(x_resi.is_valid())
|
||||
move_tile_window(y_residual_window, {0, Stride_N});
|
||||
}
|
||||
}
|
||||
|
||||
// compute mean square each-thread->cross-lane->cross-warp
|
||||
@@ -175,13 +170,14 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
square_sum(idx) += square_sum_local[idx];
|
||||
});
|
||||
}
|
||||
auto sm_scale = load_tile(sm_scale_window);
|
||||
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
|
||||
|
||||
auto sm_scale = load_tile(sm_scale_window);
|
||||
|
||||
// compute inv-rms
|
||||
auto inv_rms = tile_elementwise_in(
|
||||
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
|
||||
@@ -194,32 +190,42 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
|
||||
static_for<0, Repeat_N, 1>{}([&](auto repeat_n)
|
||||
{
|
||||
sweep_tile(x_warp_tensors[0], [&, inv_rms_ = inv_rms](auto idx) {
|
||||
sweep_tile(o_warp_tensors[0], [&, inv_rms_ = inv_rms](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
|
||||
auto rmsn_ = x_warp_tensors[repeat_n][idx] * inv_rms_[i_idx] * gamma_;
|
||||
auto rmsn_ = o_warp_tensors[repeat_n][idx] * inv_rms_[i_idx] * gamma_;
|
||||
|
||||
if constexpr(sm_scale.is_valid())
|
||||
{
|
||||
const auto xs_ = type_convert<ComputeDataType>(sm_scale[j_idx]);
|
||||
x_warp_tensors[repeat_n](idx) = rmsn_ * xs_;
|
||||
o_warp_tensors[repeat_n](idx) = rmsn_ * xs_;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
for (int repeat_n = 0; repeat_n < Repeat_N; ++repeat_n)
|
||||
{
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
store_tile(y_residual_window, cast_tile<YResidualDataType>(x_warp_tensors[repeat_n]));
|
||||
if constexpr(AccResTensorType::is_valid())
|
||||
move_tile_window(y_residual_window, {0, Stride_N});
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
if constexpr(kSaveUnquant)
|
||||
{
|
||||
Epilogue{}(
|
||||
unquant_y_window, o_all_window, sm_scale_window_, y_scale_window_, x_warp_tensors, true, smem);
|
||||
unquant_y_window, o_all_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(o_window, sm_scale_window_, y_scale_window_, x_warp_tensors, true, smem);
|
||||
Epilogue{}(o_window, sm_scale_window_, y_scale_window_, o_warp_tensors, true, smem);
|
||||
}
|
||||
}
|
||||
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
|
||||
|
||||
Reference in New Issue
Block a user