mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
@@ -13,8 +13,8 @@
|
||||
#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using GammaDataType = float;
|
||||
using BetaDataType = float;
|
||||
using YDataType = ck::half_t;
|
||||
using ComputeDataType = float;
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
|
||||
@@ -25,6 +25,10 @@ void add_device_normalization_rank_5_3_swish_f16_instances(
|
||||
void add_device_normalization_rank_5_3_swish_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>&);
|
||||
|
||||
// [x, gamma, beta, y] = [f16, f32, f32, f16]
|
||||
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>>&);
|
||||
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
@@ -70,6 +74,14 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_normalization_rank_5_3_swish_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F32> &&
|
||||
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F16>)
|
||||
{
|
||||
if constexpr(Rank == 5 && NumReduceDim == 3)
|
||||
{
|
||||
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -7,4 +7,5 @@ add_instance_library(device_normalization_instance
|
||||
device_groupnorm_f32_instance.cpp
|
||||
device_groupnorm_swish_f16_instance.cpp
|
||||
device_groupnorm_swish_f32_instance.cpp
|
||||
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
|
||||
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_f16_f32_f32_f16_instances<Swish, 5, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -69,6 +69,32 @@ using device_normalization_f32_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f16_f32_f32_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
Reference in New Issue
Block a user