mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[Ck tile] layernorm2d fwd optimize (#1637)
* optimze small N case using vec io and using rcp div * [Ck_tile] layernorm, add param to control fastdiv; change generate codes and test pass * [Ck_tile] fix blockSize compute in Generic2dBlockShape * [Ck_tile]fix kfastfdiv template style * [Ck_tile] layernorm, fix stype in review --------- Co-authored-by: dummycoderfe <noplydummmycoder@163.com>
This commit is contained in:
@@ -47,7 +47,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
|
||||
return BlockWelford<P_>{};
|
||||
}
|
||||
@@ -57,7 +58,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
|
||||
return BlockWelfordSync<P_>{};
|
||||
}
|
||||
@@ -67,7 +69,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
|
||||
return BlockWelfordCrossWarpSync<P_>{};
|
||||
}
|
||||
@@ -79,7 +82,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
|
||||
using block_welford = BlockWelford<P_>;
|
||||
using x_block_tile =
|
||||
|
||||
@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
@@ -125,7 +126,15 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
// compute inv-std
|
||||
auto inv_std = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ + epsilon));
|
||||
if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
|
||||
{
|
||||
return type_convert<ComputeDataType>(1.0f) *
|
||||
__builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
|
||||
}
|
||||
else
|
||||
{
|
||||
return type_convert<ComputeDataType>(1.0f) / sqrt(v_ + epsilon);
|
||||
}
|
||||
},
|
||||
var);
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
|
||||
|
||||
template <bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kFastFDiv_,
|
||||
bool kTwoPass_,
|
||||
Layernorm2dFusedAddEnum kFusedAdd_,
|
||||
Layernorm2dFusedQuantEnum kFusedQuant_>
|
||||
@@ -46,6 +47,7 @@ struct Layernorm2dFwdTraits
|
||||
{
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kFastFDiv = kFastFDiv_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
|
||||
Reference in New Issue
Block a user