mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Hotfix binary elementwise (for broadcast on fastest axis) (#254)
* Support different length of ScalarPerVector
* Add example of broadcast on fastest axis
* Typo
* Refine fastest example
* Add dimension check
* Modify fastest broadcast example to 3d
* Enforce users give scalarPerVector explicitely
* 1. Add CscalarPerVedctor
2. Not only broadcast on fastest need to set scalarPerVector to 1
* Rename var
* Move IsScalarPerVectorValid() inside IsSupportedArgument()
* Separate GridDesc_M0 into A, B and C
* rename var
* Rename var of length
Co-authored-by: rocking <chunylai@amd.com>
[ROCm/composable_kernel commit: 82d7d9938f]
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
add_example_executable(example_broadcast_add_2d broadcast_add_2d.cpp)
|
||||
add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp)
|
||||
add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp)
|
||||
add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp)
|
||||
add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp)
|
||||
@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::binary_element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
|
||||
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 2, 8>;
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
ABDataType,
|
||||
CDataType,
|
||||
EltwiseComputeDataType,
|
||||
Add,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
8,
|
||||
8>;
|
||||
|
||||
template <typename HostTensorA,
|
||||
typename HostTensorB,
|
||||
@@ -100,7 +109,7 @@ int main()
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error("The runtime parameters seems not supported by the "
|
||||
"DeviceBinaryElementwise_2D instance, exiting!");
|
||||
"DeviceBinaryElementwise instance, exiting!");
|
||||
};
|
||||
|
||||
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
|
||||
@@ -123,7 +132,7 @@ int main()
|
||||
0>(host_c_m_n, a_m_n, b_n, M, N, Add{});
|
||||
|
||||
pass &= ck::utils::check_err(
|
||||
c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
|
||||
c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results c", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
123
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
Normal file
123
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
Normal file
@@ -0,0 +1,123 @@
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
|
||||
#include "device_tensor.hpp"
|
||||
#include "binary_element_wise_operation.hpp"
|
||||
#include "device_binary_elementwise.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using ABDataType = F16;
|
||||
using CDataType = F16;
|
||||
using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::binary_element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
ABDataType,
|
||||
CDataType,
|
||||
EltwiseComputeDataType,
|
||||
Add,
|
||||
3,
|
||||
8,
|
||||
1,
|
||||
8,
|
||||
8>;
|
||||
|
||||
template <typename HostTensorA,
|
||||
typename HostTensorB,
|
||||
typename HostTensorC,
|
||||
typename ComputeDataType,
|
||||
typename Functor>
|
||||
void host_broadcast3D_am_bmnk(HostTensorC& C,
|
||||
const HostTensorA& A,
|
||||
const HostTensorB& B,
|
||||
const std::vector<std::size_t>& shape,
|
||||
Functor functor)
|
||||
{
|
||||
using ctype = ck::remove_reference_t<decltype(C(0, 0))>;
|
||||
|
||||
for(std::size_t m = 0; m < shape[0]; ++m)
|
||||
for(std::size_t n = 0; n < shape[1]; ++n)
|
||||
for(std::size_t k = 0; k < shape[2]; ++k)
|
||||
{
|
||||
ComputeDataType a_val = static_cast<ComputeDataType>(A(m));
|
||||
ComputeDataType b_val = static_cast<ComputeDataType>(B(m, n, k));
|
||||
ComputeDataType c_val = 0;
|
||||
functor(c_val, a_val, b_val);
|
||||
C(m, n, k) = static_cast<ctype>(c_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
std::vector<std::size_t> mnk = {4, 16, 32};
|
||||
ck::index_t M = mnk[0];
|
||||
|
||||
Tensor<ABDataType> a_m({M});
|
||||
Tensor<ABDataType> b_m_n_k(mnk);
|
||||
Tensor<CDataType> c_m_n_k(mnk);
|
||||
|
||||
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
|
||||
b_m_n_k.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpace());
|
||||
|
||||
a_m_device_buf.ToDevice(a_m.mData.data());
|
||||
b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data());
|
||||
|
||||
auto broadcastAdd = DeviceElementwiseAddInstance{};
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(
|
||||
a_m_device_buf.GetDeviceBuffer(),
|
||||
b_m_n_k_device_buf.GetDeviceBuffer(),
|
||||
c_m_n_k_device_buf.GetDeviceBuffer(),
|
||||
std::vector<ck::index_t>{mnk.begin(), mnk.end()},
|
||||
{1, 0, 0}, // broadcast A on second and third dimension
|
||||
std::vector<ck::index_t>{b_m_n_k.mDesc.GetStrides().begin(),
|
||||
b_m_n_k.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{c_m_n_k.mDesc.GetStrides().begin(),
|
||||
c_m_n_k.mDesc.GetStrides().end()},
|
||||
Add{});
|
||||
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error("The runtime parameters seems not supported by the "
|
||||
"DeviceBinaryElementwise instance, exiting!");
|
||||
};
|
||||
|
||||
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(do_verification)
|
||||
{
|
||||
c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data());
|
||||
Tensor<CDataType> host_c_m_n_k(mnk);
|
||||
|
||||
host_broadcast3D_am_bmnk<Tensor<ABDataType>,
|
||||
Tensor<ABDataType>,
|
||||
Tensor<CDataType>,
|
||||
EltwiseComputeDataType,
|
||||
Add>(host_c_m_n_k, a_m, b_m_n_k, mnk, Add{});
|
||||
|
||||
pass &= ck::utils::check_err(
|
||||
c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::binary_element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
|
||||
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 1, 8>;
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
ABDataType,
|
||||
CDataType,
|
||||
EltwiseComputeDataType,
|
||||
Add,
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
8,
|
||||
8>;
|
||||
|
||||
template <typename HostTensorA,
|
||||
typename HostTensorB,
|
||||
@@ -81,7 +90,7 @@ int main()
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error("The runtime parameters seems not supported by the "
|
||||
"DeviceBinaryElementwise_2D instance, exiting!");
|
||||
"DeviceBinaryElementwise instance, exiting!");
|
||||
};
|
||||
|
||||
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
|
||||
@@ -103,7 +112,7 @@ int main()
|
||||
Add>(host_c_m, a_m, b_m, M, Add{});
|
||||
|
||||
pass &= ck::utils::check_err(
|
||||
c_m.mData, host_c_m.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
|
||||
c_m.mData, host_c_m.mData, "Error: Incorrect results c", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
|
||||
@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::binary_element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
|
||||
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 4, 8>;
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
ABDataType,
|
||||
CDataType,
|
||||
EltwiseComputeDataType,
|
||||
Add,
|
||||
4,
|
||||
8,
|
||||
8,
|
||||
8,
|
||||
8>;
|
||||
|
||||
template <typename HostTensorA,
|
||||
typename HostTensorB,
|
||||
@@ -83,7 +92,7 @@ int main()
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error("The runtime parameters seems not supported by the "
|
||||
"DeviceBinaryElementwise_2D instance, exiting!");
|
||||
"DeviceBinaryElementwise instance, exiting!");
|
||||
};
|
||||
|
||||
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
|
||||
@@ -105,7 +114,7 @@ int main()
|
||||
Add>(host_c, a, b, nchw, Add{});
|
||||
|
||||
pass &=
|
||||
ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
|
||||
ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results c", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
|
||||
@@ -15,91 +15,107 @@ template <typename ADataType,
|
||||
typename CDataType,
|
||||
typename ComputeDataType,
|
||||
typename ElementwiseFunctor,
|
||||
index_t Dim,
|
||||
index_t ScalarPerVector>
|
||||
index_t NDim,
|
||||
index_t MPerThread,
|
||||
index_t AScalarPerVector,
|
||||
index_t BScalarPerVector,
|
||||
index_t CScalarPerVector>
|
||||
struct DeviceBinaryElementwise : public BaseOperator
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <typename Desc_M0>
|
||||
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
const auto m0 = desc_m0.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * ScalarPerVector;
|
||||
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
|
||||
const auto desc_m0_pad =
|
||||
transform_tensor_descriptor(desc_m0,
|
||||
make_tuple(make_right_pad_transform(m0, pad)),
|
||||
const auto M = desc_m.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * MPerThread;
|
||||
const auto pad = math::integer_least_multiple(M, loop_step) - M;
|
||||
const auto desc_m_pad =
|
||||
transform_tensor_descriptor(desc_m,
|
||||
make_tuple(make_right_pad_transform(M, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m0_pad;
|
||||
return desc_m_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
|
||||
const std::vector<index_t>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& strides,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NDim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<NDim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||
if constexpr(Dim > 1)
|
||||
if constexpr(NDim > 1)
|
||||
{
|
||||
const auto desc_m0 = transform_tensor_descriptor(
|
||||
const auto desc_m = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NDim>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
|
||||
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
|
||||
return PadDescriptor_M_1d(desc, gridSize, blockSize);
|
||||
}
|
||||
|
||||
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
|
||||
using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M0,
|
||||
AGridDesc_M,
|
||||
BGridDesc_M,
|
||||
CGridDesc_M,
|
||||
ElementwiseFunctor,
|
||||
ScalarPerVector>;
|
||||
MPerThread,
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
const std::vector<index_t>& shape,
|
||||
const std::vector<index_t>& stride_a,
|
||||
const std::vector<index_t>& stride_b,
|
||||
const std::vector<index_t>& stride_c,
|
||||
const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& a_strides,
|
||||
const std::vector<index_t>& b_strides,
|
||||
const std::vector<index_t>& c_strides,
|
||||
ElementwiseFunctor functor)
|
||||
: p_a_(p_a),
|
||||
p_b_(p_b),
|
||||
p_c_(p_c),
|
||||
shape_(shape),
|
||||
lengths_(lengths),
|
||||
a_strides_(a_strides),
|
||||
b_strides_(b_strides),
|
||||
c_strides_(c_strides),
|
||||
functor_(functor),
|
||||
blockSize_(256),
|
||||
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
|
||||
{
|
||||
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_);
|
||||
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_);
|
||||
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, blockSize_);
|
||||
a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_);
|
||||
b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_);
|
||||
c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_);
|
||||
}
|
||||
|
||||
const ADataType* p_a_;
|
||||
const BDataType* p_b_;
|
||||
CDataType* p_c_;
|
||||
std::vector<int> shape_;
|
||||
GridDesc_M0 a_grid_desc_m0_;
|
||||
GridDesc_M0 b_grid_desc_m0_;
|
||||
GridDesc_M0 c_grid_desc_m0_;
|
||||
std::vector<int> lengths_;
|
||||
AGridDesc_M a_grid_desc_m_;
|
||||
BGridDesc_M b_grid_desc_m_;
|
||||
CGridDesc_M c_grid_desc_m_;
|
||||
std::vector<index_t> a_strides_;
|
||||
std::vector<index_t> b_strides_;
|
||||
std::vector<index_t> c_strides_;
|
||||
ElementwiseFunctor functor_;
|
||||
index_t blockSize_;
|
||||
index_t gridSize_;
|
||||
@@ -113,7 +129,9 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
GridDesc_M0,
|
||||
AGridDesc_M,
|
||||
BGridDesc_M,
|
||||
CGridDesc_M,
|
||||
ElementwiseFunctor>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
@@ -124,9 +142,9 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
arg.p_a_,
|
||||
arg.p_b_,
|
||||
arg.p_c_,
|
||||
arg.a_grid_desc_m0_,
|
||||
arg.b_grid_desc_m0_,
|
||||
arg.c_grid_desc_m0_,
|
||||
arg.a_grid_desc_m_,
|
||||
arg.b_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.functor_);
|
||||
return elapsed_time;
|
||||
}
|
||||
@@ -146,7 +164,30 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
if(pArg->shape_.back() % ScalarPerVector != 0)
|
||||
if(pArg->lengths_.size() != NDim)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.back() % MPerThread != 0)
|
||||
return false;
|
||||
|
||||
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
|
||||
bool ret = true;
|
||||
|
||||
if(!isLastDimensionCoalesced)
|
||||
ret = scalarPerVector == 1;
|
||||
else
|
||||
ret = MPerThread % scalarPerVector == 0;
|
||||
|
||||
return ret;
|
||||
};
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
@@ -155,19 +196,19 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
std::vector<index_t> shape,
|
||||
std::vector<index_t> stride_a,
|
||||
std::vector<index_t> stride_b,
|
||||
std::vector<index_t> stride_c,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<index_t> a_strides,
|
||||
std::vector<index_t> b_strides,
|
||||
std::vector<index_t> c_strides,
|
||||
ElementwiseFunctor functor)
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
shape,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
lengths,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
functor);
|
||||
}
|
||||
|
||||
@@ -180,7 +221,7 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
// clang-format off
|
||||
str << "DeviceBinaryElementwise"
|
||||
<< "<"
|
||||
<< "ScalarPerVector = " << ScalarPerVector
|
||||
<< "MPerThread = " << MPerThread
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -11,138 +11,140 @@ template <typename GridwiseBinEltwise,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename GridDesc_M0,
|
||||
typename AGridDesc_M,
|
||||
typename BGridDesc_M,
|
||||
typename CGridDesc_M,
|
||||
typename ElementwiseFunctor>
|
||||
__global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global,
|
||||
const BDataType* __restrict__ p_b_global,
|
||||
CDataType* __restrict__ p_c_global,
|
||||
const GridDesc_M0 a_grid_desc_m0,
|
||||
const GridDesc_M0 b_grid_desc_m0,
|
||||
const GridDesc_M0 c_grid_desc_m0,
|
||||
const AGridDesc_M a_grid_desc_m,
|
||||
const BGridDesc_M b_grid_desc_m,
|
||||
const CGridDesc_M c_grid_desc_m,
|
||||
const ElementwiseFunctor functor)
|
||||
{
|
||||
GridwiseBinEltwise::Run(p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_grid_desc_m0,
|
||||
b_grid_desc_m0,
|
||||
c_grid_desc_m0,
|
||||
functor);
|
||||
GridwiseBinEltwise::Run(
|
||||
p_a_global, p_b_global, p_c_global, a_grid_desc_m, b_grid_desc_m, c_grid_desc_m, functor);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ComputeDataType,
|
||||
typename GridDesc_M0,
|
||||
typename AGridDesc_M,
|
||||
typename BGridDesc_M,
|
||||
typename CGridDesc_M,
|
||||
typename ElementwiseFunctor,
|
||||
index_t ScalarPerVector>
|
||||
index_t MPerThread,
|
||||
index_t AScalarPerVector,
|
||||
index_t BScalarPerVector,
|
||||
index_t CScalarPerVector>
|
||||
struct GridwiseBinaryElementwise_1D
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto thread_desc_m0 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
|
||||
static constexpr auto thread_desc_m =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
|
||||
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static __device__ auto CalculateElementwiseIndex()
|
||||
{
|
||||
const index_t global_thread_id = get_thread_global_1d_id();
|
||||
return make_multi_index(global_thread_id * ScalarPerVector);
|
||||
return make_multi_index(global_thread_id * MPerThread);
|
||||
}
|
||||
|
||||
__device__ static void Run(const ADataType* __restrict__ p_a_global,
|
||||
const BDataType* __restrict__ p_b_global,
|
||||
CDataType* __restrict__ p_c_global,
|
||||
const GridDesc_M0 a_grid_desc_m0,
|
||||
const GridDesc_M0 b_grid_desc_m0,
|
||||
const GridDesc_M0 c_grid_desc_m0,
|
||||
const AGridDesc_M a_grid_desc_m,
|
||||
const BGridDesc_M b_grid_desc_m,
|
||||
const CGridDesc_M c_grid_desc_m,
|
||||
const ElementwiseFunctor functor)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_global, a_grid_desc_m0.GetElementSpaceSize());
|
||||
p_a_global, a_grid_desc_m.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_global, b_grid_desc_m0.GetElementSpaceSize());
|
||||
p_b_global, b_grid_desc_m.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_global, c_grid_desc_m0.GetElementSpaceSize());
|
||||
p_c_global, c_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> a_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> b_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> c_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> a_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> b_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> c_thread_buf;
|
||||
|
||||
const auto thread_store_global_offset = CalculateElementwiseIndex();
|
||||
|
||||
auto a_global_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<ADataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M0,
|
||||
decltype(thread_desc_m0),
|
||||
Sequence<ScalarPerVector>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
0, // SrcVectorDim
|
||||
ScalarPerVector,
|
||||
1, // SrcScalarStrideInVector
|
||||
false>{a_grid_desc_m0, thread_store_global_offset};
|
||||
AGridDesc_M,
|
||||
decltype(thread_desc_m),
|
||||
Sequence<MPerThread>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
0, // SrcVectorDim
|
||||
AScalarPerVector, // ScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
false>{a_grid_desc_m, thread_store_global_offset};
|
||||
|
||||
auto b_global_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<BDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M0,
|
||||
decltype(thread_desc_m0),
|
||||
Sequence<ScalarPerVector>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
0, // SrcVectorDim
|
||||
ScalarPerVector,
|
||||
1, // SrcScalarStrideInVector
|
||||
false>{b_grid_desc_m0, thread_store_global_offset};
|
||||
BGridDesc_M,
|
||||
decltype(thread_desc_m),
|
||||
Sequence<MPerThread>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
0, // SrcVectorDim
|
||||
BScalarPerVector, // ScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
false>{b_grid_desc_m, thread_store_global_offset};
|
||||
|
||||
auto c_global_write =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
|
||||
CDataType,
|
||||
decltype(thread_desc_m0),
|
||||
GridDesc_M0,
|
||||
decltype(thread_desc_m),
|
||||
CGridDesc_M,
|
||||
PassThrough,
|
||||
Sequence<ScalarPerVector>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
0, // DstVectorDim
|
||||
ScalarPerVector,
|
||||
Sequence<MPerThread>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
0, // DstVectorDim
|
||||
CScalarPerVector, // ScalarPerVector
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1, // DstScalarStrideInVector
|
||||
false>{
|
||||
c_grid_desc_m0, thread_store_global_offset, PassThrough{}};
|
||||
c_grid_desc_m, thread_store_global_offset, PassThrough{}};
|
||||
|
||||
const index_t blockSize = get_block_size();
|
||||
const index_t blockPerGrid = get_grid_size();
|
||||
const auto m0 = c_grid_desc_m0.GetLength(I0);
|
||||
const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector;
|
||||
const auto M = c_grid_desc_m.GetLength(I0);
|
||||
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
|
||||
const auto loop_step_index = make_multi_index(loop_step);
|
||||
|
||||
index_t num_iter = m0 / (loop_step);
|
||||
index_t num_iter = M / (loop_step);
|
||||
do
|
||||
{
|
||||
// read and process ScalarPerVector elements
|
||||
// read and process MPerThread elements
|
||||
a_global_load.Run(
|
||||
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf);
|
||||
a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf);
|
||||
|
||||
b_global_load.Run(
|
||||
b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf);
|
||||
b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf);
|
||||
|
||||
static_for<0, ScalarPerVector, 1>{}([&](auto m) {
|
||||
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m));
|
||||
static_for<0, MPerThread, 1>{}([&](auto m) {
|
||||
constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m));
|
||||
functor(c_thread_buf(Number<offset>{}),
|
||||
a_thread_buf(Number<offset>{}),
|
||||
b_thread_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
c_global_write.Run(thread_desc_m0,
|
||||
c_global_write.Run(thread_desc_m,
|
||||
make_tuple(I0), // SrcSliceOriginIdx
|
||||
c_thread_buf,
|
||||
c_grid_desc_m0,
|
||||
c_grid_desc_m,
|
||||
c_global_buf);
|
||||
|
||||
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index);
|
||||
b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index);
|
||||
c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index);
|
||||
a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index);
|
||||
b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index);
|
||||
c_global_write.MoveDstSliceWindow(c_grid_desc_m, loop_step_index);
|
||||
} while(--num_iter);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user