mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
elementwise op (#238)
* Add elementwise operation kernel and example * Add comment * Add template argument of dim . Prepare to support multiple dimension * Rename example * Support 1 dimension * Add static assert * Add comment * Extract pad * Remove redundant argument * Support any dimension for elementwise operation * Remove line * Let it be the multiple number of CU * Move thread per block to the parameter of constructor * rename threadPerBlock with blockSize * Support double * rename kernel function name * remove redundant include header * Refine type * Need to the final dimension * Refine variable name * Refine type * Use index_t instead of int in API Co-authored-by: rocking <chunylai@amd.com>
This commit is contained in:
@@ -0,0 +1,204 @@
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "gridwise_binary_elementwise_1d.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ComputeDataType,
|
||||
typename ElementwiseFunctor,
|
||||
index_t Dim,
|
||||
index_t ScalarPerVector>
|
||||
struct DeviceBinaryElementwise : public BaseOperator
|
||||
{
|
||||
DeviceBinaryElementwise(index_t blockSize = 256) : BaseOperator(), blockSize_(blockSize) {}
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <typename Desc_M0>
|
||||
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
const auto m0 = desc_m0.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * ScalarPerVector;
|
||||
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
|
||||
const auto desc_m0_pad =
|
||||
transform_tensor_descriptor(desc_m0,
|
||||
make_tuple(make_right_pad_transform(m0, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m0_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
|
||||
const std::vector<index_t>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||
if constexpr(Dim > 1)
|
||||
{
|
||||
const auto desc_m0 = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
|
||||
}
|
||||
|
||||
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
|
||||
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M0,
|
||||
ElementwiseFunctor,
|
||||
ScalarPerVector>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
const std::vector<index_t>& shape,
|
||||
const std::vector<index_t>& stride_a,
|
||||
const std::vector<index_t>& stride_b,
|
||||
const std::vector<index_t>& stride_c,
|
||||
ElementwiseFunctor functor,
|
||||
index_t blockSize)
|
||||
: p_a_(p_a),
|
||||
p_b_(p_b),
|
||||
p_c_(p_c),
|
||||
shape_(shape),
|
||||
functor_(functor),
|
||||
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
|
||||
{
|
||||
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize);
|
||||
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize);
|
||||
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, blockSize);
|
||||
}
|
||||
|
||||
const ADataType* p_a_;
|
||||
const BDataType* p_b_;
|
||||
CDataType* p_c_;
|
||||
std::vector<int> shape_;
|
||||
GridDesc_M0 a_grid_desc_m0_;
|
||||
GridDesc_M0 b_grid_desc_m0_;
|
||||
GridDesc_M0 c_grid_desc_m0_;
|
||||
ElementwiseFunctor functor_;
|
||||
index_t gridSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
Invoker(index_t blockSize) : BaseInvoker(), blockSize_(blockSize) {}
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto kernel = kernel_binary_elementwise_1d<GridwiseBinEltwise,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
GridDesc_M0,
|
||||
ElementwiseFunctor>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(blockSize_),
|
||||
0,
|
||||
arg.p_a_,
|
||||
arg.p_b_,
|
||||
arg.p_c_,
|
||||
arg.a_grid_desc_m0_,
|
||||
arg.b_grid_desc_m0_,
|
||||
arg.c_grid_desc_m0_,
|
||||
arg.functor_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
|
||||
index_t blockSize_;
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
if(pArg->shape_.back() % ScalarPerVector != 0)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
std::vector<index_t> shape,
|
||||
std::vector<index_t> stride_a,
|
||||
std::vector<index_t> stride_b,
|
||||
std::vector<index_t> stride_c,
|
||||
ElementwiseFunctor functor)
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
shape,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
functor,
|
||||
blockSize_);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{blockSize_});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBinaryElementwise"
|
||||
<< "<"
|
||||
<< "ScalarPerVector = " << ScalarPerVector
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
index_t blockSize_;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace binary_element_wise {
|
||||
|
||||
struct Add
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
operator()(double& dst, const double& src1, const double& src2) const
|
||||
{
|
||||
dst = src1 + src2;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(float& dst, const float& src1, const float& src2) const
|
||||
{
|
||||
dst = src1 + src2;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace binary_element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,150 @@
|
||||
#pragma once
|
||||
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseBinEltwise,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename GridDesc_M0,
|
||||
typename ElementwiseFunctor>
|
||||
__global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global,
|
||||
const BDataType* __restrict__ p_b_global,
|
||||
CDataType* __restrict__ p_c_global,
|
||||
const GridDesc_M0 a_grid_desc_m0,
|
||||
const GridDesc_M0 b_grid_desc_m0,
|
||||
const GridDesc_M0 c_grid_desc_m0,
|
||||
const ElementwiseFunctor functor)
|
||||
{
|
||||
GridwiseBinEltwise::Run(p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_grid_desc_m0,
|
||||
b_grid_desc_m0,
|
||||
c_grid_desc_m0,
|
||||
functor);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ComputeDataType,
|
||||
typename GridDesc_M0,
|
||||
typename ElementwiseFunctor,
|
||||
index_t ScalarPerVector>
|
||||
struct GridwiseBinaryElementwise_1D
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto thread_desc_m0 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
|
||||
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static __device__ auto CalculateElementwiseIndex()
|
||||
{
|
||||
const index_t global_thread_id = get_thread_global_1d_id();
|
||||
return make_multi_index(global_thread_id * ScalarPerVector);
|
||||
}
|
||||
|
||||
__device__ static void Run(const ADataType* __restrict__ p_a_global,
|
||||
const BDataType* __restrict__ p_b_global,
|
||||
CDataType* __restrict__ p_c_global,
|
||||
const GridDesc_M0 a_grid_desc_m0,
|
||||
const GridDesc_M0 b_grid_desc_m0,
|
||||
const GridDesc_M0 c_grid_desc_m0,
|
||||
const ElementwiseFunctor functor)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_global, a_grid_desc_m0.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_global, b_grid_desc_m0.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_global, c_grid_desc_m0.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> a_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> b_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> c_thread_buf;
|
||||
|
||||
const auto thread_store_global_offset = CalculateElementwiseIndex();
|
||||
|
||||
auto a_global_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<ADataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M0,
|
||||
decltype(thread_desc_m0),
|
||||
Sequence<ScalarPerVector>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
0, // SrcVectorDim
|
||||
ScalarPerVector,
|
||||
1, // SrcScalarStrideInVector
|
||||
false>{a_grid_desc_m0, thread_store_global_offset};
|
||||
|
||||
auto b_global_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<BDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M0,
|
||||
decltype(thread_desc_m0),
|
||||
Sequence<ScalarPerVector>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
0, // SrcVectorDim
|
||||
ScalarPerVector,
|
||||
1, // SrcScalarStrideInVector
|
||||
false>{b_grid_desc_m0, thread_store_global_offset};
|
||||
|
||||
auto c_global_write =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
|
||||
CDataType,
|
||||
decltype(thread_desc_m0),
|
||||
GridDesc_M0,
|
||||
PassThrough,
|
||||
Sequence<ScalarPerVector>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
0, // DstVectorDim
|
||||
ScalarPerVector,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1, // DstScalarStrideInVector
|
||||
false>{
|
||||
c_grid_desc_m0, thread_store_global_offset, PassThrough{}};
|
||||
|
||||
const index_t blockSize = get_block_size();
|
||||
const index_t blockPerGrid = get_grid_size();
|
||||
const auto m0 = c_grid_desc_m0.GetLength(I0);
|
||||
const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector;
|
||||
const auto loop_step_index = make_multi_index(loop_step);
|
||||
|
||||
index_t num_iter = m0 / (loop_step);
|
||||
do
|
||||
{
|
||||
// read and process ScalarPerVector elements
|
||||
a_global_load.Run(
|
||||
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf);
|
||||
|
||||
b_global_load.Run(
|
||||
b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf);
|
||||
|
||||
static_for<0, ScalarPerVector, 1>{}([&](auto m) {
|
||||
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m));
|
||||
functor(c_thread_buf(Number<offset>{}),
|
||||
a_thread_buf(Number<offset>{}),
|
||||
b_thread_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
c_global_write.Run(thread_desc_m0,
|
||||
make_tuple(I0), // SrcSliceOriginIdx
|
||||
c_thread_buf,
|
||||
c_grid_desc_m0,
|
||||
c_global_buf);
|
||||
|
||||
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index);
|
||||
b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index);
|
||||
c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index);
|
||||
} while(--num_iter);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user