Groupnorm + swish external api (#668)

* Rename to proper naming

* Add example of groupnorm + swish

* Extract duplicate code in example

* Add groupnorm + swish instances

* Ractor instance generation, split into multiple cpp file

* Add external api and client example

* Refine profiler message

* Use ck math version of exp

* Refine problem size in example

* Add host version of exp
This commit is contained in:
rocking5566
2023-04-10 21:02:17 +08:00
committed by GitHub
parent 3248387bbb
commit ed3a2e5226
23 changed files with 626 additions and 173 deletions

View File

@@ -96,6 +96,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish;
template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;

View File

@@ -0,0 +1,81 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_normalization_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Swish, 5, 3>>>&);
// FP32
void add_device_normalization_rank_5_3_swish_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>&);
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
F32,
YDataType,
ck::tensor_operation::element_wise::Swish,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
F32,
YDataType,
ck::tensor_operation::element_wise::Swish,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>)
{
if constexpr(Rank == 5 && NumReduceDim == 3)
{
add_device_normalization_rank_5_3_swish_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>)
{
if constexpr(Rank == 5 && NumReduceDim == 3)
{
add_device_normalization_rank_5_3_swish_f32_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck