mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
support max3 in smoothquant and add+ rmsnorm + rdquant (#1654)
* Fix cmake example build * Support max3 in smoothquant one pass * support max3 in two pass * support max3 in add_rmsnorm_rdquant
This commit is contained in:
@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
|
||||
static constexpr bool kSaveX = Problem::kSaveX;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool UseMax3 = true; // TODO - Move to trait
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
@@ -69,9 +70,16 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
|
||||
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
|
||||
auto reduce_sum_func = ReduceOp::Add{};
|
||||
auto reduce_absmax_func = ReduceOp::AbsMax{};
|
||||
auto reduce_max_func = ReduceOp::Max{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
|
||||
float rtn;
|
||||
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
|
||||
: "=v"(rtn)
|
||||
: "v"(acc_), "v"(v_0_), "v"(v_1_));
|
||||
return rtn;
|
||||
};
|
||||
auto reduce_max_func = ReduceOp::Max{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
@@ -116,8 +124,23 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
|
||||
});
|
||||
|
||||
// compute absmax, each-thread->cross-lane->cross-warp
|
||||
auto absmax = block_reduce2d(
|
||||
y, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func);
|
||||
auto absmax = [&]() {
|
||||
constexpr auto x_size_per_row =
|
||||
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
|
||||
if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
|
||||
x_size_per_row % 2 == 0)
|
||||
{
|
||||
return block_reduce2d(y,
|
||||
reduce_absmax_func.GetIdentityValue<ComputeDataType>(),
|
||||
reduce_absmax3_func,
|
||||
sequence<1, 2>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return block_reduce2d(
|
||||
y, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func);
|
||||
}
|
||||
}();
|
||||
block_reduce2d_sync(absmax, reduce_max_func);
|
||||
block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
|
||||
|
||||
|
||||
@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
|
||||
static constexpr bool kSaveX = Problem::kSaveX;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool UseMax3 = true; // TODO - Move to trait
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
@@ -76,9 +77,16 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
|
||||
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
|
||||
auto reduce_sum_func = ReduceOp::Add{};
|
||||
auto reduce_absmax_func = ReduceOp::AbsMax{};
|
||||
auto reduce_max_func = ReduceOp::Max{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
|
||||
float rtn;
|
||||
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
|
||||
: "=v"(rtn)
|
||||
: "v"(acc_), "v"(v_0_), "v"(v_1_));
|
||||
return rtn;
|
||||
};
|
||||
auto reduce_max_func = ReduceOp::Max{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
@@ -177,7 +185,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
|
||||
y(idx) = type_convert<ComputeDataType>(y_);
|
||||
});
|
||||
|
||||
block_reduce2d(y, absmax, reduce_absmax_func);
|
||||
constexpr auto x_size_per_row =
|
||||
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
|
||||
if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
|
||||
x_size_per_row % 2 == 0)
|
||||
block_reduce2d(y, absmax, reduce_absmax3_func, sequence<1, 2>{});
|
||||
else
|
||||
block_reduce2d(y, absmax, reduce_absmax_func);
|
||||
|
||||
if constexpr(kSaveX)
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
|
||||
Reference in New Issue
Block a user