mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
[Ck tile] layernorm2d fwd optimize (#1637)
* optimze small N case using vec io and using rcp div * [Ck_tile] layernorm, add param to control fastdiv; change generate codes and test pass * [Ck_tile] fix blockSize compute in Generic2dBlockShape * [Ck_tile]fix kfastfdiv template style * [Ck_tile] layernorm, fix stype in review --------- Co-authored-by: dummycoderfe <noplydummmycoder@163.com>
This commit is contained in:
@@ -11,9 +11,10 @@ namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockWelford
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using ComputeDataType = typename Problem::ComputeDataType;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using ComputeDataType = typename Problem::ComputeDataType;
|
||||
static constexpr bool kFastFDiv = Problem::kFastFDiv;
|
||||
|
||||
CK_TILE_DEVICE constexpr BlockWelford() {}
|
||||
|
||||
@@ -89,7 +90,8 @@ struct BlockWelford
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockWelfordSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
static constexpr bool kFastFDiv = Problem::kFastFDiv;
|
||||
|
||||
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
|
||||
CK_TILE_DEVICE void
|
||||
@@ -173,8 +175,9 @@ struct BlockWelfordSync
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockWelfordCrossWarpSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockShape = typename Problem::BlockShape;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockShape = typename Problem::BlockShape;
|
||||
static constexpr bool kFastFDiv = Problem::kFastFDiv;
|
||||
|
||||
template <typename MeanDistributedTensor_>
|
||||
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
|
||||
@@ -351,12 +354,23 @@ CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_
|
||||
}
|
||||
|
||||
// Note: this function must be called after all the computation
|
||||
template <typename VarDistributedTensor_>
|
||||
template <typename VarDistributedTensor_, bool FastFdiv_ = false>
|
||||
CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_& var_tensor,
|
||||
int count)
|
||||
int count,
|
||||
bool_constant<FastFdiv_> = {})
|
||||
{
|
||||
using DataType = typename VarDistributedTensor_::DataType;
|
||||
tile_elementwise_inout([&count](auto& x) { x = x / type_convert<DataType>(count); },
|
||||
var_tensor);
|
||||
tile_elementwise_inout(
|
||||
[&count](auto& x) {
|
||||
if(FastFdiv_ && std::is_same_v<DataType, float>)
|
||||
{
|
||||
x = x * __builtin_amdgcn_rcpf(type_convert<DataType>(count));
|
||||
}
|
||||
else
|
||||
{
|
||||
x = x / type_convert<DataType>(count);
|
||||
}
|
||||
},
|
||||
var_tensor);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -7,12 +7,13 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_>
|
||||
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_, bool kFastFDiv_>
|
||||
struct BlockWelfordProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
static constexpr bool kFastFDiv = kFastFDiv_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user