add multi embeddings support (#542)

* add multi embeddings support

* fix format

* optimize sqrt

* add reduce operation

* change to elementwise op

* fix name

* rename

* run ci cd

* format example

* format code

* format code
This commit is contained in:
who who who
2023-01-19 01:32:12 +08:00
committed by GitHub
parent 55236709e2
commit 147b7db561
3 changed files with 131 additions and 196 deletions

View File

@@ -17,33 +17,24 @@ template <typename GridwiseSparseEmbedding,
typename BetaDataType,
typename AccDataType,
typename OutType,
typename OutGridDesc>
typename OutGridDesc,
typename EmbElementwiseOperation,
ck::index_t NumEmbeddings>
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
__global__ void kernel_sparse_embedding3_forward_layernorm(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const OutGridDesc out_grid_desc,
const AccDataType epsilon)
__global__ void kernel_sparse_embeddings_forward_layernorm(
OutType* p_out,
const ck::Array<EmbType*, NumEmbeddings> p_embs,
const ck::Array<IndexType*, NumEmbeddings> p_indexes,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const OutGridDesc out_grid_desc,
const AccDataType epsilon,
const EmbElementwiseOperation emb_elementwise_op)
{
GridwiseSparseEmbedding::Run(p_out,
p_emb_a,
p_emb_b,
p_emb_c,
p_index_a,
p_index_b,
p_index_c,
p_gamma,
p_beta,
out_grid_desc,
epsilon);
GridwiseSparseEmbedding::Run(
p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon, emb_elementwise_op);
}
template <typename EmbType,
@@ -53,14 +44,16 @@ template <typename EmbType,
typename AccDataType,
typename OutType,
typename OutGridDesc,
typename EmbElementwiseOperation,
ck::index_t BlockSize,
ck::index_t DimClusterSize,
ck::index_t RowClusterSize,
ck::index_t DimPerBlock, // Row x Dim, along Dim
ck::index_t RowPerBlock, // Row x Dim, along Row
ck::index_t DimThreadSize, // this is actually not vector, but number of registers
ck::index_t RowVectorSize>
struct GridwiseSparseEmbedding3ForwardLayernorm
ck::index_t RowVectorSize,
ck::index_t NumEmbeddings>
struct GridwiseSparseEmbeddingsForwardLayernorm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -97,23 +90,17 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
BlockwiseWelford<AccDataType, BlockSize, ThreadClusterLength, Sequence<0, 1>>;
__device__ static void Run(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const ck::Array<EmbType*, NumEmbeddings> p_embs,
const ck::Array<IndexType*, NumEmbeddings> p_indexes,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const OutGridDesc,
const AccDataType epsilon)
const AccDataType epsilon,
const EmbElementwiseOperation emb_elementwise_op)
{
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
// const auto index_length = out_grid_desc.GetLength(I0);
// const auto emb_dim = out_grid_desc.GetLength(I1);
constexpr auto thread_cluster_desc =
make_cluster_descriptor(Sequence<DimClusterSize, RowClusterSize>{}, Sequence<0, 1>{});
@@ -141,13 +128,11 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
constexpr auto gamma_beta_buf_desc =
make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize));
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_a;
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_b;
StaticBuffer<AddressSpaceEnum::Vgpr, EmbType, thread_buf_size, true> in_thread_buf_c;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_a;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_b;
StaticBuffer<AddressSpaceEnum::Sgpr, IndexType, DimPerBlock, true> index_buf_c;
ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, thread_buf_size, true>,
NumEmbeddings>
in_thread_bufs;
ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, IndexType, DimPerBlock, true>, NumEmbeddings>
index_bufs;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, thread_buf_size, true> acc_thread_buf;
@@ -160,42 +145,31 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, mean_var_buf_size, true> var_thread_buf;
auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_a;
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_b;
vector_type_maker_t<EmbType, RowVectorSize> emb_vector_c;
using src_vector_t = typename decltype(emb_vector_a)::type;
ck::Array<vector_type_maker_t<EmbType, RowVectorSize>, NumEmbeddings> emb_vectors;
auto emb_a = emb_vectors[0];
using src_vector_t = typename decltype(emb_a)::type;
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_;
IndexType index_a = index_buf_a[Number<current_dim>{}];
IndexType index_b = index_buf_b[Number<current_dim>{}];
IndexType index_c = index_buf_c[Number<current_dim>{}];
auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
sizeof(EmbType) * RowVectorSize;
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
IndexType index = index_bufs[i_embedding_][Number<current_dim>{}];
int32x4_t emb_res_a =
make_wave_buffer_resource_with_default_range(p_emb_a + index_a * RowPerBlock);
int32x4_t emb_res_b =
make_wave_buffer_resource_with_default_range(p_emb_b + index_b * RowPerBlock);
int32x4_t emb_res_c =
make_wave_buffer_resource_with_default_range(p_emb_c + index_c * RowPerBlock);
emb_vector_a.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_a, thread_offset, 0);
emb_vector_b.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_b, thread_offset, 0);
emb_vector_c.template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res_c, thread_offset, 0);
int32x4_t emb_res = make_wave_buffer_resource_with_default_range(
p_embs[i_embedding_] + index * RowPerBlock);
emb_vectors(i_embedding_).template AsType<src_vector_t>()(I0) =
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res, thread_offset, 0);
});
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
in_thread_buf_a(Number<register_offset>{}) =
emb_vector_a.template AsType<EmbType>()[i_row_vec_];
in_thread_buf_b(Number<register_offset>{}) =
emb_vector_b.template AsType<EmbType>()[i_row_vec_];
in_thread_buf_c(Number<register_offset>{}) =
emb_vector_c.template AsType<EmbType>()[i_row_vec_];
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
in_thread_bufs(i_embedding_)(Number<register_offset>{}) =
ck::type_convert<AccDataType>(
emb_vectors[i_embedding_].template AsType<EmbType>()[i_row_vec_]);
});
});
});
};
@@ -205,14 +179,17 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
AccDataType va =
ck::type_convert<AccDataType>(in_thread_buf_a(Number<register_offset>{}));
AccDataType vb =
ck::type_convert<AccDataType>(in_thread_buf_b(Number<register_offset>{}));
AccDataType vc =
ck::type_convert<AccDataType>(in_thread_buf_c(Number<register_offset>{}));
acc_thread_buf(Number<register_offset>{}) += va + vb + vc;
auto in_data_refs = generate_tie(
[&](auto i_embedding_) -> const auto& {
return in_thread_bufs(i_embedding_)(Number<register_offset>{});
},
Number<NumEmbeddings>{});
auto out_data_refs = generate_tie(
[&](auto output_index_) -> auto& {
return acc_thread_buf(Number<register_offset>{});
},
Number<1>{});
unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
});
});
};
@@ -242,7 +219,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
constexpr auto mean_var_offset =
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
auto divisor =
1 / __builtin_amdgcn_sqrtf(var_thread_buf(Number<mean_var_offset>{}) + epsilon);
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
@@ -250,9 +228,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_));
auto acc_val = acc_thread_buf[Number<register_offset>{}];
acc_val = (acc_val - mean_thread_buf(Number<mean_var_offset>{})) /
sqrt(var_thread_buf(Number<mean_var_offset>{}) + epsilon);
acc_val = acc_val * gamma_thread_buf[Number<gamma_beta_offset>{}] +
acc_val = (acc_val - mean_thread_buf(Number<mean_var_offset>{})) * divisor;
acc_val = acc_val * gamma_thread_buf[Number<gamma_beta_offset>{}] +
beta_thread_buf[Number<gamma_beta_offset>{}];
out_vector.template AsType<OutType>()(Number<i_row_vec_>{}) =
@@ -273,9 +250,10 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
// first load index
ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
// prefer use s_load
index_buf_a(i_idx_) = p_index_a[index_start + i_idx_.value];
index_buf_b(i_idx_) = p_index_b[index_start + i_idx_.value];
index_buf_c(i_idx_) = p_index_c[index_start + i_idx_.value];
ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
index_bufs(i_embedding_)(i_idx_) =
p_indexes[i_embedding_][index_start + i_idx_.value];
});
});
// load gamma/beta
@@ -329,7 +307,6 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for<0, mean_var_buf_size, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(
mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_);
});