mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Single-kernel GEMM + layernorm (#263)
* dump lds content in appropriate precision type * add squared add reduction op; allows sq sum * initial stub from regular gemm impl * layernorm example code & host verification * initial layernorm implementation * tidy up * make C0 precision type consistent with C * clang-tidy and additional comments * tighten up example code * account for extra flops/bytes from normalization * clang-format * c0 bias/beta/gamma now have its own precision type * AccElemOp for gemm outputs prior to feeding to layernorm * update workgroup mapping * rename kernel template param to reflect its dual use * use LDS mem pool for reduction workspace * change cshuffle precision type to f16; clean up * clang-format * correct naming * explicit cast * fully implemented gemm + bias + activation + add + norm * activation in correct order * reflect reduction API's recent change * amend * clean up; add comment * keep up with recent changes in reduction API * format * resolve merge conflicts Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -220,12 +220,24 @@ struct Tensor
|
||||
|
||||
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {}
|
||||
|
||||
template <typename OutT>
|
||||
Tensor<OutT> CopyAsType()
|
||||
{
|
||||
Tensor<OutT> ret(mDesc);
|
||||
for(size_t i = 0; i < mData.size(); i++)
|
||||
{
|
||||
ret.mData[i] = static_cast<OutT>(mData[i]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {}
|
||||
|
||||
Tensor& operator=(const Tensor& other)
|
||||
{
|
||||
mDesc = other.mDesc;
|
||||
mData = other.mData;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
|
||||
@@ -0,0 +1,236 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename C0DataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct ReferenceGemmLayernorm : public device::BaseOperator
|
||||
{
|
||||
using ReferenceGemmInstance = ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
template <typename InDataType, typename OutDataType, typename ComputeDataType>
|
||||
static void RunLayernorm(Tensor<OutDataType>& result,
|
||||
const Tensor<ComputeDataType>& acc, // MxN
|
||||
const Tensor<InDataType>& gamma, // 1xN
|
||||
const Tensor<InDataType>& beta, // 1xN
|
||||
const InDataType epsilon = 1e-5)
|
||||
{
|
||||
assert(acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] &&
|
||||
acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]);
|
||||
|
||||
size_t M = acc.mDesc.GetLengths()[0];
|
||||
size_t N = acc.mDesc.GetLengths()[1];
|
||||
|
||||
Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M})));
|
||||
Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M})));
|
||||
Tensor<ComputeDataType> acc_layernorm(acc);
|
||||
|
||||
// reduce N dim
|
||||
for(size_t i = 0; i < M; i++)
|
||||
{
|
||||
ComputeDataType sum_acc_sq = 0;
|
||||
ComputeDataType sum_acc = 0;
|
||||
for(size_t j = 0; j < N; j++)
|
||||
{
|
||||
sum_acc_sq += acc_layernorm(i, j) * acc_layernorm(i, j);
|
||||
sum_acc += acc_layernorm(i, j);
|
||||
}
|
||||
avg_acc_sq(i) = sum_acc_sq / N;
|
||||
avg_acc(i) = sum_acc / N;
|
||||
}
|
||||
|
||||
// normalize
|
||||
acc_layernorm.ForEach([&](auto& self, auto idx) {
|
||||
self(idx[0], idx[1]) =
|
||||
(self(idx[0], idx[1]) - avg_acc(idx[0])) /
|
||||
sqrt(avg_acc_sq(idx[0]) - avg_acc(idx[0]) * avg_acc(idx[0]) + epsilon);
|
||||
});
|
||||
|
||||
// affine
|
||||
acc_layernorm.ForEach([&](auto& self, auto idx) {
|
||||
self(idx[0], idx[1]) = self(idx[0], idx[1]) * gamma(idx[1]) + beta(idx[1]);
|
||||
});
|
||||
|
||||
// cast
|
||||
result = acc_layernorm.template CopyAsType<OutDataType>();
|
||||
}
|
||||
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_k_n,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
const Tensor<C0DataType>& c0_n_bias, // 1xN
|
||||
const Tensor<C0DataType>& c0_m_n_add, // MxN
|
||||
const Tensor<C0DataType>& c0_n_gamma, // 1xN
|
||||
const Tensor<C0DataType>& c0_n_beta, // 1xN
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
const CDataType epsilon = 1e-5)
|
||||
: a_m_k_{a_m_k},
|
||||
b_k_n_{b_k_n},
|
||||
c_m_n_{c_m_n},
|
||||
c0_n_bias_{c0_n_bias},
|
||||
c0_m_n_add_{c0_m_n_add},
|
||||
c0_n_gamma_{c0_n_gamma},
|
||||
c0_n_beta_{c0_n_beta},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
acc_element_op_{acc_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
epsilon_{epsilon}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_m_k_;
|
||||
const Tensor<BDataType>& b_k_n_;
|
||||
Tensor<CDataType>& c_m_n_;
|
||||
const Tensor<C0DataType>& c0_n_bias_;
|
||||
const Tensor<C0DataType>& c0_m_n_add_;
|
||||
const Tensor<C0DataType>& c0_n_gamma_;
|
||||
const Tensor<C0DataType>& c0_n_beta_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
AccElementwiseOperation acc_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
|
||||
const CDataType epsilon_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
// using Argument = ReferenceGemm::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
Tensor<AccDataType> acc_m_n(arg.c_m_n_.mDesc);
|
||||
acc_m_n.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_gemm.MakeArgument(arg.a_m_k_,
|
||||
arg.b_k_n_,
|
||||
acc_m_n,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
element_wise::PassThrough{});
|
||||
|
||||
// gemm
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
// activation(acc + bias)
|
||||
acc_m_n.ForEach([&](auto& self, auto idx) {
|
||||
AccDataType out;
|
||||
arg.acc_element_op_(out, acc_m_n(idx[0], idx[1]) + arg.c0_n_bias_(idx[1]));
|
||||
self(idx[0], idx[1]) = out;
|
||||
});
|
||||
|
||||
// add from other layers
|
||||
acc_m_n.ForEach([&](auto& self, auto idx) {
|
||||
self(idx[0], idx[1]) += arg.c0_m_n_add_(idx[0], idx[1]);
|
||||
});
|
||||
|
||||
// layernorm
|
||||
RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_gamma_, arg.c0_n_beta_);
|
||||
|
||||
// elementwise op
|
||||
arg.c_m_n_.ForEach([&](auto& self, auto idx) {
|
||||
arg.c_element_op_(self(idx[0], idx[1]), self(idx[0], idx[1]));
|
||||
});
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_k_n,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
const Tensor<C0DataType>& c0_n_bias, // 1xN
|
||||
const Tensor<C0DataType>& c0_m_n_add, // 1xN
|
||||
const Tensor<C0DataType>& c0_n_gamma, // 1xN
|
||||
const Tensor<C0DataType>& c0_n_beta, // 1xN
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
const CDataType epsilon = 1e-5)
|
||||
{
|
||||
return Argument{a_m_k,
|
||||
b_k_n,
|
||||
c_m_n,
|
||||
c0_n_bias,
|
||||
c0_m_n_add,
|
||||
c0_n_gamma,
|
||||
c0_n_beta,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
acc_element_op,
|
||||
c_element_op,
|
||||
epsilon};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceGemmLayernorm"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user