mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Add two pass pipeline
This commit is contained in:
@@ -54,7 +54,7 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
|
||||
const GammaWindow& gamma_window_,
|
||||
XWindow& x_window,
|
||||
YScaleWindow& yscale_window,
|
||||
QYWindow& y_window,
|
||||
QYWindow& qy_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem) const
|
||||
@@ -121,6 +121,7 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
|
||||
block_reduce2d_sync(absmax, reduce_max_func);
|
||||
block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
|
||||
|
||||
// ex: yscale = absmax / 127 if int8
|
||||
auto yscale = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return v_ / type_convert<ComputeDataType>(numeric<QYDataType>::max());
|
||||
@@ -128,14 +129,14 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
|
||||
absmax);
|
||||
store_tile(yscale_window, cast_tile<YScaleDataType>(yscale));
|
||||
|
||||
// quantize to
|
||||
// quantize y to qy
|
||||
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
|
||||
sweep_tile(qy, [&, yscale_ = yscale](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
auto qy_ = y[idx] / yscale_[i_idx];
|
||||
qy(idx) = saturates<QYDataType>{}(qy_);
|
||||
});
|
||||
store_tile(y_window, qy);
|
||||
store_tile(qy_window, qy);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -52,9 +52,9 @@ struct AddRmsnorm2dRdquantFwdPipelineTwoPass
|
||||
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
|
||||
const BWindow& b_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
XWindow& x_window,
|
||||
XWindow& x_window_,
|
||||
YScaleWindow& yscale_window,
|
||||
QYWindow& y_window,
|
||||
QYWindow& qy_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem) const
|
||||
@@ -63,15 +63,17 @@ struct AddRmsnorm2dRdquantFwdPipelineTwoPass
|
||||
make_tile_window(a_window_, Policy::template MakeABXBlockTileDistribution<Problem>());
|
||||
auto b_window =
|
||||
make_tile_window(b_window_, Policy::template MakeABXBlockTileDistribution<Problem>());
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeABXBlockTileDistribution<Problem>());
|
||||
auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
|
||||
|
||||
auto reduce_square_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1 * v1; };
|
||||
auto reduce_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1; };
|
||||
auto reduce_absmax_func = [](const auto& v0, const auto& v1) { return max(v0, abs(v1)); };
|
||||
auto reduce_max_func = [](const auto& v0, const auto& v1) { return max(v0, v1); };
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
|
||||
auto reduce_sum_func = ReduceOp::Add{};
|
||||
auto reduce_absmax_func = ReduceOp::AbsMax{};
|
||||
auto reduce_max_func = ReduceOp::Max{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
@@ -81,7 +83,7 @@ struct AddRmsnorm2dRdquantFwdPipelineTwoPass
|
||||
|
||||
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(a_window)));
|
||||
auto square_sum = block_reduce2d.template MakeYBlockTile<XTensorType>();
|
||||
set_tile(square_sum, 0);
|
||||
set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
@@ -100,6 +102,8 @@ struct AddRmsnorm2dRdquantFwdPipelineTwoPass
|
||||
|
||||
block_reduce2d(x, square_sum, reduce_square_sum_func);
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
move_tile_window(a_window, {0, Block_N});
|
||||
move_tile_window(b_window, {0, Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
@@ -115,33 +119,142 @@ struct AddRmsnorm2dRdquantFwdPipelineTwoPass
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
if constexpr(kSaveX)
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
else
|
||||
{
|
||||
move_tile_window(a_window, {0, -Block_N});
|
||||
move_tile_window(b_window, {0, -Block_N});
|
||||
}
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_window, {0, stride_to_right_most_window});
|
||||
|
||||
// rmsnorm computation + absmax + quantization
|
||||
using YTensorType = XTensorType;
|
||||
auto absmax = block_reduce2d.template MakeYBlockTile<YTensorType>();
|
||||
set_tile(absmax, reduce_absmax_func.GetIdentityValue<ComputeDataType>());
|
||||
|
||||
// rmsnorm computation + absmax(threadwise reduce)
|
||||
if constexpr(kSaveX)
|
||||
__syncthreads();
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
auto x = [&]() {
|
||||
if constexpr(kSaveX)
|
||||
{
|
||||
return load_tile(x_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto a = load_tile(a_window);
|
||||
const auto b = load_tile(b_window);
|
||||
return tile_elementwise_in(
|
||||
[&](const auto& a_, const auto& b_) {
|
||||
return type_convert<ComputeDataType>(a_) +
|
||||
type_convert<ComputeDataType>(b_);
|
||||
},
|
||||
a,
|
||||
b);
|
||||
}
|
||||
}();
|
||||
|
||||
auto y = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
auto gamma = load_tile(gamma_window);
|
||||
auto y = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
|
||||
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) {
|
||||
sweep_tile(y, [&](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
|
||||
auto y_ = x_ * inv_rms[i_idx] * gamma_;
|
||||
|
||||
y(idx) = type_convert<ComputeDataType>(y_);
|
||||
});
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
block_reduce2d(y, absmax, reduce_absmax_func);
|
||||
|
||||
if constexpr(kSaveX)
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
else
|
||||
{
|
||||
move_tile_window(a_window, {0, -Block_N});
|
||||
move_tile_window(b_window, {0, -Block_N});
|
||||
}
|
||||
move_tile_window(gamma_window, {-Block_N});
|
||||
move_tile_window(y_window, {0, -Block_N});
|
||||
}
|
||||
|
||||
// compute absmax, cross-lane->cross-warp
|
||||
block_reduce2d_sync(absmax, reduce_max_func);
|
||||
block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
|
||||
|
||||
// ex: yscale = absmax / 127 if int8
|
||||
auto yscale = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return v_ / type_convert<ComputeDataType>(numeric<QYDataType>::max());
|
||||
},
|
||||
absmax);
|
||||
store_tile(yscale_window, cast_tile<YScaleDataType>(yscale));
|
||||
|
||||
// quantize y to qy
|
||||
// recompute rmsnorm, try to save y in the future
|
||||
if constexpr(kSaveX)
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
else
|
||||
{
|
||||
move_tile_window(a_window, {0, Block_N});
|
||||
move_tile_window(b_window, {0, Block_N});
|
||||
}
|
||||
move_tile_window(gamma_window, {Block_N});
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
auto x = [&]() {
|
||||
if constexpr(kSaveX)
|
||||
{
|
||||
return load_tile(x_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto a = load_tile(a_window);
|
||||
const auto b = load_tile(b_window);
|
||||
return tile_elementwise_in(
|
||||
[&](const auto& a_, const auto& b_) {
|
||||
return type_convert<ComputeDataType>(a_) +
|
||||
type_convert<ComputeDataType>(b_);
|
||||
},
|
||||
a,
|
||||
b);
|
||||
}
|
||||
}();
|
||||
|
||||
auto gamma = load_tile(gamma_window);
|
||||
auto y = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
|
||||
|
||||
sweep_tile(y, [&](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto y_ = x_ * inv_rms[i_idx] * gamma_;
|
||||
auto qy_ = y_ / yscale[i_idx];
|
||||
qy(idx) = saturates<QYDataType>{}(qy_);
|
||||
});
|
||||
|
||||
store_tile(qy_window, qy);
|
||||
|
||||
if constexpr(kSaveX)
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
else
|
||||
{
|
||||
move_tile_window(a_window, {0, Block_N});
|
||||
move_tile_window(b_window, {0, Block_N});
|
||||
}
|
||||
move_tile_window(gamma_window, {Block_N});
|
||||
move_tile_window(qy_window, {0, Block_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user