mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
add multi embeddings support (#542)
* add multi embeddings support
* fix format
* optimize sqrt
* add reduce operation
* change to elementwise op
* fix name
* rename
* run ci cd
* format example
* format code
* format code
[ROCm/composable_kernel commit: 147b7db561]
This commit is contained in:
@@ -9,7 +9,8 @@
|
||||
#include <ctime>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_sparse_embedding3_forward_layernorm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -18,53 +19,26 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp"
|
||||
|
||||
// using EmbType = float;
|
||||
// using IndexType = int64_t;
|
||||
// using GammaDataType = float;
|
||||
// using BetaDataType = float;
|
||||
// using AccDataType = float;
|
||||
// using OutType = float;
|
||||
|
||||
// clang-format off
|
||||
using EmbType = ck::half_t;
|
||||
using IndexType = int64_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using OutType = ck::half_t;
|
||||
using EmbElementwiseOperation = ck::tensor_operation::element_wise::AddAdd;
|
||||
|
||||
// clang-format off
|
||||
// BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize
|
||||
using DeviceInstance_fp32_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 256, 1, 1>;
|
||||
using DeviceInstance_fp32_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 512, 1, 1>;
|
||||
using DeviceInstance_fp32_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 768, 1, 1>;
|
||||
using DeviceInstance_fp32_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1024, 1, 1>;
|
||||
using DeviceInstance_fp32_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1536, 1, 1>;
|
||||
using DeviceInstance_fp32_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 2048, 1, 4>;
|
||||
using DeviceInstance_fp32_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 4096, 1, 4>;
|
||||
using DeviceInstance_fp32_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 8192, 1, 4>;
|
||||
using DeviceInstance_fp32_e16384 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 16384, 1, 4>;
|
||||
|
||||
using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 256, 1, 1>;
|
||||
using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 512, 1, 2>;
|
||||
using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 768, 1, 1>;
|
||||
using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1024, 1, 2>;
|
||||
using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1536, 1, 2>;
|
||||
using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 2048, 1, 2>;
|
||||
using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 4096, 1, 8>;
|
||||
using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 8192, 1, 8>;
|
||||
using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 256, 1, 1, 3>;
|
||||
using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 512, 1, 2, 3>;
|
||||
using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 768, 1, 1, 3>;
|
||||
using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 1024, 1, 2, 3>;
|
||||
using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 1536, 1, 2, 3>;
|
||||
using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 2048, 1, 2, 3>;
|
||||
using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 4096, 1, 8, 3>;
|
||||
using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 8192, 1, 8, 3>;
|
||||
|
||||
template<typename emb_type, ck::index_t dim> struct emb_kernel{};
|
||||
|
||||
template<> struct emb_kernel<float, 256> { using kernel_type = DeviceInstance_fp32_e256; };
|
||||
template<> struct emb_kernel<float, 512> { using kernel_type = DeviceInstance_fp32_e512; };
|
||||
template<> struct emb_kernel<float, 768> { using kernel_type = DeviceInstance_fp32_e768; };
|
||||
template<> struct emb_kernel<float, 1024> { using kernel_type = DeviceInstance_fp32_e1024;};
|
||||
template<> struct emb_kernel<float, 1536> { using kernel_type = DeviceInstance_fp32_e1536;};
|
||||
template<> struct emb_kernel<float, 2048> { using kernel_type = DeviceInstance_fp32_e2048;};
|
||||
template<> struct emb_kernel<float, 4096> { using kernel_type = DeviceInstance_fp32_e4096;};
|
||||
template<> struct emb_kernel<float, 8192> { using kernel_type = DeviceInstance_fp32_e8192;};
|
||||
template<> struct emb_kernel<float, 16384>{ using kernel_type = DeviceInstance_fp32_e16384;};
|
||||
|
||||
template<> struct emb_kernel<ck::half_t, 256> { using kernel_type = DeviceInstance_fp16_e256; };
|
||||
template<> struct emb_kernel<ck::half_t, 512> { using kernel_type = DeviceInstance_fp16_e512; };
|
||||
template<> struct emb_kernel<ck::half_t, 768> { using kernel_type = DeviceInstance_fp16_e768; };
|
||||
@@ -152,19 +126,20 @@ int main()
|
||||
beta_dev.ToDevice(beta.mData.data());
|
||||
|
||||
auto device_instance = typename emb_kernel<EmbType, current_dim>::kernel_type{};
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(out_dev.GetDeviceBuffer(),
|
||||
emb_a_dev.GetDeviceBuffer(),
|
||||
emb_b_dev.GetDeviceBuffer(),
|
||||
emb_c_dev.GetDeviceBuffer(),
|
||||
index_a_dev.GetDeviceBuffer(),
|
||||
index_b_dev.GetDeviceBuffer(),
|
||||
index_c_dev.GetDeviceBuffer(),
|
||||
gamma_dev.GetDeviceBuffer(),
|
||||
beta_dev.GetDeviceBuffer(),
|
||||
num_rows,
|
||||
current_dim,
|
||||
index_length,
|
||||
epsilon);
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(
|
||||
out_dev.GetDeviceBuffer(),
|
||||
{ck::type_convert<EmbType*>(emb_a_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer())},
|
||||
{ck::type_convert<IndexType*>(index_a_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer())},
|
||||
gamma_dev.GetDeviceBuffer(),
|
||||
beta_dev.GetDeviceBuffer(),
|
||||
current_dim,
|
||||
index_length,
|
||||
epsilon,
|
||||
EmbElementwiseOperation{});
|
||||
std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString()
|
||||
<< std::endl
|
||||
<< std::flush;
|
||||
|
||||
Reference in New Issue
Block a user