From 75c29a7c905048a55f28bca11ceaa7c5aec724f1 Mon Sep 17 00:00:00 2001 From: who who who Date: Thu, 19 Jan 2023 01:32:12 +0800 Subject: [PATCH] add multi embeddings support (#542) * add multi embeddings support * fix format * optimize sqrt * add reduce operation * change to elementwise op * fix name * rename * run ci cd * format example * format code * format code [ROCm/composable_kernel commit: 147b7db5614a4b63b3fc2b6af32268aba990a7d0] --- .../sparse_embedding3_forward_layernorm.cpp | 77 ++++------ ...e_sparse_embeddings_forward_layernorm.hpp} | 105 ++++++------- ...e_sparse_embeddings_forward_layernorm.hpp} | 145 ++++++++---------- 3 files changed, 131 insertions(+), 196 deletions(-) rename include/ck/tensor_operation/gpu/device/impl/{device_sparse_embedding3_forward_layernorm.hpp => device_sparse_embeddings_forward_layernorm.hpp} (61%) rename include/ck/tensor_operation/gpu/grid/{gridwise_sparse_embedding3_forward_layernorm.hpp => gridwise_sparse_embeddings_forward_layernorm.hpp} (69%) diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index f5eb4c3b6b..f0a0cdf6f1 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -9,7 +9,8 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_sparse_embedding3_forward_layernorm.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -18,53 +19,26 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp" -// using EmbType = float; -// using IndexType = int64_t; -// using GammaDataType = float; -// using BetaDataType = float; -// using AccDataType = float; -// using OutType = float; - +// clang-format off using EmbType = ck::half_t; using IndexType = int64_t; using GammaDataType = ck::half_t; using BetaDataType = ck::half_t; using AccDataType = float; using OutType = ck::half_t; +using EmbElementwiseOperation = ck::tensor_operation::element_wise::AddAdd; -// clang-format off -// BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize -using DeviceInstance_fp32_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp32_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp32_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp32_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp32_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp32_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp32_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp32_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp32_e16384 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; - -using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; -using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm; +using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm; +using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm; +using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm; +using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm; +using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm; +using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm; +using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm; template struct emb_kernel{}; -template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e256; }; -template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e512; }; -template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e768; }; -template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e1024;}; -template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e1536;}; -template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e2048;}; -template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e4096;}; -template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e8192;}; -template<> struct emb_kernel{ using kernel_type = DeviceInstance_fp32_e16384;}; - template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e256; }; template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e512; }; template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e768; }; @@ -152,19 +126,20 @@ int main() beta_dev.ToDevice(beta.mData.data()); auto device_instance = typename emb_kernel::kernel_type{}; - auto argument_ptr = device_instance.MakeArgumentPointer(out_dev.GetDeviceBuffer(), - emb_a_dev.GetDeviceBuffer(), - emb_b_dev.GetDeviceBuffer(), - emb_c_dev.GetDeviceBuffer(), - index_a_dev.GetDeviceBuffer(), - index_b_dev.GetDeviceBuffer(), - index_c_dev.GetDeviceBuffer(), - gamma_dev.GetDeviceBuffer(), - beta_dev.GetDeviceBuffer(), - num_rows, - current_dim, - index_length, - epsilon); + auto argument_ptr = device_instance.MakeArgumentPointer( + out_dev.GetDeviceBuffer(), + {ck::type_convert(emb_a_dev.GetDeviceBuffer()), + ck::type_convert(emb_b_dev.GetDeviceBuffer()), + ck::type_convert(emb_c_dev.GetDeviceBuffer())}, + {ck::type_convert(index_a_dev.GetDeviceBuffer()), + ck::type_convert(index_b_dev.GetDeviceBuffer()), + ck::type_convert(index_c_dev.GetDeviceBuffer())}, + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + current_dim, + index_length, + epsilon, + EmbElementwiseOperation{}); std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString() << std::endl << std::flush; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding3_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp similarity index 61% rename from include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding3_forward_layernorm.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp index 1f2b46edd3..2f29224a75 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding3_forward_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp @@ -12,7 +12,7 @@ #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp" namespace ck { namespace tensor_operation { @@ -24,16 +24,17 @@ template -struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator + ck::index_t RowVectorSize, + ck::index_t NumEmbeddings> +struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator { - static auto MakeOutputDescriptor(const index_t index_length, const index_t rows) { return make_naive_tensor_descriptor_packed(make_tuple(index_length, rows)); @@ -42,96 +43,79 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator struct Argument : public BaseArgument { Argument(OutType* p_out, - const EmbType* p_emb_a, - const EmbType* p_emb_b, - const EmbType* p_emb_c, - const IndexType* p_index_a, - const IndexType* p_index_b, - const IndexType* p_index_c, + const ck::Array& p_embs, + const ck::Array& p_indexs, const GammaDataType* p_gamma, const BetaDataType* p_beta, - const ck::index_t NumRows, const ck::index_t EmbeddingDim, const ck::index_t IndexLength, - const AccDataType epsilon) + const AccDataType epsilon, + const EmbElementwiseOperation emb_elementwise_op) : p_out_(p_out), - p_emb_a_(p_emb_a), - p_emb_b_(p_emb_b), - p_emb_c_(p_emb_c), - p_index_a_(p_index_a), - p_index_b_(p_index_b), - p_index_c_(p_index_c), + p_embs_(p_embs), + p_indexs_(p_indexs), p_gamma_(p_gamma), p_beta_(p_beta), - NumRows_(NumRows), EmbeddingDim_(EmbeddingDim), IndexLength_(IndexLength), - epsilon_(epsilon) + epsilon_(epsilon), + emb_elementwise_op_(emb_elementwise_op) { grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize; } OutType* p_out_; - const EmbType* p_emb_a_; - const EmbType* p_emb_b_; - const EmbType* p_emb_c_; - const IndexType* p_index_a_; - const IndexType* p_index_b_; - const IndexType* p_index_c_; + ck::Array p_embs_; + ck::Array p_indexs_; const GammaDataType* p_gamma_; const BetaDataType* p_beta_; - ck::index_t NumRows_; ck::index_t EmbeddingDim_; ck::index_t IndexLength_; AccDataType epsilon_; + EmbElementwiseOperation emb_elementwise_op_; size_t grid_size_; }; - virtual std::unique_ptr MakeArgumentPointer(void* p_out, - const void* p_emb_a, - const void* p_emb_b, - const void* p_emb_c, - const void* p_index_a, - const void* p_index_b, - const void* p_index_c, - const void* p_gamma, - const void* p_beta, - ck::index_t NumRows, - ck::index_t EmbeddingDim, - ck::index_t IndexLength, - const AccDataType epsilon) + std::unique_ptr + MakeArgumentPointer(void* p_out, + const ck::Array& p_embs, + const ck::Array& p_indexs, + const void* p_gamma, + const void* p_beta, + ck::index_t EmbeddingDim, + ck::index_t IndexLength, + const AccDataType epsilon, + const EmbElementwiseOperation emb_elementwise_op) { return std::make_unique(reinterpret_cast(p_out), - reinterpret_cast(p_emb_a), - reinterpret_cast(p_emb_b), - reinterpret_cast(p_emb_c), - reinterpret_cast(p_index_a), - reinterpret_cast(p_index_b), - reinterpret_cast(p_index_c), + p_embs, + p_indexs, reinterpret_cast(p_gamma), reinterpret_cast(p_beta), - NumRows, EmbeddingDim, IndexLength, - epsilon); + epsilon, + emb_elementwise_op); } using GridwiseSparseEmbedding = - GridwiseSparseEmbedding3ForwardLayernorm; + RowVectorSize, + NumEmbeddings>; struct Invoker : public BaseInvoker { @@ -139,14 +123,16 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator { auto out_desc = MakeOutputDescriptor(arg.IndexLength_, arg.EmbeddingDim_); const auto kernel_main = - kernel_sparse_embedding3_forward_layernorm; + decltype(out_desc), + EmbElementwiseOperation, + NumEmbeddings>; float avg_time = 0; avg_time += launch_and_time_kernel(stream_config, kernel_main, @@ -154,16 +140,13 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator dim3(BlockSize), 0, arg.p_out_, - arg.p_emb_a_, - arg.p_emb_b_, - arg.p_emb_c_, - arg.p_index_a_, - arg.p_index_b_, - arg.p_index_c_, + arg.p_embs_, + arg.p_indexs_, arg.p_gamma_, arg.p_beta_, out_desc, - arg.epsilon_); + arg.epsilon_, + arg.emb_elementwise_op_); return (avg_time); } @@ -177,7 +160,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator static bool IsSupportedArgument(const Argument* p_arg) { - return (RowPerBlock == p_arg->EmbeddingDim_) && (p_arg->NumRows_ % DimPerBlock == 0); + return (RowPerBlock == p_arg->EmbeddingDim_); } bool IsSupportedArgument(const BaseArgument* p_arg) override @@ -195,7 +178,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator auto str = std::stringstream(); // clang-format off - str << "DeviceSparseEmbedding3ForwardLayernorm_"<< BlockSize << "_" << + str << "DeviceSparseEmbeddingsForwardLayernorm_"<< BlockSize << "_" << DimClusterSize << "x" << RowClusterSize << "_" << DimPerBlock << "x" << RowPerBlock << "_" << DimThreadSize << "x" << RowVectorSize; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp similarity index 69% rename from include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp index 3de6aa08c4..53942b9952 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp @@ -17,33 +17,24 @@ template + typename OutGridDesc, + typename EmbElementwiseOperation, + ck::index_t NumEmbeddings> #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - __global__ void kernel_sparse_embedding3_forward_layernorm(OutType* p_out, - const EmbType* p_emb_a, - const EmbType* p_emb_b, - const EmbType* p_emb_c, - const IndexType* p_index_a, - const IndexType* p_index_b, - const IndexType* p_index_c, - const GammaDataType* p_gamma, - const BetaDataType* p_beta, - const OutGridDesc out_grid_desc, - const AccDataType epsilon) + __global__ void kernel_sparse_embeddings_forward_layernorm( + OutType* p_out, + const ck::Array p_embs, + const ck::Array p_indexes, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + const OutGridDesc out_grid_desc, + const AccDataType epsilon, + const EmbElementwiseOperation emb_elementwise_op) { - GridwiseSparseEmbedding::Run(p_out, - p_emb_a, - p_emb_b, - p_emb_c, - p_index_a, - p_index_b, - p_index_c, - p_gamma, - p_beta, - out_grid_desc, - epsilon); + GridwiseSparseEmbedding::Run( + p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon, emb_elementwise_op); } template -struct GridwiseSparseEmbedding3ForwardLayernorm + ck::index_t RowVectorSize, + ck::index_t NumEmbeddings> +struct GridwiseSparseEmbeddingsForwardLayernorm { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -97,23 +90,17 @@ struct GridwiseSparseEmbedding3ForwardLayernorm BlockwiseWelford>; __device__ static void Run(OutType* p_out, - const EmbType* p_emb_a, - const EmbType* p_emb_b, - const EmbType* p_emb_c, - const IndexType* p_index_a, - const IndexType* p_index_b, - const IndexType* p_index_c, + const ck::Array p_embs, + const ck::Array p_indexes, const GammaDataType* p_gamma, const BetaDataType* p_beta, const OutGridDesc, - const AccDataType epsilon) + const AccDataType epsilon, + const EmbElementwiseOperation emb_elementwise_op) { const index_t thread_local_id = get_thread_local_1d_id(); const index_t block_global_id = get_block_1d_id(); - // const auto index_length = out_grid_desc.GetLength(I0); - // const auto emb_dim = out_grid_desc.GetLength(I1); - constexpr auto thread_cluster_desc = make_cluster_descriptor(Sequence{}, Sequence<0, 1>{}); @@ -141,13 +128,11 @@ struct GridwiseSparseEmbedding3ForwardLayernorm constexpr auto gamma_beta_buf_desc = make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize)); - StaticBuffer in_thread_buf_a; - StaticBuffer in_thread_buf_b; - StaticBuffer in_thread_buf_c; - - StaticBuffer index_buf_a; - StaticBuffer index_buf_b; - StaticBuffer index_buf_c; + ck::Array, + NumEmbeddings> + in_thread_bufs; + ck::Array, NumEmbeddings> + index_bufs; StaticBuffer acc_thread_buf; @@ -160,42 +145,31 @@ struct GridwiseSparseEmbedding3ForwardLayernorm StaticBuffer var_thread_buf; auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { - vector_type_maker_t emb_vector_a; - vector_type_maker_t emb_vector_b; - vector_type_maker_t emb_vector_c; - - using src_vector_t = typename decltype(emb_vector_a)::type; + ck::Array, NumEmbeddings> emb_vectors; + auto emb_a = emb_vectors[0]; + using src_vector_t = typename decltype(emb_a)::type; static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_; - IndexType index_a = index_buf_a[Number{}]; - IndexType index_b = index_buf_b[Number{}]; - IndexType index_c = index_buf_c[Number{}]; auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * sizeof(EmbType) * RowVectorSize; + static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { + IndexType index = index_bufs[i_embedding_][Number{}]; - int32x4_t emb_res_a = - make_wave_buffer_resource_with_default_range(p_emb_a + index_a * RowPerBlock); - int32x4_t emb_res_b = - make_wave_buffer_resource_with_default_range(p_emb_b + index_b * RowPerBlock); - int32x4_t emb_res_c = - make_wave_buffer_resource_with_default_range(p_emb_c + index_c * RowPerBlock); - emb_vector_a.template AsType()(I0) = - amd_buffer_load_impl(emb_res_a, thread_offset, 0); - emb_vector_b.template AsType()(I0) = - amd_buffer_load_impl(emb_res_b, thread_offset, 0); - emb_vector_c.template AsType()(I0) = - amd_buffer_load_impl(emb_res_c, thread_offset, 0); + int32x4_t emb_res = make_wave_buffer_resource_with_default_range( + p_embs[i_embedding_] + index * RowPerBlock); + emb_vectors(i_embedding_).template AsType()(I0) = + amd_buffer_load_impl(emb_res, thread_offset, 0); + }); static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { constexpr auto register_offset = thread_buf_desc.CalculateOffset( make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); - in_thread_buf_a(Number{}) = - emb_vector_a.template AsType()[i_row_vec_]; - in_thread_buf_b(Number{}) = - emb_vector_b.template AsType()[i_row_vec_]; - in_thread_buf_c(Number{}) = - emb_vector_c.template AsType()[i_row_vec_]; + static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { + in_thread_bufs(i_embedding_)(Number{}) = + ck::type_convert( + emb_vectors[i_embedding_].template AsType()[i_row_vec_]); + }); }); }); }; @@ -205,14 +179,17 @@ struct GridwiseSparseEmbedding3ForwardLayernorm static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { constexpr auto register_offset = thread_buf_desc.CalculateOffset( make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); - AccDataType va = - ck::type_convert(in_thread_buf_a(Number{})); - AccDataType vb = - ck::type_convert(in_thread_buf_b(Number{})); - AccDataType vc = - ck::type_convert(in_thread_buf_c(Number{})); - - acc_thread_buf(Number{}) += va + vb + vc; + auto in_data_refs = generate_tie( + [&](auto i_embedding_) -> const auto& { + return in_thread_bufs(i_embedding_)(Number{}); + }, + Number{}); + auto out_data_refs = generate_tie( + [&](auto output_index_) -> auto& { + return acc_thread_buf(Number{}); + }, + Number<1>{}); + unpack2(emb_elementwise_op, out_data_refs, in_data_refs); }); }); }; @@ -242,7 +219,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm constexpr auto mean_var_offset = mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); - + auto divisor = + 1 / __builtin_amdgcn_sqrtf(var_thread_buf(Number{}) + epsilon); static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { constexpr auto register_offset = thread_buf_desc.CalculateOffset( make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); @@ -250,9 +228,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_)); auto acc_val = acc_thread_buf[Number{}]; - acc_val = (acc_val - mean_thread_buf(Number{})) / - sqrt(var_thread_buf(Number{}) + epsilon); - acc_val = acc_val * gamma_thread_buf[Number{}] + + acc_val = (acc_val - mean_thread_buf(Number{})) * divisor; + acc_val = acc_val * gamma_thread_buf[Number{}] + beta_thread_buf[Number{}]; out_vector.template AsType()(Number{}) = @@ -273,9 +250,10 @@ struct GridwiseSparseEmbedding3ForwardLayernorm // first load index ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) { // prefer use s_load - index_buf_a(i_idx_) = p_index_a[index_start + i_idx_.value]; - index_buf_b(i_idx_) = p_index_b[index_start + i_idx_.value]; - index_buf_c(i_idx_) = p_index_c[index_start + i_idx_.value]; + ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { + index_bufs(i_embedding_)(i_idx_) = + p_indexes[i_embedding_][index_start + i_idx_.value]; + }); }); // load gamma/beta @@ -329,7 +307,6 @@ struct GridwiseSparseEmbedding3ForwardLayernorm static_for<0, mean_var_buf_size, 1>{}([&](auto I) { if constexpr(I > 0) block_sync_lds(); - BlockwiseWelford::Run( mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_); });