mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Hip tensor permute (#1002)
* adding files for F32 example
* adding functioning implementation with scalar multiplication and unary operator support
* added fp 16 type check in unary square
* updating scalar multiplication as an operator
* functioning version with scalar operator
* changing strides for col major
* updated column major implementation
* working column major implementation
* cleaned up comments, rearranged/renamed files
[ROCm/composable_kernel commit: 454cf7bd1f]
This commit is contained in:
@@ -0,0 +1,55 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <array>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
typename UnaryOperation,
|
||||
typename Scale,
|
||||
index_t NumDim>
|
||||
struct DeviceElementwise : public BaseOperator
|
||||
{
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
static constexpr int NumOutput = OutDataTypeTuple::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op,
|
||||
UnaryOperation unary_op,
|
||||
Scale scale_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
}; // namespace device
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
typename UnaryOperation,
|
||||
typename Scale,
|
||||
index_t NumDim>
|
||||
using DeviceElementwisePtr = std::unique_ptr<DeviceElementwise<InDataTypeTuple,
|
||||
OutDataTypeTuple,
|
||||
ElementwiseOperation,
|
||||
UnaryOperation,
|
||||
Scale,
|
||||
NumDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,329 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
typename UnaryOperation,
|
||||
typename Scale,
|
||||
index_t NumDim,
|
||||
index_t MPerThread,
|
||||
typename InScalarPerVectorSeq,
|
||||
typename OutScalarPerVectorSeq>
|
||||
struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
|
||||
OutDataTypeTuple,
|
||||
ElementwiseOperation,
|
||||
UnaryOperation,
|
||||
Scale,
|
||||
NumDim>
|
||||
{
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
static constexpr int NumOutput = OutDataTypeTuple::Size();
|
||||
|
||||
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
|
||||
NumOutput == OutScalarPerVectorSeq::Size(),
|
||||
"Tuple size is inconsistent with the number of in/out!");
|
||||
|
||||
static auto GenerateInDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<const DataType*>(nullptr);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
};
|
||||
|
||||
static auto GenerateOutDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<DataType*>(nullptr);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
};
|
||||
|
||||
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
|
||||
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
|
||||
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
const auto m = desc_m.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * MPerThread;
|
||||
const auto pad = math::integer_least_multiple(m, loop_step) - m;
|
||||
const auto desc_m_pad =
|
||||
transform_tensor_descriptor(desc_m,
|
||||
make_tuple(make_right_pad_transform(m, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M(const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||
if constexpr(NumDim > 1)
|
||||
{
|
||||
const auto desc_m = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_M_1d(desc, gridSize, blockSize);
|
||||
}
|
||||
|
||||
template <index_t TupleSize>
|
||||
static auto GenerateInOutGrid1dDescTuple(Number<TupleSize>)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto) {
|
||||
if constexpr(NumDim > 1)
|
||||
{
|
||||
return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return MakeDescriptor_M({1}, {1}, 1, 1);
|
||||
};
|
||||
},
|
||||
Number<TupleSize>{});
|
||||
};
|
||||
|
||||
using InGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumInput>{}));
|
||||
using OutGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumOutput>{}));
|
||||
|
||||
using GridwiseElementwise = GridwiseElementwise_1D<InGrid1dDescTuple,
|
||||
OutGrid1dDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
ElementwiseOperation,
|
||||
UnaryOperation,
|
||||
Scale,
|
||||
MPerThread,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op,
|
||||
UnaryOperation unary_op,
|
||||
Scale scale_op)
|
||||
|
||||
: lengths_(lengths),
|
||||
inStridesArray_(inStridesArray),
|
||||
outStridesArray_(outStridesArray),
|
||||
elementwise_op_(elementwise_op),
|
||||
unary_op_(unary_op),
|
||||
scale_op_(scale_op),
|
||||
blockSize_(256)
|
||||
{
|
||||
in_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
return static_cast<const DataType*>(in_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
out_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
return static_cast<DataType*>(out_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
}
|
||||
|
||||
InDataTypePointerTuple in_dev_buffers_;
|
||||
OutDataTypePointerTuple out_dev_buffers_;
|
||||
|
||||
std::array<index_t, NumDim> lengths_;
|
||||
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
|
||||
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
|
||||
|
||||
ElementwiseOperation elementwise_op_;
|
||||
UnaryOperation unary_op_;
|
||||
Scale scale_op_;
|
||||
index_t blockSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
index_t gridSize = getAvailableComputeUnitCount(stream_config);
|
||||
|
||||
auto in_grid_1d_desc_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return MakeDescriptor_M(
|
||||
arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_grid_1d_desc_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return MakeDescriptor_M(
|
||||
arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
const auto kernel = kernel_elementwise_1d<GridwiseElementwise,
|
||||
InGrid1dDescTuple,
|
||||
OutGrid1dDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
ElementwiseOperation,
|
||||
UnaryOperation,
|
||||
Scale>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gridSize),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
in_grid_1d_desc_tuple,
|
||||
out_grid_1d_desc_tuple,
|
||||
arg.in_dev_buffers_,
|
||||
arg.out_dev_buffers_,
|
||||
arg.elementwise_op_,
|
||||
arg.unary_op_,
|
||||
arg.scale_op_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(arg.lengths_.back() % MPerThread != 0)
|
||||
return false;
|
||||
|
||||
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& strides,
|
||||
index_t scalarPerVector) {
|
||||
if(strides.back() == 1 && lengths.back() % scalarPerVector == 0)
|
||||
return true;
|
||||
|
||||
if(strides.back() != 1 && scalarPerVector == 1)
|
||||
return true;
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
bool valid = true;
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(
|
||||
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
|
||||
valid = false;
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(
|
||||
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
|
||||
valid = false;
|
||||
});
|
||||
|
||||
return valid;
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op,
|
||||
UnaryOperation unary_op,
|
||||
Scale scale_op)
|
||||
{
|
||||
return Argument{lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op,
|
||||
unary_op,
|
||||
scale_op};
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op,
|
||||
UnaryOperation unary_op,
|
||||
Scale scale_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op,
|
||||
unary_op,
|
||||
scale_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -355,8 +355,8 @@ struct UnarySquare
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, int32_t> ||
|
||||
is_same_v<T, int8_t>
|
||||
static_assert(is_same_v<T, float> || is_same_v<T, half_t> || is_same_v<T, double> ||
|
||||
is_same_v<T, int32_t> || is_same_v<T, int8_t>
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|| is_same_v<T, int4_t>
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseElementwise1dFunctor,
|
||||
typename InGrid1dDescTuple,
|
||||
typename OutGrid1dDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
typename OutDataTypePointerTuple,
|
||||
typename ElementwiseOperation,
|
||||
typename UnaryOperation,
|
||||
typename Scale>
|
||||
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple,
|
||||
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
|
||||
const InDataTypePointerTuple p_in_global_tuple,
|
||||
const OutDataTypePointerTuple p_out_global_tuple,
|
||||
const ElementwiseOperation elementwise_op,
|
||||
const UnaryOperation unary_op,
|
||||
const Scale scale_op)
|
||||
{
|
||||
GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple,
|
||||
out_grid_1d_desc_tuple,
|
||||
p_in_global_tuple,
|
||||
p_out_global_tuple,
|
||||
elementwise_op,
|
||||
unary_op,
|
||||
scale_op);
|
||||
}
|
||||
|
||||
template <typename InGrid1dDescTuple,
|
||||
typename OutGrid1dDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
typename OutDataTypePointerTuple,
|
||||
typename ElementwiseOperation,
|
||||
typename UnaryOperation,
|
||||
typename Scale,
|
||||
index_t MPerThread,
|
||||
typename InScalarPerVectorSeq,
|
||||
typename OutScalarPerVectorSeq>
|
||||
struct GridwiseElementwise_1D
|
||||
{
|
||||
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
|
||||
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
|
||||
|
||||
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
|
||||
NumOutput == OutScalarPerVectorSeq::Size() &&
|
||||
NumInput == InGrid1dDescTuple::Size() &&
|
||||
NumOutput == OutGrid1dDescTuple::Size(),
|
||||
"Tuple size is inconsistent with the number of in/out!");
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
static constexpr auto thread_buffer_desc_m =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
__device__ static void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple,
|
||||
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
|
||||
const InDataTypePointerTuple p_in_global_tuple,
|
||||
const OutDataTypePointerTuple p_out_global_tuple,
|
||||
const ElementwiseOperation elementwise_op,
|
||||
const UnaryOperation unary_op,
|
||||
const Scale scale_op)
|
||||
{
|
||||
const index_t thread_global_id = get_thread_global_1d_id();
|
||||
|
||||
auto in_thread_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
|
||||
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
|
||||
|
||||
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_thread_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
|
||||
using DataType = remove_pointer_t<DataTypePointer>;
|
||||
|
||||
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
auto in_global_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
|
||||
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_global_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
|
||||
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread);
|
||||
|
||||
const index_t blockSize = get_block_size();
|
||||
const index_t blockPerGrid = get_grid_size();
|
||||
const auto M = in_grid_1d_desc_tuple[I0].GetLength(I0);
|
||||
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
|
||||
const auto loop_step_index = make_multi_index(loop_step);
|
||||
|
||||
auto in_global_load_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
|
||||
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
|
||||
|
||||
return ThreadwiseTensorSliceTransfer_v2<DataType,
|
||||
DataType,
|
||||
decltype(in_grid_1d_desc_tuple[I]),
|
||||
decltype(thread_buffer_desc_m),
|
||||
Sequence<MPerThread>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
0, // SrcVectorDim
|
||||
InScalarPerVectorSeq::At(
|
||||
I), // ScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
false>{in_grid_1d_desc_tuple[I],
|
||||
thread_global_offset};
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_global_store_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
|
||||
using DataType = remove_pointer_t<DataTypePointer>;
|
||||
|
||||
return ThreadwiseTensorSliceTransfer_v1r3<DataType,
|
||||
DataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
decltype(out_grid_1d_desc_tuple[I]),
|
||||
PassThroughOp,
|
||||
Sequence<MPerThread>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
0, // SrcVectorDim
|
||||
OutScalarPerVectorSeq::At(I),
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_1d_desc_tuple[I], thread_global_offset, PassThroughOp{});
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
index_t num_iter = M / (loop_step);
|
||||
do
|
||||
{
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I],
|
||||
in_global_buf_tuple[I],
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
in_thread_buf_tuple(I));
|
||||
|
||||
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I],
|
||||
loop_step_index);
|
||||
});
|
||||
|
||||
static_for<0, MPerThread, 1>{}([&](auto iM) {
|
||||
// get reference to in data
|
||||
auto uop_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
|
||||
Number<NumInput>{});
|
||||
|
||||
// get reference to dst data
|
||||
auto out_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto I) -> auto& { return out_thread_buf_tuple(I)(iM); },
|
||||
Number<NumOutput>{});
|
||||
|
||||
unpack2(unary_op, uop_data_refs, uop_data_refs);
|
||||
|
||||
auto sop_in_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
|
||||
Number<NumInput>{});
|
||||
|
||||
auto sop_out_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
|
||||
Number<NumInput>{});
|
||||
|
||||
unpack2(scale_op, sop_out_data_refs, sop_in_data_refs);
|
||||
|
||||
const auto in_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); },
|
||||
Number<NumInput>{});
|
||||
|
||||
unpack2(elementwise_op, out_data_refs, in_data_refs);
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
out_global_store_tuple(I).Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
out_thread_buf_tuple[I],
|
||||
out_grid_1d_desc_tuple[I],
|
||||
out_global_buf_tuple(I));
|
||||
|
||||
out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I],
|
||||
loop_step_index);
|
||||
});
|
||||
} while(--num_iter);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user