mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Ck tile/layernorm: implement naive reduce, opt performance (#1784)
* add no welford * enable output raw * raw of int8 * fix build * fix smoke test err * [ck_tile]layernorm: fix welford ok, set int8 and bf16 small N as default and others open by generate * [cktile]layernorm, fix err commit files and remove uselss * fix quant 8192 err & change norm_reduce class and file name --------- Co-authored-by: coderfeli <coderfeli@163.com> Co-authored-by: carlushuang <carlus.huang@amd.com>
This commit is contained in:
@@ -4,8 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
|
||||
#include "ck_tile/ops/welford/block/block_welford.hpp"
|
||||
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
|
||||
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduce()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
|
||||
return BlockWelford<P_>{};
|
||||
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv,
|
||||
Problem::Traits::kWelford>;
|
||||
return BlockNormReduce<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceSync()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv,
|
||||
Problem::Traits::kWelford>;
|
||||
|
||||
return BlockWelfordSync<P_>{};
|
||||
return BlockNormReduceSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv,
|
||||
Problem::Traits::kWelford>;
|
||||
|
||||
return BlockWelfordCrossWarpSync<P_>{};
|
||||
return BlockNormReduceCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv,
|
||||
Problem::Traits::kWelford>;
|
||||
|
||||
using block_welford = BlockWelford<P_>;
|
||||
using block_welford = BlockNormReduce<P_>;
|
||||
using x_block_tile =
|
||||
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using mean_var_block_tile =
|
||||
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
|
||||
|
||||
return GetBlockWelfordCrossWarpSync<Problem>()
|
||||
return GetBlockNormReduceCrossWarpSync<Problem>()
|
||||
.template GetSmemSize<mean_var_block_tile>();
|
||||
}
|
||||
else
|
||||
|
||||
@@ -37,6 +37,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
|
||||
static constexpr bool kWelford = Problem::Traits::kWelford;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
@@ -95,11 +96,16 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
int cur_count = 0;
|
||||
int max_count =
|
||||
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
|
||||
auto block_welford = Policy::template GetBlockWelford<Problem>();
|
||||
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
|
||||
auto block_welford_cross_warp_sync =
|
||||
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
|
||||
auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
|
||||
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
|
||||
auto block_norm_reduce_cross_warp_sync =
|
||||
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(cast_tile<ComputeDataType>(x));
|
||||
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
|
||||
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
|
||||
clear_tile(mean);
|
||||
clear_tile(var);
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
@@ -117,12 +123,21 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
|
||||
}
|
||||
|
||||
// compute welford each-thread->cross-lane->cross-warp
|
||||
auto [mean, var] = block_welford(acc, cur_count, max_count);
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
block_welford_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
|
||||
|
||||
// compute reduce each-thread->cross-lane->cross-warp
|
||||
block_norm_reduce(acc, mean, var, cur_count, max_count);
|
||||
block_norm_reduce_sync(mean, var, cur_count);
|
||||
block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
|
||||
if(kWelford)
|
||||
{
|
||||
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile(mean, [&](auto idx) {
|
||||
mean(idx) = mean(idx) / type_convert<MeanDataType>(row_size);
|
||||
var(idx) = var(idx) / type_convert<MeanDataType>(row_size) - mean(idx) * mean(idx);
|
||||
});
|
||||
}
|
||||
// compute inv-std
|
||||
auto inv_std = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
@@ -153,8 +168,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
|
||||
|
||||
auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
|
||||
ln(idx) = ln_;
|
||||
ln(idx) = ln_;
|
||||
});
|
||||
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
|
||||
|
||||
@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
|
||||
static constexpr bool kWelford = Problem::Traits::kWelford;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
@@ -77,6 +78,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
void* smem,
|
||||
Epilogue) const
|
||||
{
|
||||
static_assert(kWelford == true, "2 pass only supports welford merge");
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto gamma_window = make_tile_window(
|
||||
@@ -102,14 +104,14 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
int max_count =
|
||||
(num_n_tile_iteration - 1) * count_per_iter +
|
||||
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
|
||||
auto block_welford = Policy::template GetBlockWelford<Problem>();
|
||||
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
|
||||
auto block_welford_cross_warp_sync =
|
||||
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
|
||||
auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
|
||||
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
|
||||
auto block_norm_reduce_cross_warp_sync =
|
||||
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
|
||||
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>();
|
||||
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>();
|
||||
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
|
||||
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
@@ -133,11 +135,11 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
move_tile_window(y_residual_window, {0, Block_N});
|
||||
}
|
||||
}
|
||||
block_welford(acc, mean, var, cur_count, max_count);
|
||||
block_norm_reduce(acc, mean, var, cur_count, max_count);
|
||||
}
|
||||
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
block_welford_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_norm_reduce_sync(mean, var, cur_count);
|
||||
block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
|
||||
|
||||
// compute inv-std
|
||||
|
||||
@@ -40,6 +40,7 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
|
||||
template <bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kFastFDiv_,
|
||||
bool kWelford_,
|
||||
bool kTwoPass_,
|
||||
Layernorm2dFusedAddEnum kFusedAdd_,
|
||||
Layernorm2dFusedQuantEnum kFusedQuant_>
|
||||
@@ -48,6 +49,7 @@ struct Layernorm2dFwdTraits
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kFastFDiv = kFastFDiv_;
|
||||
static constexpr bool kWelford = kWelford_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/welford/block/block_welford.hpp"
|
||||
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
|
||||
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
|
||||
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
|
||||
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
|
||||
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
@@ -4,22 +4,23 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
|
||||
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockWelford
|
||||
struct BlockNormReduce
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using ComputeDataType = typename Problem::ComputeDataType;
|
||||
static constexpr bool kFastFDiv = Problem::kFastFDiv;
|
||||
static constexpr bool kWelford = Problem::kWelford;
|
||||
|
||||
CK_TILE_DEVICE constexpr BlockWelford() {}
|
||||
CK_TILE_DEVICE constexpr BlockNormReduce() {}
|
||||
|
||||
// [CAUSION] - max_count_ is to deal with the padding problem
|
||||
// max_count_ is depend on caller, eg: naive and splitN welford will have different
|
||||
// max_count_ is depend on caller, eg: naive and splitN norm_reduce will have different
|
||||
// calculation of max_count_
|
||||
// -> use block_welford_calculate_max_count to compute
|
||||
template <typename XDistributedTensor_,
|
||||
@@ -40,18 +41,24 @@ struct BlockWelford
|
||||
if(cur_count_ < max_count_)
|
||||
{
|
||||
++cur_count_;
|
||||
|
||||
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
|
||||
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
|
||||
constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
|
||||
|
||||
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
|
||||
|
||||
welford_update(mean_tensor(out_dstr_idx),
|
||||
var_tensor(out_dstr_idx),
|
||||
x,
|
||||
cur_count_,
|
||||
constant<kFastFDiv>{});
|
||||
if(kWelford)
|
||||
{
|
||||
welford_update(mean_tensor(out_dstr_idx),
|
||||
var_tensor(out_dstr_idx),
|
||||
x,
|
||||
cur_count_,
|
||||
constant<kFastFDiv>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
mean_tensor(out_dstr_idx) += x;
|
||||
var_tensor(out_dstr_idx) += x * x;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -91,10 +98,11 @@ struct BlockWelford
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockWelfordSync
|
||||
struct BlockNormReduceSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
static constexpr bool kFastFDiv = Problem::kFastFDiv;
|
||||
static constexpr bool kWelford = Problem::kWelford;
|
||||
|
||||
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
|
||||
CK_TILE_DEVICE void
|
||||
@@ -152,36 +160,48 @@ struct BlockWelfordSync
|
||||
(number<lid_over_rid_derivative << istage.value>{}.value);
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
|
||||
const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
|
||||
const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
|
||||
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
|
||||
const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
|
||||
if(kWelford)
|
||||
{
|
||||
const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
|
||||
|
||||
// welford merge
|
||||
welford_merge(v_local_mean,
|
||||
v_local_var,
|
||||
v_local_count,
|
||||
v_remote_mean,
|
||||
v_remote_var,
|
||||
v_remote_count,
|
||||
constant<kFastFDiv>{});
|
||||
// norm_reduce merge
|
||||
welford_merge(v_local_mean,
|
||||
v_local_var,
|
||||
v_local_count,
|
||||
v_remote_mean,
|
||||
v_remote_var,
|
||||
v_remote_count,
|
||||
constant<kFastFDiv>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
v_local_mean += v_remote_mean;
|
||||
v_local_var += v_remote_var;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
mean_tensor.get_thread_buffer()(i) = v_local_mean;
|
||||
var_tensor.get_thread_buffer()(i) = v_local_var;
|
||||
|
||||
count = v_local_count;
|
||||
if(kWelford)
|
||||
{
|
||||
count = v_local_count;
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockWelfordCrossWarpSync
|
||||
struct BlockNormReduceCrossWarpSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockShape = typename Problem::BlockShape;
|
||||
static constexpr bool kFastFDiv = Problem::kFastFDiv;
|
||||
static constexpr bool kWelford = Problem::kWelford;
|
||||
using smem_dtype = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>;
|
||||
|
||||
template <typename MeanDistributedTensor_>
|
||||
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
|
||||
@@ -252,7 +272,7 @@ struct BlockWelfordCrossWarpSync
|
||||
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
|
||||
|
||||
// Note: we always pack everything into fp32x4
|
||||
fp32x4_t* smem_ptr = reinterpret_cast<fp32x4_t*>(smem);
|
||||
smem_dtype* smem_ptr = reinterpret_cast<smem_dtype*>(smem);
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
|
||||
@@ -267,11 +287,13 @@ struct BlockWelfordCrossWarpSync
|
||||
if(lane_id == 0)
|
||||
{
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
fp32x4_t local_scratch_;
|
||||
smem_dtype local_scratch_;
|
||||
local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
|
||||
local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
|
||||
local_scratch_[2] = bit_cast<float>(count);
|
||||
|
||||
if(kWelford)
|
||||
{
|
||||
local_scratch_[2] = bit_cast<float>(count);
|
||||
}
|
||||
smem_ptr[smem_offset + i * num_warps] = local_scratch_;
|
||||
});
|
||||
}
|
||||
@@ -280,7 +302,7 @@ struct BlockWelfordCrossWarpSync
|
||||
// load from smem. here we let everythread to do compute :)
|
||||
index_t local_warp_id = warp_id / num_reduce_warps;
|
||||
index_t local_smem_os = local_warp_id * num_reduce_warps;
|
||||
fp32x4_t all_scratch[thread_buf_size * num_reduce_warps];
|
||||
smem_dtype all_scratch[thread_buf_size * num_reduce_warps];
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
|
||||
all_scratch[i_0 * num_reduce_warps + i_1] =
|
||||
@@ -293,32 +315,40 @@ struct BlockWelfordCrossWarpSync
|
||||
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
// TODO: use descriptor for this
|
||||
auto v_local = all_scratch[i_0 * num_reduce_warps];
|
||||
auto v_local_mean = bit_cast<DataType>(v_local[0]);
|
||||
auto v_local_var = bit_cast<DataType>(v_local[1]);
|
||||
auto v_local_count = bit_cast<int>(v_local[2]);
|
||||
auto v_local = all_scratch[i_0 * num_reduce_warps];
|
||||
auto v_local_mean = bit_cast<DataType>(v_local[0]);
|
||||
auto v_local_var = bit_cast<DataType>(v_local[1]);
|
||||
int v_local_count = kWelford ? bit_cast<int>(v_local[2]) : 0;
|
||||
|
||||
// further reduce mean/var
|
||||
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
|
||||
constexpr auto i_1 = number<i_1_n1 + 1>{};
|
||||
const fp32x4_t v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
|
||||
const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
|
||||
const auto v_remote_mean = bit_cast<DataType>(v_remote[0]);
|
||||
const auto v_remote_var = bit_cast<DataType>(v_remote[1]);
|
||||
const auto v_remote_count = bit_cast<int>(v_remote[2]);
|
||||
if(kWelford)
|
||||
{
|
||||
const auto v_remote_count = bit_cast<int>(v_remote[2]);
|
||||
|
||||
welford_merge(v_local_mean,
|
||||
v_local_var,
|
||||
v_local_count,
|
||||
v_remote_mean,
|
||||
v_remote_var,
|
||||
v_remote_count,
|
||||
constant<kFastFDiv>{});
|
||||
welford_merge(v_local_mean,
|
||||
v_local_var,
|
||||
v_local_count,
|
||||
v_remote_mean,
|
||||
v_remote_var,
|
||||
v_remote_count,
|
||||
constant<kFastFDiv>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
v_local_mean += v_remote_mean;
|
||||
v_local_var += v_remote_var;
|
||||
}
|
||||
});
|
||||
|
||||
mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
|
||||
var_tensor.get_thread_buffer()(i_0) = v_local_var;
|
||||
|
||||
count = v_local_count;
|
||||
if(kWelford)
|
||||
count = v_local_count;
|
||||
});
|
||||
}
|
||||
};
|
||||
@@ -7,13 +7,18 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_, bool kFastFDiv_>
|
||||
struct BlockWelfordProblem
|
||||
template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename BlockShape_,
|
||||
bool kFastFDiv_,
|
||||
bool kWelford_>
|
||||
struct BlockNormReduceProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
static constexpr bool kFastFDiv = kFastFDiv_;
|
||||
static constexpr bool kWelford = kWelford_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user