diff --git a/example/36_sparse_embedding/CMakeLists.txt b/example/36_sparse_embedding/CMakeLists.txt new file mode 100644 index 0000000000..9cbcf5540e --- /dev/null +++ b/example/36_sparse_embedding/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_sparse_embedding3_forward_layernorm sparse_embedding3_forward_layernorm.cpp) diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp new file mode 100644 index 0000000000..c6c12108ba --- /dev/null +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -0,0 +1,222 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_sparse_embedding3_forward_layernorm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp" + +// using EmbType = float; +// using IndexType = int64_t; +// using GammaDataType = float; +// using BetaDataType = float; +// using AccDataType = float; +// using OutType = float; + +using EmbType = ck::half_t; +using IndexType = int64_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using AccDataType = float; +using OutType = ck::half_t; + +// clang-format off +// BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize +using DeviceInstance_fp32_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e16384 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; + +using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; + +template struct emb_kernel{}; + +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e256; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e512; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e768; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e1024;}; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e1536;}; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e2048;}; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e4096;}; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e8192;}; +template<> struct emb_kernel{ using kernel_type = DeviceInstance_fp32_e16384;}; + +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e256; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e512; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e768; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e1024; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e1536; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e2048; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e4096; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e8192; }; + +// clang-format on + +int main() +{ + bool time_kernel = true; + + constexpr auto num_rows = 65536; + constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{}; + // constexpr auto dims = ck::Sequence<256, 512>{}; + constexpr auto index_length = 2048; + constexpr AccDataType epsilon = 1e-4; + + auto f_host_tensor_desc_1d = [](std::size_t len_) { + return HostTensorDescriptor(std::vector({len_})); + }; + + auto f_host_tensor_desc_2d = [](std::size_t rows_, std::size_t cols_) { + return HostTensorDescriptor(std::vector({rows_, cols_})); + }; + + using ReferenceInstance = + ck::tensor_operation::host::ReferenceSparseEmbedding3ForwardLayernorm; + + ck::static_for<0, dims.Size(), 1>{}([&](auto I) { + std::srand(std::time(nullptr)); + constexpr auto current_dim = dims.At(I); + Tensor emb_a(f_host_tensor_desc_2d(num_rows, current_dim)); + Tensor emb_b(f_host_tensor_desc_2d(num_rows, current_dim)); + Tensor emb_c(f_host_tensor_desc_2d(num_rows, current_dim)); + + Tensor index_a(f_host_tensor_desc_1d(index_length)); + Tensor index_b(f_host_tensor_desc_1d(index_length)); + Tensor index_c(f_host_tensor_desc_1d(index_length)); + + Tensor gamma(f_host_tensor_desc_1d(current_dim)); + Tensor beta(f_host_tensor_desc_1d(current_dim)); + + Tensor out(f_host_tensor_desc_2d(index_length, current_dim)); + + emb_a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + emb_b.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + emb_c.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + index_a.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + index_b.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + index_c.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem emb_a_dev(sizeof(EmbType) * emb_a.mDesc.GetElementSpaceSize()); + DeviceMem emb_b_dev(sizeof(EmbType) * emb_b.mDesc.GetElementSpaceSize()); + DeviceMem emb_c_dev(sizeof(EmbType) * emb_c.mDesc.GetElementSpaceSize()); + + DeviceMem index_a_dev(sizeof(IndexType) * index_a.mDesc.GetElementSpaceSize()); + DeviceMem index_b_dev(sizeof(IndexType) * index_b.mDesc.GetElementSpaceSize()); + DeviceMem index_c_dev(sizeof(IndexType) * index_c.mDesc.GetElementSpaceSize()); + + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + + DeviceMem out_dev(sizeof(OutType) * out.mDesc.GetElementSpaceSize()); + + emb_a_dev.ToDevice(emb_a.mData.data()); + emb_b_dev.ToDevice(emb_b.mData.data()); + emb_c_dev.ToDevice(emb_c.mData.data()); + + index_a_dev.ToDevice(index_a.mData.data()); + index_b_dev.ToDevice(index_b.mData.data()); + index_c_dev.ToDevice(index_c.mData.data()); + + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + auto device_instance = typename emb_kernel::kernel_type{}; + auto argument_ptr = device_instance.MakeArgumentPointer(out_dev.GetDeviceBuffer(), + emb_a_dev.GetDeviceBuffer(), + emb_b_dev.GetDeviceBuffer(), + emb_c_dev.GetDeviceBuffer(), + index_a_dev.GetDeviceBuffer(), + index_b_dev.GetDeviceBuffer(), + index_c_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + num_rows, + current_dim, + index_length, + epsilon); + std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString() + << std::endl + << std::flush; + + bool is_supported = device_instance.IsSupportedArgument(argument_ptr.get()); + + if(!is_supported) + { + std::cout << "Runtime parameters are not supported" << std::endl; + return; + } + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + float time_ms = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor out_from_dev(f_host_tensor_desc_2d(index_length, current_dim)); + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(out, + emb_a, + emb_b, + emb_c, + index_a, + index_b, + index_c, + gamma, + beta, + num_rows, + current_dim, + index_length, + epsilon); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + out_dev.FromDevice(out_from_dev.mData.data()); + pass &= ck::utils::check_err( + out_from_dev.mData, out.mData, "Error: Incorrect results", 1e-3, 1e-3); + } + + double total_read = current_dim * index_length * 3 * sizeof(EmbType) + + current_dim * sizeof(GammaDataType) + + current_dim * sizeof(BetaDataType); + double total_write = current_dim * index_length * sizeof(OutType); + double gbps = (total_read + total_write) / time_ms / 1e6; + + std::cout << ", total bytes:" << (total_read + total_write) << ", time:" << time_ms + << ", gbps:" << gbps << ", valid:" << (pass ? "y" : "n") << std::endl + << std::flush; + }); + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index d4c6199dcf..3f73a6a0a3 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -51,4 +51,5 @@ add_subdirectory(32_batched_gemm_scale_softmax_gemm) add_subdirectory(33_multiple_reduce) add_subdirectory(34_batchnorm) add_subdirectory(35_splitK_gemm) +add_subdirectory(36_sparse_embedding) add_subdirectory(41_grouped_conv_conv_fwd) diff --git a/include/ck/tensor_operation/gpu/device/device_sparse_embedding3_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/device/device_sparse_embedding3_forward_layernorm.hpp new file mode 100644 index 0000000000..1f2b46edd3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_sparse_embedding3_forward_layernorm.hpp @@ -0,0 +1,210 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator +{ + + static auto MakeOutputDescriptor(const index_t index_length, const index_t rows) + { + return make_naive_tensor_descriptor_packed(make_tuple(index_length, rows)); + } + + struct Argument : public BaseArgument + { + Argument(OutType* p_out, + const EmbType* p_emb_a, + const EmbType* p_emb_b, + const EmbType* p_emb_c, + const IndexType* p_index_a, + const IndexType* p_index_b, + const IndexType* p_index_c, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + const ck::index_t NumRows, + const ck::index_t EmbeddingDim, + const ck::index_t IndexLength, + const AccDataType epsilon) + : p_out_(p_out), + p_emb_a_(p_emb_a), + p_emb_b_(p_emb_b), + p_emb_c_(p_emb_c), + p_index_a_(p_index_a), + p_index_b_(p_index_b), + p_index_c_(p_index_c), + p_gamma_(p_gamma), + p_beta_(p_beta), + NumRows_(NumRows), + EmbeddingDim_(EmbeddingDim), + IndexLength_(IndexLength), + epsilon_(epsilon) + { + grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize; + } + + OutType* p_out_; + const EmbType* p_emb_a_; + const EmbType* p_emb_b_; + const EmbType* p_emb_c_; + const IndexType* p_index_a_; + const IndexType* p_index_b_; + const IndexType* p_index_c_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + ck::index_t NumRows_; + ck::index_t EmbeddingDim_; + ck::index_t IndexLength_; + AccDataType epsilon_; + + size_t grid_size_; + }; + + virtual std::unique_ptr MakeArgumentPointer(void* p_out, + const void* p_emb_a, + const void* p_emb_b, + const void* p_emb_c, + const void* p_index_a, + const void* p_index_b, + const void* p_index_c, + const void* p_gamma, + const void* p_beta, + ck::index_t NumRows, + ck::index_t EmbeddingDim, + ck::index_t IndexLength, + const AccDataType epsilon) + { + return std::make_unique(reinterpret_cast(p_out), + reinterpret_cast(p_emb_a), + reinterpret_cast(p_emb_b), + reinterpret_cast(p_emb_c), + reinterpret_cast(p_index_a), + reinterpret_cast(p_index_b), + reinterpret_cast(p_index_c), + reinterpret_cast(p_gamma), + reinterpret_cast(p_beta), + NumRows, + EmbeddingDim, + IndexLength, + epsilon); + } + + using GridwiseSparseEmbedding = + GridwiseSparseEmbedding3ForwardLayernorm; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + auto out_desc = MakeOutputDescriptor(arg.IndexLength_, arg.EmbeddingDim_); + const auto kernel_main = + kernel_sparse_embedding3_forward_layernorm; + float avg_time = 0; + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + arg.p_out_, + arg.p_emb_a_, + arg.p_emb_b_, + arg.p_emb_c_, + arg.p_index_a_, + arg.p_index_b_, + arg.p_index_c_, + arg.p_gamma_, + arg.p_beta_, + out_desc, + arg.epsilon_); + + return (avg_time); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + static bool IsSupportedArgument(const Argument* p_arg) + { + return (RowPerBlock == p_arg->EmbeddingDim_) && (p_arg->NumRows_ % DimPerBlock == 0); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(dynamic_cast(p_arg)); + } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceSparseEmbedding3ForwardLayernorm_"<< BlockSize << "_" << + DimClusterSize << "x" << RowClusterSize << "_" << + DimPerBlock << "x" << RowPerBlock << "_" << + DimThreadSize << "x" << RowVectorSize; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp new file mode 100644 index 0000000000..3de6aa08c4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp @@ -0,0 +1,344 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" + +namespace ck { + +template +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + __global__ void kernel_sparse_embedding3_forward_layernorm(OutType* p_out, + const EmbType* p_emb_a, + const EmbType* p_emb_b, + const EmbType* p_emb_c, + const IndexType* p_index_a, + const IndexType* p_index_b, + const IndexType* p_index_c, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + const OutGridDesc out_grid_desc, + const AccDataType epsilon) +{ + GridwiseSparseEmbedding::Run(p_out, + p_emb_a, + p_emb_b, + p_emb_c, + p_index_a, + p_index_b, + p_index_c, + p_gamma, + p_beta, + out_grid_desc, + epsilon); +} + +template +struct GridwiseSparseEmbedding3ForwardLayernorm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr index_t WaveSize = 64; + + static_assert(BlockSize == RowClusterSize * DimClusterSize, + "Invalid cluster distribution within block"); + static_assert(RowClusterSize % WaveSize == 0, "need to be wavewise"); + + static_assert(DimPerBlock % (DimClusterSize * DimThreadSize) == 0, ""); + static_assert(RowPerBlock % (RowClusterSize * RowVectorSize) == 0, ""); + + static constexpr auto DimSubBlocks = DimPerBlock / (DimClusterSize * DimThreadSize); + static constexpr auto RowSubBlocks = RowPerBlock / (RowClusterSize * RowVectorSize); + + static_assert((DimPerBlock % DimSubBlocks == 0) && (RowPerBlock % RowSubBlocks == 0), ""); + static constexpr auto DimPerSubBlock = DimPerBlock / DimSubBlocks; + static constexpr auto RowPerSubBlock = RowPerBlock / RowSubBlocks; + + using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}))); + + using ThreadwiseWolfordDescReduce = decltype( + make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using ThreadClusterLength = Sequence; + + using BlockwiseWelford = + BlockwiseWelford>; + + __device__ static void Run(OutType* p_out, + const EmbType* p_emb_a, + const EmbType* p_emb_b, + const EmbType* p_emb_c, + const IndexType* p_index_a, + const IndexType* p_index_b, + const IndexType* p_index_c, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + const OutGridDesc, + const AccDataType epsilon) + { + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + // const auto index_length = out_grid_desc.GetLength(I0); + // const auto emb_dim = out_grid_desc.GetLength(I1); + + constexpr auto thread_cluster_desc = + make_cluster_descriptor(Sequence{}, Sequence<0, 1>{}); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_dim_cluster_id = thread_cluster_idx[I0]; + const auto thread_row_cluster_id = thread_cluster_idx[I1]; + + const auto wave_dim_id = __builtin_amdgcn_readfirstlane(thread_dim_cluster_id / WaveSize); + + const auto index_start = block_global_id * DimPerBlock + wave_dim_id * DimThreadSize; + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = RowSubBlocks * RowVectorSize; + + constexpr auto thread_buf_size = + DimSubBlocks * DimThreadSize * RowSubBlocks * RowVectorSize; + constexpr auto thread_buf_desc = make_naive_tensor_descriptor_packed( + make_tuple(DimSubBlocks, DimThreadSize, RowSubBlocks, RowVectorSize)); + constexpr auto mean_var_buf_size = DimSubBlocks * DimThreadSize; + constexpr auto mean_var_buf_desc = + make_naive_tensor_descriptor_packed(make_tuple(DimSubBlocks, DimThreadSize)); + constexpr auto gamma_beta_buf_size = RowSubBlocks * RowVectorSize; + constexpr auto gamma_beta_buf_desc = + make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize)); + + StaticBuffer in_thread_buf_a; + StaticBuffer in_thread_buf_b; + StaticBuffer in_thread_buf_c; + + StaticBuffer index_buf_a; + StaticBuffer index_buf_b; + StaticBuffer index_buf_c; + + StaticBuffer acc_thread_buf; + + StaticBuffer + gamma_thread_buf; + StaticBuffer + beta_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + + auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { + vector_type_maker_t emb_vector_a; + vector_type_maker_t emb_vector_b; + vector_type_maker_t emb_vector_c; + + using src_vector_t = typename decltype(emb_vector_a)::type; + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_; + IndexType index_a = index_buf_a[Number{}]; + IndexType index_b = index_buf_b[Number{}]; + IndexType index_c = index_buf_c[Number{}]; + + auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(EmbType) * RowVectorSize; + + int32x4_t emb_res_a = + make_wave_buffer_resource_with_default_range(p_emb_a + index_a * RowPerBlock); + int32x4_t emb_res_b = + make_wave_buffer_resource_with_default_range(p_emb_b + index_b * RowPerBlock); + int32x4_t emb_res_c = + make_wave_buffer_resource_with_default_range(p_emb_c + index_c * RowPerBlock); + emb_vector_a.template AsType()(I0) = + amd_buffer_load_impl(emb_res_a, thread_offset, 0); + emb_vector_b.template AsType()(I0) = + amd_buffer_load_impl(emb_res_b, thread_offset, 0); + emb_vector_c.template AsType()(I0) = + amd_buffer_load_impl(emb_res_c, thread_offset, 0); + + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + in_thread_buf_a(Number{}) = + emb_vector_a.template AsType()[i_row_vec_]; + in_thread_buf_b(Number{}) = + emb_vector_b.template AsType()[i_row_vec_]; + in_thread_buf_c(Number{}) = + emb_vector_c.template AsType()[i_row_vec_]; + }); + }); + }; + + auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + AccDataType va = + ck::type_convert(in_thread_buf_a(Number{})); + AccDataType vb = + ck::type_convert(in_thread_buf_b(Number{})); + AccDataType vc = + ck::type_convert(in_thread_buf_c(Number{})); + + acc_thread_buf(Number{}) += va + vb + vc; + }); + }); + }; + + auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + constexpr auto mean_var_offset = + mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); + + threadwise_welford.cur_count_++; + threadwise_welford.Update(mean_thread_buf(Number{}), + var_thread_buf(Number{}), + acc_thread_buf(Number{})); + }); + }); + }; + + auto threadwise_normalize_store_out = [&](auto i_dim_sub_, auto i_row_sub_) { + int32x4_t out_res = + make_wave_buffer_resource_with_default_range(p_out + index_start * RowPerBlock); + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + vector_type_maker_t out_vector; + using dst_vector_t = typename decltype(out_vector)::type; + + constexpr auto mean_var_offset = + mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); + + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + constexpr auto gamma_beta_offset = + gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_)); + + auto acc_val = acc_thread_buf[Number{}]; + acc_val = (acc_val - mean_thread_buf(Number{})) / + sqrt(var_thread_buf(Number{}) + epsilon); + acc_val = acc_val * gamma_thread_buf[Number{}] + + beta_thread_buf[Number{}]; + + out_vector.template AsType()(Number{}) = + type_convert(acc_val); + }); + + index_t thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(OutType) * RowVectorSize; + + amd_buffer_store_impl( + out_vector.template AsType()[Number<0>{}], + out_res, + thread_offset, + 0); + }); + }; + + // first load index + ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) { + // prefer use s_load + index_buf_a(i_idx_) = p_index_a[index_start + i_idx_.value]; + index_buf_b(i_idx_) = p_index_b[index_start + i_idx_.value]; + index_buf_c(i_idx_) = p_index_c[index_start + i_idx_.value]; + }); + + // load gamma/beta + static_for<0, RowSubBlocks, 1>{}([&](auto i_row_sub_) { + vector_type_maker_t gamma_vector; + vector_type_maker_t beta_vector; + + index_t thread_offset_gamma = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(GammaDataType) * RowVectorSize; + index_t thread_offset_beta = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(BetaDataType) * RowVectorSize; + + int32x4_t gamma_res = make_wave_buffer_resource_with_default_range(p_gamma); + int32x4_t beta_res = make_wave_buffer_resource_with_default_range(p_beta); + + gamma_vector.template AsType()(I0) = + amd_buffer_load_impl( + gamma_res, thread_offset_gamma, 0); + beta_vector.template AsType()(I0) = + amd_buffer_load_impl(beta_res, thread_offset_beta, 0); + + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto offset = + gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_)); + gamma_thread_buf(Number{}) = type_convert( + gamma_vector.template AsType()[Number{}]); + beta_thread_buf(Number{}) = type_convert( + beta_vector.template AsType()[Number{}]); + }); + }); + + static_for<0, thread_buf_size, 1>{}( + [&](auto I) { acc_thread_buf(I) = type_convert(0.0f); }); + + static_for<0, mean_var_buf_size, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + static_for<0, DimSubBlocks, 1>{}([&](auto i_dim_sub) { + load_current_sub_row(i_dim_sub, Number<0>{}); + static_for<0, RowSubBlocks - 1, 1>{}([&](auto i_row) { + load_current_sub_row(i_dim_sub, Number<1>{} + i_row); + accumulate_current_sub_row(i_dim_sub, i_row); + threadwise_welford_sub_row(i_dim_sub, i_row); + }); + accumulate_current_sub_row(i_dim_sub, Number{}); + threadwise_welford_sub_row(i_dim_sub, Number{}); + + // blockwise welford + static_for<0, mean_var_buf_size, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford::Run( + mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_); + }); + + // store + static_for<0, RowSubBlocks, 1>{}( + [&](auto i_row) { threadwise_normalize_store_out(i_dim_sub, i_row); }); + }); + } +}; + +} // namespace ck diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index cc503cf0e5..79295356df 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -34,6 +34,21 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_ return wave_buffer_resource.content; } +template +__device__ int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave) +{ + BufferResource wave_buffer_resource; + + // wavewise base address (64 bit) + wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); + // wavewise range (32 bit) + wave_buffer_resource.range(Number<2>{}) = 0xffffffff; // max possible range + // wavewise setting (32 bit) + wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; + + return wave_buffer_resource.content; +} + // buffer load i8 __device__ int8_t llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp new file mode 100644 index 0000000000..b6a9b0fb5e --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceSparseEmbedding3ForwardLayernorm : public device::BaseOperator +{ + struct Argument : public device::BaseArgument + { + Argument(Tensor& output, + const Tensor& emb_a, + const Tensor& emb_b, + const Tensor& emb_c, + const Tensor& index_a, + const Tensor& index_b, + const Tensor& index_c, + const Tensor& gamma, + const Tensor& beta, + ck::index_t NumRows, + ck::index_t EmbeddingDim, + ck::index_t IndexLength, + AccDataType epsilon) + : output_(output), + emb_a_(emb_a), + emb_b_(emb_b), + emb_c_(emb_c), + index_a_(index_a), + index_b_(index_b), + index_c_(index_c), + gamma_(gamma), + beta_(beta), + NumRows_(NumRows), + EmbeddingDim_(EmbeddingDim), + IndexLength_(IndexLength), + epsilon_(epsilon) + { + } + Tensor& output_; + const Tensor emb_a_; + const Tensor emb_b_; + const Tensor emb_c_; + const Tensor index_a_; + const Tensor index_b_; + const Tensor index_c_; + const Tensor gamma_; + const Tensor beta_; + ck::index_t NumRows_; + ck::index_t EmbeddingDim_; + ck::index_t IndexLength_; + AccDataType epsilon_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + ck::index_t D = arg.EmbeddingDim_; + ck::index_t L = arg.IndexLength_; + ck::index_t E = arg.NumRows_; + + Tensor accumulator({L, D}); + + Tensor mean({L}); + Tensor var({L}); + + accumulator.SetZero(); + + auto f_emb_per_row = [&](auto idx) { + IndexType idx_a = arg.index_a_(idx); + IndexType idx_b = arg.index_b_(idx); + IndexType idx_c = arg.index_c_(idx); + + if(!((idx_a < E) && (idx_b < E) && (idx_c < E))) + { + throw(std::runtime_error("wrong! out of range")); + } + + for(auto d = 0; d < D; d++) + { + auto v_a = ck::type_convert(arg.emb_a_(idx_a, d)); + auto v_b = ck::type_convert(arg.emb_b_(idx_b, d)); + auto v_c = ck::type_convert(arg.emb_c_(idx_c, d)); + + accumulator(idx, d) += v_a + v_b + v_c; + } + }; + make_ParallelTensorFunctor(f_emb_per_row, L)(std::thread::hardware_concurrency()); + + // layernorm + for(auto idx = 0; idx < L; ++idx) + { + mean(idx) = 0; + var(idx) = 0; + + for(auto d = 0; d < D; ++d) + { + auto x_val = accumulator(idx, d); + mean(idx) += x_val; + var(idx) += x_val * x_val; + } + + mean(idx) = mean(idx) / D; + var(idx) = (var(idx) / D) - (mean(idx) * mean(idx)); + } + + for(auto idx = 0; idx < L; ++idx) + { + for(auto d = 0; d < D; ++d) + { + auto x_val = accumulator(idx, d); + auto y_val = (x_val - mean(idx)) / sqrt(var(idx) + arg.epsilon_); + y_val = (y_val * arg.gamma_(d)) + arg.beta_(d); + arg.output_(idx, d) = ck::type_convert(y_val); + } + } + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(Tensor& output, + const Tensor& emb_a, + const Tensor& emb_b, + const Tensor& emb_c, + const Tensor& index_a, + const Tensor& index_b, + const Tensor& index_c, + const Tensor& gamma, + const Tensor& beta, + ck::index_t NumRows, + ck::index_t EmbeddingDim, + ck::index_t IndexLength, + AccDataType epsilon) + { + return Argument(output, + emb_a, + emb_b, + emb_c, + index_a, + index_b, + index_c, + gamma, + beta, + NumRows, + EmbeddingDim, + IndexLength, + epsilon); + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceSparseEmbedding3ForwardLayernorm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck