[Ck tile] layernorm2d fwd optimize (#1637)

* optimze small N case using vec io and using rcp div

* [Ck_tile] layernorm, add param to control fastdiv; change generate codes and test pass

* [Ck_tile] fix blockSize compute in Generic2dBlockShape

* [Ck_tile]fix kfastfdiv template style

* [Ck_tile] layernorm, fix stype in review

---------

Co-authored-by: dummycoderfe <noplydummmycoder@163.com>
This commit is contained in:
dummycoderfe
2024-11-08 12:28:23 +08:00
committed by GitHub
parent 75c5bfa364
commit 686a58a912
8 changed files with 144 additions and 84 deletions

View File

@@ -11,9 +11,10 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = void>
struct BlockWelford
{
using Problem = remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
using ComputeDataType = typename Problem::ComputeDataType;
using Problem = remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
using ComputeDataType = typename Problem::ComputeDataType;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
CK_TILE_DEVICE constexpr BlockWelford() {}
@@ -89,7 +90,8 @@ struct BlockWelford
template <typename Problem_, typename Policy_ = void>
struct BlockWelfordSync
{
using Problem = remove_cvref_t<Problem_>;
using Problem = remove_cvref_t<Problem_>;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void
@@ -173,8 +175,9 @@ struct BlockWelfordSync
template <typename Problem_, typename Policy_ = void>
struct BlockWelfordCrossWarpSync
{
using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape;
using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape;
static constexpr bool kFastFDiv = Problem::kFastFDiv;
template <typename MeanDistributedTensor_>
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
@@ -351,12 +354,23 @@ CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_
}
// Note: this function must be called after all the computation
template <typename VarDistributedTensor_>
template <typename VarDistributedTensor_, bool FastFdiv_ = false>
CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_& var_tensor,
int count)
int count,
bool_constant<FastFdiv_> = {})
{
using DataType = typename VarDistributedTensor_::DataType;
tile_elementwise_inout([&count](auto& x) { x = x / type_convert<DataType>(count); },
var_tensor);
tile_elementwise_inout(
[&count](auto& x) {
if(FastFdiv_ && std::is_same_v<DataType, float>)
{
x = x * __builtin_amdgcn_rcpf(type_convert<DataType>(count));
}
else
{
x = x / type_convert<DataType>(count);
}
},
var_tensor);
}
} // namespace ck_tile

View File

@@ -7,12 +7,13 @@
namespace ck_tile {
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_>
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_, bool kFastFDiv_>
struct BlockWelfordProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kFastFDiv = kFastFDiv_;
};
} // namespace ck_tile