layernorm and groupnorm backward data (#1083)

* rename folder

* Add type string

* Remove typo

* Add deviceOp to backward x

* Add comment to describe the behavior of backward normalization

* Add kernel function, prepare to implement

* implement generic kernel

* Check vector size

* Add sweep once pipeline for small reduce size

* Fix bug of KRaw_ error

* Fix bug of dx stride

* sanity check for mean and rstd

* backward x for groupnorm

* Add bwd x instance

* add layernorm 2d bwd gamma beta instances

* Change save mean var type from f32 to f16 in f16 mode

* Change the example to f16

* Add groupnorm bwd gamma beta instance

* Add groupnorm bwd x instance

* Fix naming

* Add layernorm bwd x ckprofiler

* Add groupnorm bwd x profiler

* clang format

* Rename bwd x to bwd data

* Fix bug of verification in profiler

* Add test of layernorm and groupnorm bwd data

* Add missing cmake

* Add layernorm2d bwd data

* rename fwd example

* Add groupnorm client example

* Fix typo. replace Invarient with Invariant

* Add checking before running the best instance
This commit is contained in:
rocking
2023-12-19 04:23:11 +08:00
committed by GitHub
parent ad0a8e4cd2
commit a69aa2a11a
65 changed files with 3050 additions and 110 deletions

View File

@@ -0,0 +1,8 @@
set(DEVICE_NORMALIZATION_bwd_data_INSTANCES)
list(APPEND DEVICE_NORMALIZATION_bwd_data_INSTANCES
device_groupnorm_bwd_data_f32_instance.cpp
device_layernorm2d_bwd_data_f16_instance.cpp
device_layernorm2d_bwd_data_f32_instance.cpp)
add_instance_library(device_normalization_bwd_data_instance ${DEVICE_NORMALIZATION_bwd_data_INSTANCES})

View File

@@ -0,0 +1,22 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_data_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_groupnorm_bwd_data_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdData<F32, F32, F32, F32, F32, 5, 3>>>&
instances)
{
add_device_operation_instances(instances, device_groupnorm_bwd_data_f32_generic_instance{});
add_device_operation_instances(instances, device_groupnorm_bwd_data_f32_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,23 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_data_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_data_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdData<F16, F16, F16, F16, F16, 2, 1>>>&
instances)
{
add_device_operation_instances(instances,
device_layernorm_bwd_data_f16_generic_instance<2, 1>{});
add_device_operation_instances(instances, device_layernorm_bwd_data_f16_instances<2, 1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,23 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_data_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_data_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdData<F32, F32, F32, F32, F32, 2, 1>>>&
instances)
{
add_device_operation_instances(instances,
device_layernorm_bwd_data_f32_generic_instance<2, 1>{});
add_device_operation_instances(instances, device_layernorm_bwd_data_f32_instances<2, 1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_data_f16_instances =
// clang-format off
std::tuple <
// DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDXFastestDimReduced, DXDstVectorSize>
DeviceNormalizationBwdDataImpl<F16, F16, F16, F16, F32, F16, Rank, Reduce, 256, 1, 256, 1, 2, true, 2, true, 2, true, 2, false, 1, true, 2>,
DeviceNormalizationBwdDataImpl<F16, F16, F16, F16, F32, F16, Rank, Reduce, 256, 1, 256, 1, 4, true, 4, true, 4, true, 4, false, 1, true, 4>,
DeviceNormalizationBwdDataImpl<F16, F16, F16, F16, F32, F16, Rank, Reduce, 256, 1, 256, 1, 8, true, 8, true, 8, true, 8, false, 1, true, 8>
// clang-format on
>;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_data_f16_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationBwdDataImpl<F16, F16, F16, F16, F32, F16, Rank, Reduce, 64, 1, 64, 1, 1, true, 1, true, 1, true, 1, false, 1, true, 1>
// clang-format on
>;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_data_f32_instances =
// clang-format off
std::tuple <
// DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDXFastestDimReduced, DXDstVectorSize>
DeviceNormalizationBwdDataImpl<F32, F32, F32, F32, F32, F32, Rank, Reduce, 256, 1, 256, 1, 2, true, 2, true, 2, true, 2, false, 1, true, 2>,
DeviceNormalizationBwdDataImpl<F32, F32, F32, F32, F32, F32, Rank, Reduce, 256, 1, 256, 1, 4, true, 4, true, 4, true, 4, false, 1, true, 4>
// clang-format on
>;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_data_f32_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationBwdDataImpl<F32, F32, F32, F32, F32, F32, Rank, Reduce, 64, 1, 64, 1, 1, true, 1, true, 1, true, 1, false, 1, true, 1>
// clang-format on
>;
using device_groupnorm_bwd_data_f32_instances =
// clang-format off
std::tuple <
// DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDXFastestDimReduced, DXDstVectorSize>
DeviceNormalizationBwdDataImpl<F32, F32, F32, F32, F32, F32, 5, 3, 256, 1, 256, 1, 2, true, 2, true, 2, true, 2, false, 1, true, 2>,
DeviceNormalizationBwdDataImpl<F32, F32, F32, F32, F32, F32, 5, 3, 256, 1, 256, 1, 4, true, 4, true, 4, true, 4, false, 1, true, 4>
// clang-format on
>;
using device_groupnorm_bwd_data_f32_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationBwdDataImpl<F32, F32, F32, F32, F32, F32, 5, 3, 64, 1, 64, 1, 1, true, 1, true, 1, true, 1, false, 1, true, 1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,8 @@
set(DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES)
list(APPEND DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES
device_groupnorm_bwd_gamma_beta_f32_instance.cpp
device_layernorm2d_bwd_gamma_beta_f16_instance.cpp
device_layernorm2d_bwd_gamma_beta_f32_instance.cpp)
add_instance_library(device_normalization_bwd_gamma_beta_instance ${DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES})

View File

@@ -0,0 +1,23 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_gamma_beta_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_groupnorm_bwd_gamma_beta_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F32, F32, F32, F32, F32, 5, 3>>>&
instances)
{
add_device_operation_instances(instances, device_groupnorm_bwd_gamma_beta_f32_instances{});
add_device_operation_instances(instances,
device_groupnorm_bwd_gamma_beta_f32_generic_instance{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,24 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_gamma_beta_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F16, F16, F16, F16, F16, 2, 1>>>&
instances)
{
add_device_operation_instances(instances,
device_layernorm_bwd_gamma_beta_f16_generic_instance<2, 1>{});
add_device_operation_instances(instances,
device_layernorm_bwd_gamma_beta_f16_instances<2, 1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,24 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_gamma_beta_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F32, F32, F32, F32, F32, 2, 1>>>&
instances)
{
add_device_operation_instances(instances,
device_layernorm_bwd_gamma_beta_f32_generic_instance<2, 1>{});
add_device_operation_instances(instances,
device_layernorm_bwd_gamma_beta_f32_instances<2, 1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_gamma_beta_f16_instances =
// clang-format off
std::tuple <
// DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize>
DeviceNormalizationBwdGammaBetaImpl<F16, F16, F16, F32, F16, F16, Rank, Reduce, 256, 1, 256, 2, 1, false, 2, false, 2, true, 1, 2, 2>,
DeviceNormalizationBwdGammaBetaImpl<F16, F16, F16, F32, F16, F16, Rank, Reduce, 256, 1, 256, 4, 1, false, 4, false, 4, true, 1, 4, 4>,
DeviceNormalizationBwdGammaBetaImpl<F16, F16, F16, F32, F16, F16, Rank, Reduce, 256, 1, 256, 8, 1, false, 8, false, 8, true, 1, 8, 8>
// clang-format on
>;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_gamma_beta_f16_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationBwdGammaBetaImpl<F16, F16, F16, F32, F16, F16, Rank, Reduce, 64, 1, 64, 1, 1, false, 1, false, 1, true, 1, 1, 1>
// clang-format on
>;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_gamma_beta_f32_instances =
// clang-format off
std::tuple <
// DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize>
DeviceNormalizationBwdGammaBetaImpl<F32, F32, F32, F32, F32, F32, Rank, Reduce, 256, 1, 256, 2, 1, false, 2, false, 2, true, 1, 2, 2>,
DeviceNormalizationBwdGammaBetaImpl<F32, F32, F32, F32, F32, F32, Rank, Reduce, 256, 1, 256, 4, 1, false, 4, false, 4, true, 1, 4, 4>
// clang-format on
>;
template <index_t Rank, index_t Reduce>
using device_layernorm_bwd_gamma_beta_f32_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationBwdGammaBetaImpl<F32, F32, F32, F32, F32, F32, Rank, Reduce, 64, 1, 64, 1, 1, false, 1, false, 1, true, 1, 1, 1>
// clang-format on
>;
using device_groupnorm_bwd_gamma_beta_f32_instances =
// clang-format off
std::tuple <
// DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize>
DeviceNormalizationBwdGammaBetaImpl<F32, F32, F32, F32, F32, F32, 5, 3, 256, 1, 256, 2, 1, false, 2, false, 2, false, 1, 2, 2>,
DeviceNormalizationBwdGammaBetaImpl<F32, F32, F32, F32, F32, F32, 5, 3, 256, 1, 256, 4, 1, false, 4, false, 4, false, 1, 4, 4>
// clang-format on
>;
using device_groupnorm_bwd_gamma_beta_f32_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationBwdGammaBetaImpl<F32, F32, F32, F32, F32, F32, 5, 3, 64, 1, 64, 1, 1, false, 1, false, 1, false, 1, 1, 1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -11,7 +11,7 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_fwd_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 5, 3>>>&
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Pass, 5, 3>>>&
instances)
{
add_device_operation_instances(instances,

View File

@@ -11,7 +11,7 @@ namespace instance {
using Swish = ck::tensor_operation::element_wise::Swish;
void add_device_normalization_fwd_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Swish, 5, 3>>>&
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Swish, 5, 3>>>&
instances)
{
add_device_operation_instances(instances,

View File

@@ -11,7 +11,7 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_fwd_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 2, 1>>>&
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Pass, 2, 1>>>&
instances)
{
add_device_operation_instances(instances,

View File

@@ -11,7 +11,7 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_fwd_rank_4_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 4, 3>>>&
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Pass, 4, 3>>>&
instances)
{
add_device_operation_instances(instances,

View File

@@ -23,24 +23,24 @@ using device_normalization_f16_instances =
// clang-format off
std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
// clang-format on
>;
@@ -49,31 +49,31 @@ using device_normalization_splitk_f16_instances =
// clang-format off
std::tuple <
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
// clang-format on
>;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_generic_instance = std::tuple<
// clang-format off
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
// clang-format on
>;