mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
Layernorm4d (#1022)
* Rename folder * Add layernorm 4d fwd example * Rename original layernorm example * Add layernorm 4d f16 test * Add layernorm4d_fwd client example * Support layernorm4D in ckProfiler * Rename groupnorm to groupnorm fwd in example * Rename layernorm and group fwd in test * Rename normalization to normalization_fwd (instances) * Add fwd to DeviceNormalization * Rename external api header * Rename folder, because we can also add bwd in this folder * Add fwd in layernorm and groupnorm (profiler * Fix compile error --------- Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -1,14 +0,0 @@
|
||||
set(DEVICE_NORMALIZATION_INSTANCES)
|
||||
|
||||
list(APPEND DEVICE_NORMALIZATION_INSTANCES
|
||||
device_layernorm2d_f16_instance.cpp
|
||||
device_layernorm4d_f16_instance.cpp
|
||||
device_groupnorm_f16_instance.cpp
|
||||
device_groupnorm_swish_f16_instance.cpp
|
||||
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
|
||||
device_layernorm2d_f32_instance.cpp
|
||||
device_layernorm4d_f32_instance.cpp
|
||||
device_groupnorm_f32_instance.cpp
|
||||
device_groupnorm_swish_f32_instance.cpp)
|
||||
|
||||
add_instance_library(device_normalization_instance ${DEVICE_NORMALIZATION_INSTANCES})
|
||||
@@ -1,201 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f16_instances =
|
||||
// clang-format off
|
||||
std::tuple <
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_splitk_f16_instances =
|
||||
// clang-format off
|
||||
std::tuple <
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f16_generic_instance = std::tuple<
|
||||
// clang-format off
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_splitk_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f32_generic_instance = std::tuple<
|
||||
// clang-format off
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f16_f32_f32_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_splitk_f16_f32_f32_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple<
|
||||
// clang-format off
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,14 @@
|
||||
set(DEVICE_NORMALIZATION_FWD_INSTANCES)
|
||||
|
||||
list(APPEND DEVICE_NORMALIZATION_FWD_INSTANCES
|
||||
device_layernorm2d_fwd_f16_instance.cpp
|
||||
device_layernorm4d_fwd_f16_instance.cpp
|
||||
device_groupnorm_fwd_f16_instance.cpp
|
||||
device_groupnorm_fwd_swish_f16_instance.cpp
|
||||
device_groupnorm_fwd_swish_f16_f32_f32_f16_instance.cpp
|
||||
device_layernorm2d_fwd_f32_instance.cpp
|
||||
device_layernorm4d_fwd_f32_instance.cpp
|
||||
device_groupnorm_fwd_f32_instance.cpp
|
||||
device_groupnorm_fwd_swish_f32_instance.cpp)
|
||||
|
||||
add_instance_library(device_normalization_fwd_instance ${DEVICE_NORMALIZATION_FWD_INSTANCES})
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
#include "normalization_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -10,8 +10,8 @@ namespace instance {
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
void add_device_normalization_rank_5_3_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Pass, 5, 3>>>&
|
||||
void add_device_normalization_fwd_rank_5_3_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 5, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
#include "normalization_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -10,8 +10,8 @@ namespace instance {
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
void add_device_normalization_rank_5_3_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 5, 3>>>&
|
||||
void add_device_normalization_fwd_rank_5_3_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Pass, 5, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
#include "normalization_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -10,8 +10,8 @@ namespace instance {
|
||||
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
|
||||
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F16, F32, Swish, 5, 3>>>&
|
||||
void add_device_normalization_fwd_rank_5_3_swish_f16_f32_f32_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
#include "normalization_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -10,8 +10,8 @@ namespace instance {
|
||||
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
|
||||
void add_device_normalization_rank_5_3_swish_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Swish, 5, 3>>>&
|
||||
void add_device_normalization_fwd_rank_5_3_swish_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Swish, 5, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
#include "normalization_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -10,8 +10,8 @@ namespace instance {
|
||||
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
|
||||
void add_device_normalization_rank_5_3_swish_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>&
|
||||
void add_device_normalization_fwd_rank_5_3_swish_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
#include "normalization_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -10,8 +10,8 @@ namespace instance {
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
void add_device_normalization_rank_2_1_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Pass, 2, 1>>>&
|
||||
void add_device_normalization_fwd_rank_2_1_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 2, 1>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
#include "normalization_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -10,8 +10,8 @@ namespace instance {
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
void add_device_normalization_rank_2_1_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 2, 1>>>&
|
||||
void add_device_normalization_fwd_rank_2_1_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Pass, 2, 1>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
#include "normalization_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -10,8 +10,8 @@ namespace instance {
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
void add_device_normalization_rank_4_3_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Pass, 4, 3>>>&
|
||||
void add_device_normalization_fwd_rank_4_3_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 4, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
#include "normalization_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -10,8 +10,8 @@ namespace instance {
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
void add_device_normalization_rank_4_3_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 4, 3>>>&
|
||||
void add_device_normalization_fwd_rank_4_3_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Pass, 4, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
@@ -0,0 +1,201 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f16_instances =
|
||||
// clang-format off
|
||||
std::tuple <
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_splitk_f16_instances =
|
||||
// clang-format off
|
||||
std::tuple <
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8, 2>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f16_generic_instance = std::tuple<
|
||||
// clang-format off
|
||||
DeviceNormalizationFwdImpl<F16, F16, F16, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_splitk_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f32_generic_instance = std::tuple<
|
||||
// clang-format off
|
||||
DeviceNormalizationFwdImpl<F32, F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f16_f32_f32_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_splitk_f16_f32_f32_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector>
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1>, // irregular size
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4, 2>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4, 1>,
|
||||
DeviceNormalizationFwdSplitKImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple<
|
||||
// clang-format off
|
||||
DeviceNormalizationFwdImpl<F16, F32, F32, F32, F16, F32, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user