mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
embedding fuse layernorm (#405)
* add gridwise/device sparse embedding * update code * update code * remove useless makefile * code fix * workable * work properly * emb add * add more instance * format * remove useless code * fix format * fix clang-tidy * clean * fix a compile error Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: Chao Liu <lc.roy86@gmail.com>
This commit is contained in:
1
example/36_sparse_embedding/CMakeLists.txt
Normal file
1
example/36_sparse_embedding/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_sparse_embedding3_forward_layernorm sparse_embedding3_forward_layernorm.cpp)
|
||||
@@ -0,0 +1,222 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <getopt.h>
|
||||
#include <ctime>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_sparse_embedding3_forward_layernorm.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp"
|
||||
|
||||
// using EmbType = float;
|
||||
// using IndexType = int64_t;
|
||||
// using GammaDataType = float;
|
||||
// using BetaDataType = float;
|
||||
// using AccDataType = float;
|
||||
// using OutType = float;
|
||||
|
||||
using EmbType = ck::half_t;
|
||||
using IndexType = int64_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using OutType = ck::half_t;
|
||||
|
||||
// clang-format off
|
||||
// BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize
|
||||
using DeviceInstance_fp32_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 256, 1, 1>;
|
||||
using DeviceInstance_fp32_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 512, 1, 1>;
|
||||
using DeviceInstance_fp32_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 768, 1, 1>;
|
||||
using DeviceInstance_fp32_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1024, 1, 1>;
|
||||
using DeviceInstance_fp32_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1536, 1, 1>;
|
||||
using DeviceInstance_fp32_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 2048, 1, 4>;
|
||||
using DeviceInstance_fp32_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 4096, 1, 4>;
|
||||
using DeviceInstance_fp32_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 8192, 1, 4>;
|
||||
using DeviceInstance_fp32_e16384 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 16384, 1, 4>;
|
||||
|
||||
using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 256, 1, 1>;
|
||||
using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 512, 1, 2>;
|
||||
using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 768, 1, 1>;
|
||||
using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1024, 1, 2>;
|
||||
using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1536, 1, 2>;
|
||||
using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 2048, 1, 2>;
|
||||
using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 4096, 1, 8>;
|
||||
using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 8192, 1, 8>;
|
||||
|
||||
template<typename emb_type, ck::index_t dim> struct emb_kernel{};
|
||||
|
||||
template<> struct emb_kernel<float, 256> { using kernel_type = DeviceInstance_fp32_e256; };
|
||||
template<> struct emb_kernel<float, 512> { using kernel_type = DeviceInstance_fp32_e512; };
|
||||
template<> struct emb_kernel<float, 768> { using kernel_type = DeviceInstance_fp32_e768; };
|
||||
template<> struct emb_kernel<float, 1024> { using kernel_type = DeviceInstance_fp32_e1024;};
|
||||
template<> struct emb_kernel<float, 1536> { using kernel_type = DeviceInstance_fp32_e1536;};
|
||||
template<> struct emb_kernel<float, 2048> { using kernel_type = DeviceInstance_fp32_e2048;};
|
||||
template<> struct emb_kernel<float, 4096> { using kernel_type = DeviceInstance_fp32_e4096;};
|
||||
template<> struct emb_kernel<float, 8192> { using kernel_type = DeviceInstance_fp32_e8192;};
|
||||
template<> struct emb_kernel<float, 16384>{ using kernel_type = DeviceInstance_fp32_e16384;};
|
||||
|
||||
template<> struct emb_kernel<ck::half_t, 256> { using kernel_type = DeviceInstance_fp16_e256; };
|
||||
template<> struct emb_kernel<ck::half_t, 512> { using kernel_type = DeviceInstance_fp16_e512; };
|
||||
template<> struct emb_kernel<ck::half_t, 768> { using kernel_type = DeviceInstance_fp16_e768; };
|
||||
template<> struct emb_kernel<ck::half_t, 1024> { using kernel_type = DeviceInstance_fp16_e1024; };
|
||||
template<> struct emb_kernel<ck::half_t, 1536> { using kernel_type = DeviceInstance_fp16_e1536; };
|
||||
template<> struct emb_kernel<ck::half_t, 2048> { using kernel_type = DeviceInstance_fp16_e2048; };
|
||||
template<> struct emb_kernel<ck::half_t, 4096> { using kernel_type = DeviceInstance_fp16_e4096; };
|
||||
template<> struct emb_kernel<ck::half_t, 8192> { using kernel_type = DeviceInstance_fp16_e8192; };
|
||||
|
||||
// clang-format on
|
||||
|
||||
int main()
|
||||
{
|
||||
bool time_kernel = true;
|
||||
|
||||
constexpr auto num_rows = 65536;
|
||||
constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{};
|
||||
// constexpr auto dims = ck::Sequence<256, 512>{};
|
||||
constexpr auto index_length = 2048;
|
||||
constexpr AccDataType epsilon = 1e-4;
|
||||
|
||||
auto f_host_tensor_desc_1d = [](std::size_t len_) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({len_}));
|
||||
};
|
||||
|
||||
auto f_host_tensor_desc_2d = [](std::size_t rows_, std::size_t cols_) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({rows_, cols_}));
|
||||
};
|
||||
|
||||
using ReferenceInstance =
|
||||
ck::tensor_operation::host::ReferenceSparseEmbedding3ForwardLayernorm<EmbType,
|
||||
IndexType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
OutType>;
|
||||
|
||||
ck::static_for<0, dims.Size(), 1>{}([&](auto I) {
|
||||
std::srand(std::time(nullptr));
|
||||
constexpr auto current_dim = dims.At(I);
|
||||
Tensor<EmbType> emb_a(f_host_tensor_desc_2d(num_rows, current_dim));
|
||||
Tensor<EmbType> emb_b(f_host_tensor_desc_2d(num_rows, current_dim));
|
||||
Tensor<EmbType> emb_c(f_host_tensor_desc_2d(num_rows, current_dim));
|
||||
|
||||
Tensor<IndexType> index_a(f_host_tensor_desc_1d(index_length));
|
||||
Tensor<IndexType> index_b(f_host_tensor_desc_1d(index_length));
|
||||
Tensor<IndexType> index_c(f_host_tensor_desc_1d(index_length));
|
||||
|
||||
Tensor<GammaDataType> gamma(f_host_tensor_desc_1d(current_dim));
|
||||
Tensor<BetaDataType> beta(f_host_tensor_desc_1d(current_dim));
|
||||
|
||||
Tensor<OutType> out(f_host_tensor_desc_2d(index_length, current_dim));
|
||||
|
||||
emb_a.GenerateTensorValue(GeneratorTensor_3<EmbType>{0.0, 1.0});
|
||||
emb_b.GenerateTensorValue(GeneratorTensor_3<EmbType>{0.0, 1.0});
|
||||
emb_c.GenerateTensorValue(GeneratorTensor_3<EmbType>{0.0, 1.0});
|
||||
|
||||
index_a.GenerateTensorValue(GeneratorTensor_2<IndexType>{0, num_rows});
|
||||
index_b.GenerateTensorValue(GeneratorTensor_2<IndexType>{0, num_rows});
|
||||
index_c.GenerateTensorValue(GeneratorTensor_2<IndexType>{0, num_rows});
|
||||
|
||||
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
|
||||
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem emb_a_dev(sizeof(EmbType) * emb_a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem emb_b_dev(sizeof(EmbType) * emb_b.mDesc.GetElementSpaceSize());
|
||||
DeviceMem emb_c_dev(sizeof(EmbType) * emb_c.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem index_a_dev(sizeof(IndexType) * index_a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem index_b_dev(sizeof(IndexType) * index_b.mDesc.GetElementSpaceSize());
|
||||
DeviceMem index_c_dev(sizeof(IndexType) * index_c.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
|
||||
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem out_dev(sizeof(OutType) * out.mDesc.GetElementSpaceSize());
|
||||
|
||||
emb_a_dev.ToDevice(emb_a.mData.data());
|
||||
emb_b_dev.ToDevice(emb_b.mData.data());
|
||||
emb_c_dev.ToDevice(emb_c.mData.data());
|
||||
|
||||
index_a_dev.ToDevice(index_a.mData.data());
|
||||
index_b_dev.ToDevice(index_b.mData.data());
|
||||
index_c_dev.ToDevice(index_c.mData.data());
|
||||
|
||||
gamma_dev.ToDevice(gamma.mData.data());
|
||||
beta_dev.ToDevice(beta.mData.data());
|
||||
|
||||
auto device_instance = typename emb_kernel<EmbType, current_dim>::kernel_type{};
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(out_dev.GetDeviceBuffer(),
|
||||
emb_a_dev.GetDeviceBuffer(),
|
||||
emb_b_dev.GetDeviceBuffer(),
|
||||
emb_c_dev.GetDeviceBuffer(),
|
||||
index_a_dev.GetDeviceBuffer(),
|
||||
index_b_dev.GetDeviceBuffer(),
|
||||
index_c_dev.GetDeviceBuffer(),
|
||||
gamma_dev.GetDeviceBuffer(),
|
||||
beta_dev.GetDeviceBuffer(),
|
||||
num_rows,
|
||||
current_dim,
|
||||
index_length,
|
||||
epsilon);
|
||||
std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString()
|
||||
<< std::endl
|
||||
<< std::flush;
|
||||
|
||||
bool is_supported = device_instance.IsSupportedArgument(argument_ptr.get());
|
||||
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "Runtime parameters are not supported" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
auto invoker_ptr = device_instance.MakeInvokerPointer();
|
||||
float time_ms = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
bool pass = true;
|
||||
{
|
||||
Tensor<OutType> out_from_dev(f_host_tensor_desc_2d(index_length, current_dim));
|
||||
ReferenceInstance ref;
|
||||
auto ref_argument = ref.MakeArgument(out,
|
||||
emb_a,
|
||||
emb_b,
|
||||
emb_c,
|
||||
index_a,
|
||||
index_b,
|
||||
index_c,
|
||||
gamma,
|
||||
beta,
|
||||
num_rows,
|
||||
current_dim,
|
||||
index_length,
|
||||
epsilon);
|
||||
auto ref_invoker = ref.MakeInvoker();
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
out_dev.FromDevice(out_from_dev.mData.data());
|
||||
pass &= ck::utils::check_err(
|
||||
out_from_dev.mData, out.mData, "Error: Incorrect results", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
double total_read = current_dim * index_length * 3 * sizeof(EmbType) +
|
||||
current_dim * sizeof(GammaDataType) +
|
||||
current_dim * sizeof(BetaDataType);
|
||||
double total_write = current_dim * index_length * sizeof(OutType);
|
||||
double gbps = (total_read + total_write) / time_ms / 1e6;
|
||||
|
||||
std::cout << ", total bytes:" << (total_read + total_write) << ", time:" << time_ms
|
||||
<< ", gbps:" << gbps << ", valid:" << (pass ? "y" : "n") << std::endl
|
||||
<< std::flush;
|
||||
});
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -51,4 +51,5 @@ add_subdirectory(32_batched_gemm_scale_softmax_gemm)
|
||||
add_subdirectory(33_multiple_reduce)
|
||||
add_subdirectory(34_batchnorm)
|
||||
add_subdirectory(35_splitK_gemm)
|
||||
add_subdirectory(36_sparse_embedding)
|
||||
add_subdirectory(41_grouped_conv_conv_fwd)
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename EmbType,
|
||||
typename IndexType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename OutType,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t DimClusterSize,
|
||||
ck::index_t RowClusterSize,
|
||||
ck::index_t DimPerBlock,
|
||||
ck::index_t RowPerBlock,
|
||||
ck::index_t DimThreadSize,
|
||||
ck::index_t RowVectorSize>
|
||||
struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
|
||||
{
|
||||
|
||||
static auto MakeOutputDescriptor(const index_t index_length, const index_t rows)
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(index_length, rows));
|
||||
}
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(OutType* p_out,
|
||||
const EmbType* p_emb_a,
|
||||
const EmbType* p_emb_b,
|
||||
const EmbType* p_emb_c,
|
||||
const IndexType* p_index_a,
|
||||
const IndexType* p_index_b,
|
||||
const IndexType* p_index_c,
|
||||
const GammaDataType* p_gamma,
|
||||
const BetaDataType* p_beta,
|
||||
const ck::index_t NumRows,
|
||||
const ck::index_t EmbeddingDim,
|
||||
const ck::index_t IndexLength,
|
||||
const AccDataType epsilon)
|
||||
: p_out_(p_out),
|
||||
p_emb_a_(p_emb_a),
|
||||
p_emb_b_(p_emb_b),
|
||||
p_emb_c_(p_emb_c),
|
||||
p_index_a_(p_index_a),
|
||||
p_index_b_(p_index_b),
|
||||
p_index_c_(p_index_c),
|
||||
p_gamma_(p_gamma),
|
||||
p_beta_(p_beta),
|
||||
NumRows_(NumRows),
|
||||
EmbeddingDim_(EmbeddingDim),
|
||||
IndexLength_(IndexLength),
|
||||
epsilon_(epsilon)
|
||||
{
|
||||
grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize;
|
||||
}
|
||||
|
||||
OutType* p_out_;
|
||||
const EmbType* p_emb_a_;
|
||||
const EmbType* p_emb_b_;
|
||||
const EmbType* p_emb_c_;
|
||||
const IndexType* p_index_a_;
|
||||
const IndexType* p_index_b_;
|
||||
const IndexType* p_index_c_;
|
||||
const GammaDataType* p_gamma_;
|
||||
const BetaDataType* p_beta_;
|
||||
ck::index_t NumRows_;
|
||||
ck::index_t EmbeddingDim_;
|
||||
ck::index_t IndexLength_;
|
||||
AccDataType epsilon_;
|
||||
|
||||
size_t grid_size_;
|
||||
};
|
||||
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(void* p_out,
|
||||
const void* p_emb_a,
|
||||
const void* p_emb_b,
|
||||
const void* p_emb_c,
|
||||
const void* p_index_a,
|
||||
const void* p_index_b,
|
||||
const void* p_index_c,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
ck::index_t NumRows,
|
||||
ck::index_t EmbeddingDim,
|
||||
ck::index_t IndexLength,
|
||||
const AccDataType epsilon)
|
||||
{
|
||||
return std::make_unique<Argument>(reinterpret_cast<OutType*>(p_out),
|
||||
reinterpret_cast<const EmbType*>(p_emb_a),
|
||||
reinterpret_cast<const EmbType*>(p_emb_b),
|
||||
reinterpret_cast<const EmbType*>(p_emb_c),
|
||||
reinterpret_cast<const IndexType*>(p_index_a),
|
||||
reinterpret_cast<const IndexType*>(p_index_b),
|
||||
reinterpret_cast<const IndexType*>(p_index_c),
|
||||
reinterpret_cast<const GammaDataType*>(p_gamma),
|
||||
reinterpret_cast<const BetaDataType*>(p_beta),
|
||||
NumRows,
|
||||
EmbeddingDim,
|
||||
IndexLength,
|
||||
epsilon);
|
||||
}
|
||||
|
||||
using GridwiseSparseEmbedding =
|
||||
GridwiseSparseEmbedding3ForwardLayernorm<EmbType,
|
||||
IndexType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
OutType,
|
||||
decltype(MakeOutputDescriptor(1, 1)),
|
||||
BlockSize,
|
||||
DimClusterSize,
|
||||
RowClusterSize,
|
||||
DimPerBlock,
|
||||
RowPerBlock,
|
||||
DimThreadSize,
|
||||
RowVectorSize>;
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
auto out_desc = MakeOutputDescriptor(arg.IndexLength_, arg.EmbeddingDim_);
|
||||
const auto kernel_main =
|
||||
kernel_sparse_embedding3_forward_layernorm<GridwiseSparseEmbedding,
|
||||
EmbType,
|
||||
IndexType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
OutType,
|
||||
decltype(out_desc)>;
|
||||
float avg_time = 0;
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_main,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_out_,
|
||||
arg.p_emb_a_,
|
||||
arg.p_emb_b_,
|
||||
arg.p_emb_c_,
|
||||
arg.p_index_a_,
|
||||
arg.p_index_b_,
|
||||
arg.p_index_c_,
|
||||
arg.p_gamma_,
|
||||
arg.p_beta_,
|
||||
out_desc,
|
||||
arg.epsilon_);
|
||||
|
||||
return (avg_time);
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument* p_arg)
|
||||
{
|
||||
return (RowPerBlock == p_arg->EmbeddingDim_) && (p_arg->NumRows_ % DimPerBlock == 0);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceSparseEmbedding3ForwardLayernorm_"<< BlockSize << "_" <<
|
||||
DimClusterSize << "x" << RowClusterSize << "_" <<
|
||||
DimPerBlock << "x" << RowPerBlock << "_" <<
|
||||
DimThreadSize << "x" << RowVectorSize;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,344 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseSparseEmbedding,
|
||||
typename EmbType,
|
||||
typename IndexType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename OutType,
|
||||
typename OutGridDesc>
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
__global__ void kernel_sparse_embedding3_forward_layernorm(OutType* p_out,
|
||||
const EmbType* p_emb_a,
|
||||
const EmbType* p_emb_b,
|
||||
const EmbType* p_emb_c,
|
||||
const IndexType* p_index_a,
|
||||
const IndexType* p_index_b,
|
||||
const IndexType* p_index_c,
|
||||
const GammaDataType* p_gamma,
|
||||
const BetaDataType* p_beta,
|
||||
const OutGridDesc out_grid_desc,
|
||||
const AccDataType epsilon)
|
||||
{
|
||||
GridwiseSparseEmbedding::Run(p_out,
|
||||
p_emb_a,
|
||||
p_emb_b,
|
||||
p_emb_c,
|
||||
p_index_a,
|
||||
p_index_b,
|
||||
p_index_c,
|
||||
p_gamma,
|
||||
p_beta,
|
||||
out_grid_desc,
|
||||
epsilon);
|
||||
}
|
||||
|
||||
template <typename EmbType,
|
||||
typename IndexType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename OutType,
|
||||
typename OutGridDesc,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t DimClusterSize,
|
||||
ck::index_t RowClusterSize,
|
||||
ck::index_t DimPerBlock, // Row x Dim, along Dim
|
||||
ck::index_t RowPerBlock, // Row x Dim, along Row
|
||||
ck::index_t DimThreadSize, // this is actually not vector, but number of registers
|
||||
ck::index_t RowVectorSize>
|
||||
struct GridwiseSparseEmbedding3ForwardLayernorm
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr index_t WaveSize = 64;
|
||||
|
||||
static_assert(BlockSize == RowClusterSize * DimClusterSize,
|
||||
"Invalid cluster distribution within block");
|
||||
static_assert(RowClusterSize % WaveSize == 0, "need to be wavewise");
|
||||
|
||||
static_assert(DimPerBlock % (DimClusterSize * DimThreadSize) == 0, "");
|
||||
static_assert(RowPerBlock % (RowClusterSize * RowVectorSize) == 0, "");
|
||||
|
||||
static constexpr auto DimSubBlocks = DimPerBlock / (DimClusterSize * DimThreadSize);
|
||||
static constexpr auto RowSubBlocks = RowPerBlock / (RowClusterSize * RowVectorSize);
|
||||
|
||||
static_assert((DimPerBlock % DimSubBlocks == 0) && (RowPerBlock % RowSubBlocks == 0), "");
|
||||
static constexpr auto DimPerSubBlock = DimPerBlock / DimSubBlocks;
|
||||
static constexpr auto RowPerSubBlock = RowPerBlock / RowSubBlocks;
|
||||
|
||||
using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<DimSubBlocks * DimThreadSize>{}, Number<RowSubBlocks * RowVectorSize>{})));
|
||||
|
||||
using ThreadwiseWolfordDescReduce = decltype(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<DimSubBlocks * DimThreadSize>{})));
|
||||
|
||||
using ThreadwiseWelford =
|
||||
ThreadwiseWelford<AccDataType, ThreadwiseWolfordDesc2D, ThreadwiseWolfordDescReduce>;
|
||||
|
||||
using ThreadClusterLength = Sequence<DimClusterSize, RowClusterSize>;
|
||||
|
||||
using BlockwiseWelford =
|
||||
BlockwiseWelford<AccDataType, BlockSize, ThreadClusterLength, Sequence<0, 1>>;
|
||||
|
||||
__device__ static void Run(OutType* p_out,
|
||||
const EmbType* p_emb_a,
|
||||
const EmbType* p_emb_b,
|
||||
const EmbType* p_emb_c,
|
||||
const IndexType* p_index_a,
|
||||
const IndexType* p_index_b,
|
||||
const IndexType* p_index_c,
|
||||
const GammaDataType* p_gamma,
|
||||
const BetaDataType* p_beta,
|
||||
const OutGridDesc,
|
||||
const AccDataType epsilon)
|
||||
{
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
|
||||
// const auto index_length = out_grid_desc.GetLength(I0);
|
||||
// const auto emb_dim = out_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(Sequence<DimClusterSize, RowClusterSize>{}, Sequence<0, 1>{});
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_dim_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_row_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const auto wave_dim_id = __builtin_amdgcn_readfirstlane(thread_dim_cluster_id / WaveSize);
|
||||
|
||||
const auto index_start = block_global_id * DimPerBlock + wave_dim_id * DimThreadSize;
|
||||
|
||||
auto threadwise_welford = ThreadwiseWelford();
|
||||
threadwise_welford.max_count_ = RowSubBlocks * RowVectorSize;
|
||||
|
||||
constexpr auto thread_buf_size =
|
||||
DimSubBlocks * DimThreadSize * RowSubBlocks * RowVectorSize;
|
||||
constexpr auto thread_buf_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(DimSubBlocks, DimThreadSize, RowSubBlocks, RowVectorSize));
|
||||
constexpr auto mean_var_buf_size = DimSubBlocks * DimThreadSize;
|
||||
constexpr auto mean_var_buf_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(DimSubBlocks, DimThreadSize));
|
||||
constexpr auto gamma_beta_buf_size = RowSubBlocks * RowVectorSize;
|
||||
constexpr auto gamma_beta_buf_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_a;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_b;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_c;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_a;
|
||||
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_b;
|
||||
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_c;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, thread_buf_size, true> acc_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, gamma_beta_buf_size, true>
|
||||
gamma_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, gamma_beta_buf_size, true>
|
||||
beta_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, mean_var_buf_size, true> mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, mean_var_buf_size, true> var_thread_buf;
|
||||
|
||||
auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
|
||||
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_a;
|
||||
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_b;
|
||||
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_c;
|
||||
|
||||
using src_vector_t = typename decltype(emb_vector_a)::type;
|
||||
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
|
||||
constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_;
|
||||
IndexType index_a = index_buf_a[Number<current_dim>{}];
|
||||
IndexType index_b = index_buf_b[Number<current_dim>{}];
|
||||
IndexType index_c = index_buf_c[Number<current_dim>{}];
|
||||
|
||||
auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
|
||||
sizeof(EmbType) * RowVectorSize;
|
||||
|
||||
int32x4_t emb_res_a =
|
||||
make_wave_buffer_resource_with_default_range(p_emb_a + index_a * RowPerBlock);
|
||||
int32x4_t emb_res_b =
|
||||
make_wave_buffer_resource_with_default_range(p_emb_b + index_b * RowPerBlock);
|
||||
int32x4_t emb_res_c =
|
||||
make_wave_buffer_resource_with_default_range(p_emb_c + index_c * RowPerBlock);
|
||||
emb_vector_a.template AsType<src_vector_t>()(I0) =
|
||||
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_a, thread_offset, 0);
|
||||
emb_vector_b.template AsType<src_vector_t>()(I0) =
|
||||
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_b, thread_offset, 0);
|
||||
emb_vector_c.template AsType<src_vector_t>()(I0) =
|
||||
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_c, thread_offset, 0);
|
||||
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
in_thread_buf_a(Number<register_offset>{}) =
|
||||
emb_vector_a.template AsType<EmbType>()[i_row_vec_];
|
||||
in_thread_buf_b(Number<register_offset>{}) =
|
||||
emb_vector_b.template AsType<EmbType>()[i_row_vec_];
|
||||
in_thread_buf_c(Number<register_offset>{}) =
|
||||
emb_vector_c.template AsType<EmbType>()[i_row_vec_];
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
|
||||
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
AccDataType va =
|
||||
ck::type_convert<AccDataType>(in_thread_buf_a(Number<register_offset>{}));
|
||||
AccDataType vb =
|
||||
ck::type_convert<AccDataType>(in_thread_buf_b(Number<register_offset>{}));
|
||||
AccDataType vc =
|
||||
ck::type_convert<AccDataType>(in_thread_buf_c(Number<register_offset>{}));
|
||||
|
||||
acc_thread_buf(Number<register_offset>{}) += va + vb + vc;
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
|
||||
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
constexpr auto mean_var_offset =
|
||||
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
|
||||
|
||||
threadwise_welford.cur_count_++;
|
||||
threadwise_welford.Update(mean_thread_buf(Number<mean_var_offset>{}),
|
||||
var_thread_buf(Number<mean_var_offset>{}),
|
||||
acc_thread_buf(Number<register_offset>{}));
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
auto threadwise_normalize_store_out = [&](auto i_dim_sub_, auto i_row_sub_) {
|
||||
int32x4_t out_res =
|
||||
make_wave_buffer_resource_with_default_range(p_out + index_start * RowPerBlock);
|
||||
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
|
||||
vector_type_maker_t<OutType, RowVectorSize> out_vector;
|
||||
using dst_vector_t = typename decltype(out_vector)::type;
|
||||
|
||||
constexpr auto mean_var_offset =
|
||||
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
|
||||
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
constexpr auto gamma_beta_offset =
|
||||
gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_));
|
||||
|
||||
auto acc_val = acc_thread_buf[Number<register_offset>{}];
|
||||
acc_val = (acc_val - mean_thread_buf(Number<mean_var_offset>{})) /
|
||||
sqrt(var_thread_buf(Number<mean_var_offset>{}) + epsilon);
|
||||
acc_val = acc_val * gamma_thread_buf[Number<gamma_beta_offset>{}] +
|
||||
beta_thread_buf[Number<gamma_beta_offset>{}];
|
||||
|
||||
out_vector.template AsType<OutType>()(Number<i_row_vec_>{}) =
|
||||
type_convert<OutType>(acc_val);
|
||||
});
|
||||
|
||||
index_t thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
|
||||
sizeof(OutType) * RowVectorSize;
|
||||
|
||||
amd_buffer_store_impl<OutType, RowVectorSize>(
|
||||
out_vector.template AsType<dst_vector_t>()[Number<0>{}],
|
||||
out_res,
|
||||
thread_offset,
|
||||
0);
|
||||
});
|
||||
};
|
||||
|
||||
// first load index
|
||||
ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
|
||||
// prefer use s_load
|
||||
index_buf_a(i_idx_) = p_index_a[index_start + i_idx_.value];
|
||||
index_buf_b(i_idx_) = p_index_b[index_start + i_idx_.value];
|
||||
index_buf_c(i_idx_) = p_index_c[index_start + i_idx_.value];
|
||||
});
|
||||
|
||||
// load gamma/beta
|
||||
static_for<0, RowSubBlocks, 1>{}([&](auto i_row_sub_) {
|
||||
vector_type_maker_t<GammaDataType, RowVectorSize> gamma_vector;
|
||||
vector_type_maker_t<BetaDataType, RowVectorSize> beta_vector;
|
||||
|
||||
index_t thread_offset_gamma = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
|
||||
sizeof(GammaDataType) * RowVectorSize;
|
||||
index_t thread_offset_beta = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
|
||||
sizeof(BetaDataType) * RowVectorSize;
|
||||
|
||||
int32x4_t gamma_res = make_wave_buffer_resource_with_default_range(p_gamma);
|
||||
int32x4_t beta_res = make_wave_buffer_resource_with_default_range(p_beta);
|
||||
|
||||
gamma_vector.template AsType<typename decltype(gamma_vector)::type>()(I0) =
|
||||
amd_buffer_load_impl<GammaDataType, RowVectorSize>(
|
||||
gamma_res, thread_offset_gamma, 0);
|
||||
beta_vector.template AsType<typename decltype(beta_vector)::type>()(I0) =
|
||||
amd_buffer_load_impl<BetaDataType, RowVectorSize>(beta_res, thread_offset_beta, 0);
|
||||
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto offset =
|
||||
gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_));
|
||||
gamma_thread_buf(Number<offset>{}) = type_convert<AccDataType>(
|
||||
gamma_vector.template AsType<GammaDataType>()[Number<i_row_vec_>{}]);
|
||||
beta_thread_buf(Number<offset>{}) = type_convert<AccDataType>(
|
||||
beta_vector.template AsType<BetaDataType>()[Number<i_row_vec_>{}]);
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, thread_buf_size, 1>{}(
|
||||
[&](auto I) { acc_thread_buf(I) = type_convert<AccDataType>(0.0f); });
|
||||
|
||||
static_for<0, mean_var_buf_size, 1>{}([&](auto I) {
|
||||
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
});
|
||||
|
||||
static_for<0, DimSubBlocks, 1>{}([&](auto i_dim_sub) {
|
||||
load_current_sub_row(i_dim_sub, Number<0>{});
|
||||
static_for<0, RowSubBlocks - 1, 1>{}([&](auto i_row) {
|
||||
load_current_sub_row(i_dim_sub, Number<1>{} + i_row);
|
||||
accumulate_current_sub_row(i_dim_sub, i_row);
|
||||
threadwise_welford_sub_row(i_dim_sub, i_row);
|
||||
});
|
||||
accumulate_current_sub_row(i_dim_sub, Number<RowSubBlocks - 1>{});
|
||||
threadwise_welford_sub_row(i_dim_sub, Number<RowSubBlocks - 1>{});
|
||||
|
||||
// blockwise welford
|
||||
static_for<0, mean_var_buf_size, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseWelford::Run(
|
||||
mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_);
|
||||
});
|
||||
|
||||
// store
|
||||
static_for<0, RowSubBlocks, 1>{}(
|
||||
[&](auto i_row) { threadwise_normalize_store_out(i_dim_sub, i_row); });
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -34,6 +34,21 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_
|
||||
return wave_buffer_resource.content;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave)
|
||||
{
|
||||
BufferResource<T> wave_buffer_resource;
|
||||
|
||||
// wavewise base address (64 bit)
|
||||
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
|
||||
// wavewise range (32 bit)
|
||||
wave_buffer_resource.range(Number<2>{}) = 0xffffffff; // max possible range
|
||||
// wavewise setting (32 bit)
|
||||
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
|
||||
|
||||
return wave_buffer_resource.content;
|
||||
}
|
||||
|
||||
// buffer load i8
|
||||
__device__ int8_t
|
||||
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename EmbType,
|
||||
typename IndexType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename OutType>
|
||||
struct ReferenceSparseEmbedding3ForwardLayernorm : public device::BaseOperator
|
||||
{
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(Tensor<OutType>& output,
|
||||
const Tensor<EmbType>& emb_a,
|
||||
const Tensor<EmbType>& emb_b,
|
||||
const Tensor<EmbType>& emb_c,
|
||||
const Tensor<IndexType>& index_a,
|
||||
const Tensor<IndexType>& index_b,
|
||||
const Tensor<IndexType>& index_c,
|
||||
const Tensor<GammaDataType>& gamma,
|
||||
const Tensor<BetaDataType>& beta,
|
||||
ck::index_t NumRows,
|
||||
ck::index_t EmbeddingDim,
|
||||
ck::index_t IndexLength,
|
||||
AccDataType epsilon)
|
||||
: output_(output),
|
||||
emb_a_(emb_a),
|
||||
emb_b_(emb_b),
|
||||
emb_c_(emb_c),
|
||||
index_a_(index_a),
|
||||
index_b_(index_b),
|
||||
index_c_(index_c),
|
||||
gamma_(gamma),
|
||||
beta_(beta),
|
||||
NumRows_(NumRows),
|
||||
EmbeddingDim_(EmbeddingDim),
|
||||
IndexLength_(IndexLength),
|
||||
epsilon_(epsilon)
|
||||
{
|
||||
}
|
||||
Tensor<OutType>& output_;
|
||||
const Tensor<EmbType> emb_a_;
|
||||
const Tensor<EmbType> emb_b_;
|
||||
const Tensor<EmbType> emb_c_;
|
||||
const Tensor<IndexType> index_a_;
|
||||
const Tensor<IndexType> index_b_;
|
||||
const Tensor<IndexType> index_c_;
|
||||
const Tensor<GammaDataType> gamma_;
|
||||
const Tensor<BetaDataType> beta_;
|
||||
ck::index_t NumRows_;
|
||||
ck::index_t EmbeddingDim_;
|
||||
ck::index_t IndexLength_;
|
||||
AccDataType epsilon_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
ck::index_t D = arg.EmbeddingDim_;
|
||||
ck::index_t L = arg.IndexLength_;
|
||||
ck::index_t E = arg.NumRows_;
|
||||
|
||||
Tensor<AccDataType> accumulator({L, D});
|
||||
|
||||
Tensor<AccDataType> mean({L});
|
||||
Tensor<AccDataType> var({L});
|
||||
|
||||
accumulator.SetZero();
|
||||
|
||||
auto f_emb_per_row = [&](auto idx) {
|
||||
IndexType idx_a = arg.index_a_(idx);
|
||||
IndexType idx_b = arg.index_b_(idx);
|
||||
IndexType idx_c = arg.index_c_(idx);
|
||||
|
||||
if(!((idx_a < E) && (idx_b < E) && (idx_c < E)))
|
||||
{
|
||||
throw(std::runtime_error("wrong! out of range"));
|
||||
}
|
||||
|
||||
for(auto d = 0; d < D; d++)
|
||||
{
|
||||
auto v_a = ck::type_convert<AccDataType>(arg.emb_a_(idx_a, d));
|
||||
auto v_b = ck::type_convert<AccDataType>(arg.emb_b_(idx_b, d));
|
||||
auto v_c = ck::type_convert<AccDataType>(arg.emb_c_(idx_c, d));
|
||||
|
||||
accumulator(idx, d) += v_a + v_b + v_c;
|
||||
}
|
||||
};
|
||||
make_ParallelTensorFunctor(f_emb_per_row, L)(std::thread::hardware_concurrency());
|
||||
|
||||
// layernorm
|
||||
for(auto idx = 0; idx < L; ++idx)
|
||||
{
|
||||
mean(idx) = 0;
|
||||
var(idx) = 0;
|
||||
|
||||
for(auto d = 0; d < D; ++d)
|
||||
{
|
||||
auto x_val = accumulator(idx, d);
|
||||
mean(idx) += x_val;
|
||||
var(idx) += x_val * x_val;
|
||||
}
|
||||
|
||||
mean(idx) = mean(idx) / D;
|
||||
var(idx) = (var(idx) / D) - (mean(idx) * mean(idx));
|
||||
}
|
||||
|
||||
for(auto idx = 0; idx < L; ++idx)
|
||||
{
|
||||
for(auto d = 0; d < D; ++d)
|
||||
{
|
||||
auto x_val = accumulator(idx, d);
|
||||
auto y_val = (x_val - mean(idx)) / sqrt(var(idx) + arg.epsilon_);
|
||||
y_val = (y_val * arg.gamma_(d)) + arg.beta_(d);
|
||||
arg.output_(idx, d) = ck::type_convert<OutType>(y_val);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(Tensor<OutType>& output,
|
||||
const Tensor<EmbType>& emb_a,
|
||||
const Tensor<EmbType>& emb_b,
|
||||
const Tensor<EmbType>& emb_c,
|
||||
const Tensor<IndexType>& index_a,
|
||||
const Tensor<IndexType>& index_b,
|
||||
const Tensor<IndexType>& index_c,
|
||||
const Tensor<GammaDataType>& gamma,
|
||||
const Tensor<BetaDataType>& beta,
|
||||
ck::index_t NumRows,
|
||||
ck::index_t EmbeddingDim,
|
||||
ck::index_t IndexLength,
|
||||
AccDataType epsilon)
|
||||
{
|
||||
return Argument(output,
|
||||
emb_a,
|
||||
emb_b,
|
||||
emb_c,
|
||||
index_a,
|
||||
index_b,
|
||||
index_c,
|
||||
gamma,
|
||||
beta,
|
||||
NumRows,
|
||||
EmbeddingDim,
|
||||
IndexLength,
|
||||
epsilon);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceSparseEmbedding3ForwardLayernorm"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user