mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Gridwise elementwise 2d (#466)
* added 2d gridwise elementwise
* added 2d version of device elementwise
* added example file with updated device elementwise call
* added Cmake file
* changed NumDim into 2D
* fixed compiler issues
* fixed indexing for loop step
* fixed NumDim dimension error
* changed blockID to 2D
* updated Grid Desc
* updated kernel call
* fixed 2d thread indexing
* added dimensions for example file
* commented out unused code
* changed vector load
* removed extra code
* temporarily removing vector load on 2nd dim
* changed vector load back, still causing errors
* altered indexing
* changed isSupportedArgument for 2D
* changed indexing + do/while
* fixed isSupportedArgument
* changed dimension for debugging
* fixed
* added testing printouts
* testing change
* added variables to distribute threads through both dimensions
* testing changes
* integrated variable for thread distribution into device elementwise and added as parameter for gridwise elementwise
* removed most of the extraneous code, testing with different dimensions
* testing
* removed debugging print statements
* moved 2d elementwise permute into elementwise permute directory
* fixed formatting
* removed debugging comments from threadwise transfer
Co-authored-by: Jing Zhang <jizhan@amd.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
[ROCm/composable_kernel commit: 0e5c264c3e]
This commit is contained in:
@@ -1 +1,2 @@
|
||||
add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp)
|
||||
add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp)
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise_2d.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using DeviceElementwisePermuteInstance =
|
||||
ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ADataType>,
|
||||
ck::Tuple<BDataType>,
|
||||
PassThrough,
|
||||
3, // NumDim_M
|
||||
1, // NumDim_N
|
||||
8,
|
||||
8,
|
||||
ck::Sequence<8>,
|
||||
ck::Sequence<8>>;
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc,
|
||||
const HostTensorA& A_nchw,
|
||||
const std::vector<std::size_t>& shape_nchw,
|
||||
Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < shape_nchw[0]; ++n)
|
||||
for(std::size_t c = 0; c < shape_nchw[1]; ++c)
|
||||
for(std::size_t h = 0; h < shape_nchw[2]; ++h)
|
||||
for(std::size_t w = 0; w < shape_nchw[3]; ++w)
|
||||
{
|
||||
auto a_val = A_nchw(n, c, h, w);
|
||||
functor(B_nhwc(n, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
|
||||
const int N = 120;
|
||||
const int C = 128;
|
||||
const int H = 32;
|
||||
const int W = 1024;
|
||||
|
||||
/**const int N = 120;
|
||||
const int H = 32;
|
||||
const int W = 64;
|
||||
|
||||
const int C = 128;**/
|
||||
|
||||
std::vector<std::size_t> nchw = {N, C, H, W};
|
||||
std::vector<std::size_t> nhwc = {N, H, W, C};
|
||||
|
||||
Tensor<ADataType> a(nchw);
|
||||
Tensor<BDataType> b(nhwc);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
// LogRangeAsType<float>(std::cout << "Tensor a : ", a.mData, ",") << std::endl;
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 4> ab_lengths{N, H, W, C};
|
||||
|
||||
std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W};
|
||||
std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1};
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The runtime parameters seems not supported by the device instance, exiting!");
|
||||
};
|
||||
|
||||
std::cout << "A (nchw): " << a.mDesc << std::endl;
|
||||
std::cout << "B (nhwc): " << b.mDesc << std::endl;
|
||||
|
||||
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
|
||||
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
// LogRangeAsType<float>(std::cout << "Tensor b : ", b.mData, ",") << std::endl;
|
||||
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>(
|
||||
host_b, a, nchw, PassThrough{});
|
||||
|
||||
// LogRangeAsType<float>(std::cout << "Host b : ", host_b.mData, ",") << std::endl;
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
341
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
Normal file
341
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
Normal file
@@ -0,0 +1,341 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumDim_m,
|
||||
index_t NumDim_n,
|
||||
index_t MPerThread,
|
||||
index_t NPerThread,
|
||||
typename InScalarPerVectorSeq,
|
||||
typename OutScalarPerVectorSeq>
|
||||
struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
|
||||
OutDataTypeTuple,
|
||||
ElementwiseOperation,
|
||||
NumDim_m + NumDim_n>
|
||||
{
|
||||
static constexpr index_t NumDim = NumDim_m + NumDim_n;
|
||||
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
static constexpr int NumOutput = OutDataTypeTuple::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
|
||||
NumOutput == OutScalarPerVectorSeq::Size(),
|
||||
"Tuple size is inconsistent with the number of in/out!");
|
||||
|
||||
static auto GenerateInDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<const DataType*>(nullptr);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
};
|
||||
|
||||
static auto GenerateOutDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<DataType*>(nullptr);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
};
|
||||
|
||||
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
|
||||
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
|
||||
|
||||
template <typename Desc_MN>
|
||||
static auto PadDescriptor_MN_2d(Desc_MN desc_mn,
|
||||
index_t gridSize,
|
||||
index_t blockSize,
|
||||
index_t num_threads_m,
|
||||
index_t num_threads_n)
|
||||
{
|
||||
std::ignore = blockSize;
|
||||
std::ignore = gridSize;
|
||||
const auto m = desc_mn.GetLength(I0);
|
||||
const auto n = desc_mn.GetLength(I1);
|
||||
const index_t loop_step_m = num_threads_m * MPerThread;
|
||||
const index_t loop_step_n = num_threads_n * NPerThread;
|
||||
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
|
||||
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
|
||||
|
||||
const auto desc_mn_pad = transform_tensor_descriptor(
|
||||
desc_mn,
|
||||
make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return desc_mn_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_MN(const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize,
|
||||
index_t num_threads_m,
|
||||
index_t num_threads_n)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type();
|
||||
constexpr auto nDimIds =
|
||||
typename arithmetic_sequence_gen<NumDim_m, NumDim_m + NumDim_n, 1>::type();
|
||||
|
||||
const auto mLengths = get_container_subset(tupleOfShape, mDimIds);
|
||||
const auto nLengths = get_container_subset(tupleOfShape, nDimIds);
|
||||
|
||||
// merge nd to 2d desc - [s0 * s1 * ...]
|
||||
|
||||
if constexpr(NumDim > 2)
|
||||
{
|
||||
const auto desc_mn = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
|
||||
make_tuple(mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize, num_threads_m, num_threads_n);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_MN_2d(desc, gridSize, blockSize, num_threads_m, num_threads_n);
|
||||
}
|
||||
|
||||
template <index_t TupleSize>
|
||||
static auto GenerateInOutGrid2dDescTuple(Number<TupleSize>)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto) {
|
||||
if constexpr(NumDim > 2)
|
||||
{
|
||||
return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1, 1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return MakeDescriptor_MN({1}, {1}, 1, 1, 1, 1);
|
||||
};
|
||||
},
|
||||
Number<TupleSize>{});
|
||||
};
|
||||
|
||||
using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumOutput>{}));
|
||||
using InGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumInput>{}));
|
||||
|
||||
using GridwiseElementwise = GridwiseElementwise_2D<InGrid2dDescTuple,
|
||||
OutGrid2dDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
ElementwiseOperation,
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op)
|
||||
|
||||
: lengths_(lengths),
|
||||
inStridesArray_(inStridesArray),
|
||||
outStridesArray_(outStridesArray),
|
||||
elementwise_op_(elementwise_op),
|
||||
blockSize_(256),
|
||||
gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future
|
||||
num_threads_m_((gridSize_ * blockSize_) / 16),
|
||||
num_threads_n_(16)
|
||||
{
|
||||
static_assert(NumDim_m > 0, "");
|
||||
static_assert(NumDim_n > 0, "");
|
||||
|
||||
in_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
return static_cast<const DataType*>(in_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
out_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
return static_cast<DataType*>(out_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
in_grid_2d_desc_tuple_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
return MakeDescriptor_MN(lengths,
|
||||
inStridesArray[I.value],
|
||||
gridSize_,
|
||||
blockSize_,
|
||||
num_threads_m_,
|
||||
num_threads_n_);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
out_grid_2d_desc_tuple_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
return MakeDescriptor_MN(lengths,
|
||||
outStridesArray[I.value],
|
||||
gridSize_,
|
||||
blockSize_,
|
||||
num_threads_m_,
|
||||
num_threads_n_);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
}
|
||||
|
||||
InDataTypePointerTuple in_dev_buffers_;
|
||||
OutDataTypePointerTuple out_dev_buffers_;
|
||||
InGrid2dDescTuple in_grid_2d_desc_tuple_;
|
||||
OutGrid2dDescTuple out_grid_2d_desc_tuple_;
|
||||
|
||||
std::array<index_t, NumDim> lengths_;
|
||||
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
|
||||
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
|
||||
|
||||
ElementwiseOperation elementwise_op_;
|
||||
index_t blockSize_;
|
||||
index_t gridSize_;
|
||||
index_t num_threads_m_;
|
||||
index_t num_threads_n_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto kernel = kernel_elementwise_2d<GridwiseElementwise,
|
||||
InGrid2dDescTuple,
|
||||
OutGrid2dDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
ElementwiseOperation>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
arg.in_grid_2d_desc_tuple_,
|
||||
arg.out_grid_2d_desc_tuple_,
|
||||
arg.in_dev_buffers_,
|
||||
arg.out_dev_buffers_,
|
||||
arg.elementwise_op_,
|
||||
arg.num_threads_m_,
|
||||
arg.num_threads_n_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.back() % MPerThread != 0)
|
||||
return false;
|
||||
|
||||
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& strides,
|
||||
index_t scalarPerVector,
|
||||
index_t vectorDim) {
|
||||
if(strides[vectorDim] == 1 &&
|
||||
(lengths[vectorDim] % scalarPerVector == 0 ||
|
||||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim])
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
bool valid = true;
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(pArg->lengths_,
|
||||
pArg->inStridesArray_[I.value],
|
||||
InScalarPerVectorSeq::At(I),
|
||||
NumDim_m - 1))
|
||||
valid = false;
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(pArg->lengths_,
|
||||
pArg->outStridesArray_[I.value],
|
||||
OutScalarPerVectorSeq::At(I),
|
||||
NumDim - 1))
|
||||
valid = false;
|
||||
});
|
||||
|
||||
return valid;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
230
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
Normal file
230
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
Normal file
@@ -0,0 +1,230 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseElementwise2dFunctor,
|
||||
typename InGrid2dDescTuple,
|
||||
typename OutGrid2dDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
typename OutDataTypePointerTuple,
|
||||
typename ElementwiseOperation>
|
||||
__global__ void kernel_elementwise_2d(const InGrid2dDescTuple in_grid_2d_desc_tuple,
|
||||
const OutGrid2dDescTuple out_grid_2d_desc_tuple,
|
||||
const InDataTypePointerTuple p_in_global_tuple,
|
||||
const OutDataTypePointerTuple p_out_global_tuple,
|
||||
const ElementwiseOperation elementwise_op,
|
||||
const index_t num_threads_m,
|
||||
const index_t num_threads_n)
|
||||
{
|
||||
GridwiseElementwise2dFunctor::Run(in_grid_2d_desc_tuple,
|
||||
out_grid_2d_desc_tuple,
|
||||
p_in_global_tuple,
|
||||
p_out_global_tuple,
|
||||
elementwise_op,
|
||||
num_threads_m,
|
||||
num_threads_n);
|
||||
}
|
||||
|
||||
template <typename InGrid2dDescTuple,
|
||||
typename OutGrid2dDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
typename OutDataTypePointerTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t MPerThread,
|
||||
index_t NPerThread,
|
||||
typename InScalarPerVectorSeq,
|
||||
typename OutScalarPerVectorSeq>
|
||||
struct GridwiseElementwise_2D
|
||||
{
|
||||
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
|
||||
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
|
||||
|
||||
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
|
||||
NumOutput == OutScalarPerVectorSeq::Size() &&
|
||||
NumInput == InGrid2dDescTuple::Size() &&
|
||||
NumOutput == OutGrid2dDescTuple::Size(),
|
||||
"Tuple size is inconsistent with the number of in/out!");
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr auto thread_buffer_desc_mn =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}, Number<NPerThread>{}));
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
__device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple,
|
||||
const OutGrid2dDescTuple out_grid_2d_desc_tuple,
|
||||
const InDataTypePointerTuple p_in_global_tuple,
|
||||
const OutDataTypePointerTuple p_out_global_tuple,
|
||||
const ElementwiseOperation elementwise_op,
|
||||
const index_t num_threads_m,
|
||||
const index_t num_threads_n)
|
||||
{
|
||||
auto in_thread_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
|
||||
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
|
||||
|
||||
return StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
DataType,
|
||||
MPerThread * NPerThread,
|
||||
true>{};
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_thread_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
|
||||
using DataType = remove_pointer_t<DataTypePointer>;
|
||||
|
||||
return StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
DataType,
|
||||
MPerThread * NPerThread,
|
||||
true>{};
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
auto in_global_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_global_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global_tuple[I], out_grid_2d_desc_tuple[I].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0);
|
||||
const auto N = in_grid_2d_desc_tuple[I0].GetLength(I1);
|
||||
|
||||
const index_t loop_step_m = num_threads_m * MPerThread;
|
||||
const index_t loop_step_n = num_threads_n * NPerThread;
|
||||
|
||||
const index_t thread_1d_id = get_thread_global_1d_id();
|
||||
index_t tid_m = thread_1d_id / num_threads_n;
|
||||
index_t tid_n = thread_1d_id % num_threads_n;
|
||||
|
||||
const auto thread_global_offset = make_multi_index(tid_m * MPerThread, tid_n * NPerThread);
|
||||
|
||||
auto in_global_load_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
|
||||
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
|
||||
|
||||
return ThreadwiseTensorSliceTransfer_v2<
|
||||
DataType,
|
||||
DataType,
|
||||
decltype(in_grid_2d_desc_tuple[I]),
|
||||
decltype(thread_buffer_desc_mn),
|
||||
Sequence<MPerThread, NPerThread>, // SliceLengths
|
||||
Sequence<0, 1>, // DimAccessOrder
|
||||
0, // SrcVectorDim
|
||||
InScalarPerVectorSeq::At(I), // ScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>{in_grid_2d_desc_tuple[I], thread_global_offset};
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_global_store_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
|
||||
using DataType = remove_pointer_t<DataTypePointer>;
|
||||
|
||||
return ThreadwiseTensorSliceTransfer_v1r3<
|
||||
DataType,
|
||||
DataType,
|
||||
decltype(thread_buffer_desc_mn),
|
||||
decltype(out_grid_2d_desc_tuple[I]),
|
||||
PassThroughOp,
|
||||
Sequence<MPerThread, NPerThread>, // SliceLengths
|
||||
Sequence<0, 1>, // DimAccessOrder
|
||||
1, // SrcVectorDim
|
||||
1, // OutScalarPerVectorSeq::At(I),
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(out_grid_2d_desc_tuple[I], thread_global_offset, PassThroughOp{});
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
index_t num_iter_m = M / (loop_step_m);
|
||||
do
|
||||
{
|
||||
index_t num_iter_n = N / (loop_step_n);
|
||||
do
|
||||
{
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I],
|
||||
in_global_buf_tuple[I],
|
||||
thread_buffer_desc_mn,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf_tuple(I));
|
||||
|
||||
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I],
|
||||
make_multi_index(0, loop_step_n));
|
||||
});
|
||||
|
||||
static_for<0, MPerThread, 1>{}([&](auto iM) {
|
||||
static_for<0, NPerThread, 1>{}([&](auto iN) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc_mn.CalculateOffset(make_tuple(iM, iN));
|
||||
// get reference to in data
|
||||
const auto in_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto I) -> const auto& {
|
||||
return in_thread_buf_tuple(I)(Number<offset>{});
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
// get referenec to dst data
|
||||
auto out_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto I) -> auto& {
|
||||
return out_thread_buf_tuple(I)(Number<offset>{});
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
unpack2(elementwise_op, out_data_refs, in_data_refs);
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
out_global_store_tuple(I).Run(thread_buffer_desc_mn,
|
||||
make_tuple(I0, I0),
|
||||
out_thread_buf_tuple[I],
|
||||
out_grid_2d_desc_tuple[I],
|
||||
out_global_buf_tuple(I));
|
||||
|
||||
out_global_store_tuple(I).MoveDstSliceWindow(out_grid_2d_desc_tuple[I],
|
||||
make_multi_index(0, loop_step_n));
|
||||
});
|
||||
|
||||
} while(--num_iter_n);
|
||||
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
in_global_load_tuple(I).MoveSrcSliceWindow(
|
||||
in_grid_2d_desc_tuple[I],
|
||||
make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n));
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
out_global_store_tuple(I).MoveDstSliceWindow(
|
||||
out_grid_2d_desc_tuple[I],
|
||||
make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n));
|
||||
});
|
||||
} while(--num_iter_m);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user