mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Gemm layernorm welford (#413)
* Add device op of gemm layernorm * [What] Rename F to H [Why] F and G prepare for welford tensor * Add gridwise gemm + welford * Extract template parameter * Rename kernel. Prepare to add second half kernel * Extract var * Add second kernel for gemm+layernorm * Move to the gemm_layernorm folder * Rename F and G to mean and var * Do not use snakeCurved, it makes determination of padding for welford difficult * Rewrite the device interface and rename some var * Add welford count * Update interface * Sync code, prepare to test on MI200 * Clean the code * Implement layernorm * Add comment to mension hipFree * Wrtie out the e for debug. This could be remove and use h for instead * 1. Allocate mean, var and count into by SetWorkSpacePointer. 2. Add GetWorkSpaceSize to calculate the space size * Add gemm layernorm host code * use reference layernorm * Fix bug of blockwise welford for first kernel * Fix bug of mean var padding for layernorm * Use sgpr for shuffleM_index * padding for GemmMeanVarCountGridDescriptor_M_NBlock * Add layout parameter * Check argument for gemm * calculate max count for tail block * Share E and H memory in device op * Hard code the vector dim * Refine the MakeDescriptor * 1. Remove E parameter, because E is inside of device op 2. Check vector size * [What] Rename MakeMeanVarDescriptor_M_N [Why] Prepare to add count version of make descriptor * Use 1D global memory for count * Prevent redundant IO * Update parameter * Add pipeline v1/v2 selector * Rename the example name * Add base class for gemm layernorm * Refine naming to distinguish naive and welford * Add comment to explan in detail * We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1 * Rewrite the 2st kernel, use multiple block along N dimension in layernorm kernel * Share the vector size * Refine var name * [What] Force LayernormThreadSliceSize_N = vector size. [Why] Memory coalesce * Add comment * Extract divisor out of the loop in reference layernorm * Pad different size for E and H in layernorm kernel according to different block tile * Refine naming * Refine naming * Prevent implicit cast * [What] use ck::math::sqrt instead of __builtin_amdgcn_sqrtf [Why] __builtin_amdgcn_sqrtf is only support float, double will cause casting * Cast only constant * Change of post shuffle thread descriptor * Add EMeanVarDataType parameter. * Merge the mean and var threadwise copy * Add missing index * Fix Typo * Sync the variable with previous if * 1. Declare e inside the host_gemm_layernorm() 2. Prevent implicit cast in reference code Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,394 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename EMeanVarDataType,
|
||||
typename HDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename ComputeDataType,
|
||||
typename EHGridDesc_M_N,
|
||||
typename MeanVarGridDesc_M_NBlock,
|
||||
typename CountGridDesc_M_NBlock,
|
||||
typename GammaBetaGridDesc_N,
|
||||
typename HElementwiseOperation,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t NThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t NThreadSliceSize,
|
||||
index_t ESrcVectorSize,
|
||||
index_t HDstVectorSize,
|
||||
index_t GammaSrcVectorSize,
|
||||
index_t BetaSrcVectorSize>
|
||||
struct GridwiseWelfordSecondHalfLayernorm2d
|
||||
{
|
||||
static_assert(NThreadSliceSize % ESrcVectorSize == 0 &&
|
||||
NThreadSliceSize % GammaSrcVectorSize == 0 &&
|
||||
NThreadSliceSize % BetaSrcVectorSize == 0,
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert(NThreadSliceSize % HDstVectorSize == 0,
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
|
||||
using ThreadBufferDimAccessOrder = Sequence<0, 1>;
|
||||
using ThreadClusterArrangeOrder = Sequence<0, 1>;
|
||||
|
||||
static constexpr auto thread_cluster_desc_m_n =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadBufferLengths_M_N = Sequence<MThreadSliceSize, NThreadSliceSize>;
|
||||
static constexpr auto thread_buffer_desc_m_n = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<NThreadSliceSize>{}));
|
||||
|
||||
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
|
||||
static constexpr auto thread_buffer_desc_m_1 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
using ThreadBufferLengths_N = Sequence<NThreadSliceSize>;
|
||||
static constexpr auto thread_buffer_desc_n =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<NThreadSliceSize>{}));
|
||||
|
||||
using ThreadWelfordSrcDesc_M_1 = decltype(thread_buffer_desc_m_1);
|
||||
using ThreadWelfordDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using ThreadwiseWelford =
|
||||
ThreadwiseWelfordMerge<ComputeDataType, ThreadWelfordSrcDesc_M_1, ThreadWelfordDstDesc_M>;
|
||||
|
||||
using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_N,
|
||||
ThreadClusterArrangeOrder>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const EMeanVarDataType* __restrict__ p_e_grid,
|
||||
const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
|
||||
const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
|
||||
const int32_t* __restrict__ p_in_welford_count_grid,
|
||||
const GammaDataType* __restrict__ p_gamma_grid,
|
||||
const BetaDataType* __restrict__ p_beta_grid,
|
||||
HDataType* __restrict__ p_h_grid,
|
||||
const EHGridDesc_M_N& e_grid_desc_m_n,
|
||||
const EHGridDesc_M_N& h_grid_desc_m_n,
|
||||
const MeanVarGridDesc_M_NBlock& mean_var_grid_desc_m_nblock,
|
||||
const CountGridDesc_M_NBlock& count_grid_desc_m_nblock,
|
||||
const GammaBetaGridDesc_N& gamma_grid_desc_n,
|
||||
const GammaBetaGridDesc_N& beta_grid_desc_n,
|
||||
index_t numMeanVarCountBlockTileIteration_N,
|
||||
index_t NBlockClusterLength,
|
||||
ComputeDataType epsilon,
|
||||
HElementwiseOperation h_element_op)
|
||||
{
|
||||
// Thread/Block id
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const auto block_work_idx = make_tuple(block_global_id / NBlockClusterLength,
|
||||
block_global_id % NBlockClusterLength);
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc_m_n.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_n_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
// Global Memory
|
||||
const auto e_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_m_n.GetElementSpaceSize());
|
||||
|
||||
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_welford_mean_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
|
||||
|
||||
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_welford_var_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
|
||||
|
||||
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_welford_count_grid, count_grid_desc_m_nblock.GetElementSpaceSize());
|
||||
|
||||
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize());
|
||||
|
||||
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_beta_grid, beta_grid_desc_n.GetElementSpaceSize());
|
||||
|
||||
auto h_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_h_grid, h_grid_desc_m_n.GetElementSpaceSize());
|
||||
|
||||
// VGPR
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
|
||||
in_welford_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
|
||||
in_welford_var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
|
||||
in_welford_count_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
|
||||
welford_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
|
||||
welford_var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
|
||||
welford_count_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * NThreadSliceSize,
|
||||
true>
|
||||
e_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * NThreadSliceSize,
|
||||
true>
|
||||
gamma_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * NThreadSliceSize,
|
||||
true>
|
||||
beta_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * NThreadSliceSize,
|
||||
true>
|
||||
h_thread_buf;
|
||||
|
||||
// IO
|
||||
auto threadwise_mean_load_m_nblock =
|
||||
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
|
||||
ComputeDataType,
|
||||
MeanVarGridDesc_M_NBlock,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
ThreadBufferDimAccessOrder,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m_nblock,
|
||||
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_n_cluster_id));
|
||||
|
||||
auto threadwise_var_load_m_nblock =
|
||||
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
|
||||
ComputeDataType,
|
||||
MeanVarGridDesc_M_NBlock,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
ThreadBufferDimAccessOrder,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m_nblock,
|
||||
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_n_cluster_id));
|
||||
|
||||
auto threadwise_count_load_m_nblock =
|
||||
ThreadwiseTensorSliceTransfer_v2<int32_t,
|
||||
int32_t,
|
||||
CountGridDesc_M_NBlock,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
ThreadBufferDimAccessOrder,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
count_grid_desc_m_nblock,
|
||||
make_multi_index(block_work_idx[I0] * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_n_cluster_id));
|
||||
|
||||
auto threadwise_e_load_m_n =
|
||||
ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
|
||||
ComputeDataType,
|
||||
decltype(e_grid_desc_m_n),
|
||||
decltype(thread_buffer_desc_m_n),
|
||||
ThreadBufferLengths_M_N,
|
||||
ThreadBufferDimAccessOrder,
|
||||
1, // SrcVectorDim
|
||||
ESrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
e_grid_desc_m_n,
|
||||
make_multi_index(
|
||||
block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize));
|
||||
|
||||
auto threadwise_gamma_load_n =
|
||||
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
|
||||
ComputeDataType,
|
||||
decltype(gamma_grid_desc_n),
|
||||
decltype(thread_buffer_desc_n),
|
||||
ThreadBufferLengths_N,
|
||||
Sequence<0>, // DimAccessOrder,
|
||||
0, // SrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
gamma_grid_desc_n,
|
||||
make_multi_index(block_work_idx[I1] * N_BlockTileSize +
|
||||
thread_n_cluster_id * NThreadSliceSize));
|
||||
|
||||
auto threadwise_beta_load_n =
|
||||
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
|
||||
ComputeDataType,
|
||||
decltype(beta_grid_desc_n),
|
||||
decltype(thread_buffer_desc_n),
|
||||
ThreadBufferLengths_N,
|
||||
Sequence<0>, // DimAccessOrder,
|
||||
0, // SrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
beta_grid_desc_n,
|
||||
make_multi_index(block_work_idx[I1] * N_BlockTileSize +
|
||||
thread_n_cluster_id * NThreadSliceSize));
|
||||
|
||||
auto threadwise_h_store_m_n =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
|
||||
HDataType,
|
||||
decltype(thread_buffer_desc_m_n),
|
||||
decltype(h_grid_desc_m_n),
|
||||
HElementwiseOperation,
|
||||
ThreadBufferLengths_M_N,
|
||||
ThreadBufferDimAccessOrder,
|
||||
1, // DstVectorDim
|
||||
HDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
h_grid_desc_m_n,
|
||||
make_multi_index(
|
||||
block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize),
|
||||
h_element_op);
|
||||
|
||||
// step1: Merge mean and variance
|
||||
constexpr auto mean_var_count_thread_copy_step_I0_n =
|
||||
make_multi_index(I0, NThreadClusterSize);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
|
||||
welford_var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
|
||||
welford_count_thread_buf(I) = 0;
|
||||
});
|
||||
|
||||
for(index_t n = 0; n < numMeanVarCountBlockTileIteration_N; ++n)
|
||||
{
|
||||
threadwise_mean_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
|
||||
welford_mean_global_val_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
in_welford_mean_thread_buf);
|
||||
|
||||
threadwise_var_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
|
||||
welford_var_global_val_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
in_welford_var_thread_buf);
|
||||
|
||||
threadwise_count_load_m_nblock.Run(count_grid_desc_m_nblock,
|
||||
welford_count_global_val_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
in_welford_count_thread_buf);
|
||||
|
||||
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
|
||||
in_welford_var_thread_buf,
|
||||
in_welford_count_thread_buf,
|
||||
welford_mean_thread_buf,
|
||||
welford_var_thread_buf,
|
||||
welford_count_thread_buf);
|
||||
|
||||
threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
|
||||
mean_var_count_thread_copy_step_I0_n);
|
||||
threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
|
||||
mean_var_count_thread_copy_step_I0_n);
|
||||
threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_nblock,
|
||||
mean_var_count_thread_copy_step_I0_n);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseWelford::Run(
|
||||
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
|
||||
});
|
||||
|
||||
// step2: normalization
|
||||
// h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
|
||||
threadwise_e_load_m_n.Run(e_grid_desc_m_n,
|
||||
e_global_val_buf,
|
||||
thread_buffer_desc_m_n,
|
||||
make_tuple(I0, I0),
|
||||
e_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
|
||||
auto divisor = 1 / ck::math::sqrt(welford_var_thread_buf(m) + epsilon);
|
||||
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
|
||||
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
|
||||
h_thread_buf(Number<m_n>{}) =
|
||||
(e_thread_buf(Number<m_n>{}) - welford_mean_thread_buf(m)) * divisor;
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_gamma_load_n.Run(gamma_grid_desc_n,
|
||||
gamma_global_val_buf,
|
||||
thread_buffer_desc_n,
|
||||
make_tuple(I0),
|
||||
gamma_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
|
||||
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
|
||||
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
|
||||
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) * gamma_thread_buf(n);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_beta_load_n.Run(beta_grid_desc_n,
|
||||
beta_global_val_buf,
|
||||
thread_buffer_desc_n,
|
||||
make_tuple(I0),
|
||||
beta_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
|
||||
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
|
||||
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
|
||||
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) + beta_thread_buf(n);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_h_store_m_n.Run(thread_buffer_desc_m_n,
|
||||
make_tuple(I0, I0),
|
||||
h_thread_buf,
|
||||
h_grid_desc_m_n,
|
||||
h_global_val_buf);
|
||||
|
||||
} // run
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -434,7 +434,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon);
|
||||
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
|
||||
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
|
||||
@@ -319,7 +319,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon);
|
||||
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
|
||||
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
|
||||
Reference in New Issue
Block a user