mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Fused elementwise normalization (#492)
* add fused addition lyernorm * add fused addition lyernorm * changed CMakelist * removed annotates * modified descriptor of C * fixed bug in gridwise add layernorm * format the files * modified name from add&layernorm into elementwise&layernorm * created fused elementwise layernorm branch * change input into tuple type * add sweep once to reduce load & read of C from global memory * modified Argument api * modified way to malloc c in global memory * changed gamma and beta to m_k_desc * fixed bug when sweep once and move CDataType when define device level struct * add src dim for gamma and beta * implement optimization for coalesced * delete a annotation line * fixed some bug to meet the requirements of ck * add bandwidth computing in example, and fixed the time unit * move device_elementwise_layernorm_impl.hpp into device/impl * fixed bug in device_elementwise_layernorm_impl.hpp * changed name from layernorm into normalization * clang-format the changed files * changed the names * moved immidiate results into lds, it become faster in non-sweeponce cases * changed naming of C into X to make the defination more clear * changed naming in example * add tests for elementwise normalization * move example_elementwise_layernorm_blockwise into folder 44_elementwise_normalization * move test_elementwise_layernorm_fp16 into new folder * move elementwise_normalization_instances into a new folder * add more tests in test_elementwise_layernorm_fp16.cpp * added some corner cases in test * fixed method to compute lds size for matrix X * changed name of 44_elementwise_normalization into 45_elementwise_normalization * modified some comments * modified some other confused comments * reduce redundant tests in test_elementwise_layernorm_fp16.cpp
This commit is contained in:
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// FP16
|
||||
void add_device_elementwise_normalization_rank_2_1_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwiseNormalization<ck::Tuple<F16, F16>,
|
||||
F16,
|
||||
F16,
|
||||
F32,
|
||||
F16,
|
||||
element_wise::Add,
|
||||
PassThrough,
|
||||
2,
|
||||
1>>>&);
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename YDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElementwiseNormalization<
|
||||
InDataTypeTuple,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
F32,
|
||||
YDataType,
|
||||
ck::tensor_operation::element_wise::Add,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Rank,
|
||||
NumReduceDim>>
|
||||
{
|
||||
using DeviceOp = DeviceElementwiseNormalization<InDataTypeTuple,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
F32,
|
||||
YDataType,
|
||||
ck::tensor_operation::element_wise::Add,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<GammaDataType, F16> && is_same_v<BetaDataType, F16> &&
|
||||
is_same_v<YDataType, F16>)
|
||||
{
|
||||
if constexpr(Rank == 2 && NumReduceDim == 1)
|
||||
{
|
||||
add_device_elementwise_normalization_rank_2_1_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user