mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
Ck tile/layernorm: implement naive reduce, opt performance (#1784)
* add no welford * enable output raw * raw of int8 * fix build * fix smoke test err * [ck_tile]layernorm: fix welford ok, set int8 and bf16 small N as default and others open by generate * [cktile]layernorm, fix err commit files and remove uselss * fix quant 8192 err & change norm_reduce class and file name --------- Co-authored-by: coderfeli <coderfeli@163.com> Co-authored-by: carlushuang <carlus.huang@amd.com>
This commit is contained in:
@@ -4,8 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
|
||||
#include "ck_tile/ops/welford/block/block_welford.hpp"
|
||||
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
|
||||
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduce()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
|
||||
return BlockWelford<P_>{};
|
||||
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv,
|
||||
Problem::Traits::kWelford>;
|
||||
return BlockNormReduce<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceSync()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv,
|
||||
Problem::Traits::kWelford>;
|
||||
|
||||
return BlockWelfordSync<P_>{};
|
||||
return BlockNormReduceSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv,
|
||||
Problem::Traits::kWelford>;
|
||||
|
||||
return BlockWelfordCrossWarpSync<P_>{};
|
||||
return BlockNormReduceCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv,
|
||||
Problem::Traits::kWelford>;
|
||||
|
||||
using block_welford = BlockWelford<P_>;
|
||||
using block_welford = BlockNormReduce<P_>;
|
||||
using x_block_tile =
|
||||
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using mean_var_block_tile =
|
||||
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
|
||||
|
||||
return GetBlockWelfordCrossWarpSync<Problem>()
|
||||
return GetBlockNormReduceCrossWarpSync<Problem>()
|
||||
.template GetSmemSize<mean_var_block_tile>();
|
||||
}
|
||||
else
|
||||
|
||||
@@ -37,6 +37,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
|
||||
static constexpr bool kWelford = Problem::Traits::kWelford;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
@@ -95,11 +96,16 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
int cur_count = 0;
|
||||
int max_count =
|
||||
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
|
||||
auto block_welford = Policy::template GetBlockWelford<Problem>();
|
||||
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
|
||||
auto block_welford_cross_warp_sync =
|
||||
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
|
||||
auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
|
||||
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
|
||||
auto block_norm_reduce_cross_warp_sync =
|
||||
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(cast_tile<ComputeDataType>(x));
|
||||
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
|
||||
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
|
||||
clear_tile(mean);
|
||||
clear_tile(var);
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
@@ -117,12 +123,21 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
|
||||
}
|
||||
|
||||
// compute welford each-thread->cross-lane->cross-warp
|
||||
auto [mean, var] = block_welford(acc, cur_count, max_count);
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
block_welford_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
|
||||
|
||||
// compute reduce each-thread->cross-lane->cross-warp
|
||||
block_norm_reduce(acc, mean, var, cur_count, max_count);
|
||||
block_norm_reduce_sync(mean, var, cur_count);
|
||||
block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
|
||||
if(kWelford)
|
||||
{
|
||||
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile(mean, [&](auto idx) {
|
||||
mean(idx) = mean(idx) / type_convert<MeanDataType>(row_size);
|
||||
var(idx) = var(idx) / type_convert<MeanDataType>(row_size) - mean(idx) * mean(idx);
|
||||
});
|
||||
}
|
||||
// compute inv-std
|
||||
auto inv_std = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
@@ -153,8 +168,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
|
||||
|
||||
auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
|
||||
ln(idx) = ln_;
|
||||
ln(idx) = ln_;
|
||||
});
|
||||
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
|
||||
|
||||
@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
|
||||
static constexpr bool kWelford = Problem::Traits::kWelford;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
@@ -77,6 +78,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
void* smem,
|
||||
Epilogue) const
|
||||
{
|
||||
static_assert(kWelford == true, "2 pass only supports welford merge");
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto gamma_window = make_tile_window(
|
||||
@@ -102,14 +104,14 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
int max_count =
|
||||
(num_n_tile_iteration - 1) * count_per_iter +
|
||||
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
|
||||
auto block_welford = Policy::template GetBlockWelford<Problem>();
|
||||
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
|
||||
auto block_welford_cross_warp_sync =
|
||||
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
|
||||
auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
|
||||
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
|
||||
auto block_norm_reduce_cross_warp_sync =
|
||||
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
|
||||
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>();
|
||||
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>();
|
||||
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
|
||||
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
@@ -133,11 +135,11 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
move_tile_window(y_residual_window, {0, Block_N});
|
||||
}
|
||||
}
|
||||
block_welford(acc, mean, var, cur_count, max_count);
|
||||
block_norm_reduce(acc, mean, var, cur_count, max_count);
|
||||
}
|
||||
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
block_welford_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_norm_reduce_sync(mean, var, cur_count);
|
||||
block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
|
||||
|
||||
// compute inv-std
|
||||
|
||||
@@ -40,6 +40,7 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
|
||||
template <bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kFastFDiv_,
|
||||
bool kWelford_,
|
||||
bool kTwoPass_,
|
||||
Layernorm2dFusedAddEnum kFusedAdd_,
|
||||
Layernorm2dFusedQuantEnum kFusedQuant_>
|
||||
@@ -48,6 +49,7 @@ struct Layernorm2dFwdTraits
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kFastFDiv = kFastFDiv_;
|
||||
static constexpr bool kWelford = kWelford_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
|
||||
Reference in New Issue
Block a user