mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Add 'Permute' device op & example (#408)
* Add example folder for 'DeviceElementwise'
* Re-structure example files
* Move common parts into common.hpp
* Use more strict input
* Add more helper methods in 'DeviceElementwise'
* Use more specific method to write example
* Allow specify problem through command line argument
* Allow specify problem 'axes' through command line argument
* Add check to template type argument
* Add transpose_shape() to generalize shape permute
* Generalize transpose utility functions
* Use better name for tensor indices
* Add checks in helper functions
* Remove debug messages
* Refine error message for check_err()
* Generalize variable naming in example code
* Add device op 'DevicePermute'
This device op is clone of 'DeviceElementwise'
* Use 'DevicePermute' device op in example
* Remove 'elementwise' from identifiers
* Remove 'elementwise' from file paths
* Remove base class of 'DevicePermute'
* Let 'DevicePermute' inherit from 'BaseOperator'
* Add simple type traits to validate device op type
* Add static_assert() to check type constraints
* Create 'DevicePermuteBase' to generate methods
* Use indirect base type to generate methods
* Remove 'is_device_op<>' type traits
* Only accept single-input-single-output for 'DervicePermute'
* Simplify 'DevicePermute' interface
* Re-format 'DeviceElementwise'
* Use CRTP to generate overridden virtual method
* Remove unnecessary include directives
* Distinguish input & output shape in 'DevicePermute'
* Passing 'axes' to 'DevicePermute'
* Use more reasonable return value for Invoker::Run()
* Add 'GridwisePermute' kernel
This kernel is a clone of 'GridwiseElementwise_1D'
* Remove no-longer used type argument
* Check if input/output shape meet the requirement
* Remove no-longer used method
* Remove never-entered-if-clause
* Change problem description for 'DevicePermute'
* Transform descriptor into 3 dimensions
* Add debug code the verify result
* Add comment to indicate template argument location
* Add N/H/WPerBlock template parameter to 'DevicePermute'
* Rename 'GridwisePermute' to 'GridwiseCopy'
* Check tensor descriptor dimensions in 'GridwiseElementwise_1D'
* Add missing include directive
* Add 'BlockSize' parameter to 'DevicePermute'
* Remove no-longer used method
* Add 'BlockToTileMap' for 'GridwiseCopy'
* Use the normal Block2TileMap convention
* Rename 'BlockToTileMap' as 'Block2TileMap'
* Fix most of compilation errors
* Let 'Block2TileMap' map block to 2d coordinate
* Allow data transfer in 'GridwiseCopy'
* Fix wrong output descriptor for 2nd blockwise copy
* Rename 'GridwiseCopy' as 'GridwisePermute'
* Remove '1d' in identifiers
* Remove commented-out codes
* Remove 'MPerThread' template parameter
* Seperate template parameters
* Unify variable namming convention
* Use more verbose way to create expressions
* Add template parameter 'InBlockLdsExtraW'
* Release the constraint on In/OutGridDesc
* Use date type directly as template argument
* Re-arrange template arguments for blockwise copy
* Remove no-longer used template parameters
* Embed layout in the variable names
* Add GridwisePermute::CheckValidity()
* Extract local types as template parameters
* Rename local type alias
* Add more template parameters (vector width related)
* Calculate new SrcVectorDim/DstVectorDim after merge descriptor dimensions
* Fill tensor values start from 1
* Re-formate example code
* Avoid too-large block id
* Add comment
* Make sure 'SrcVectorDim' is not same as 'DstVectorDim'
* Add check for the 'VectorDim' & 'ScalarPerVector' template params
* Let 'DstVectorDim' equals 'SrcVectorDim' after transpose out grid desc
* Remove no-longer used template parameter 'NPerBlock'
* Fix wrong descriptor creation logics
* Specify problem in each examples
* Use better example name
* Add new example 'example_permute_NxHxW_fp32'
* Add example for demonstrating bundle multiple elems in tensor
* Add support to permute multiple elements together
* Change the default problem size
* Add span<> class template
* Use span<> to generalize check_err() interface
* Fix ambiguous ctor call
* Avoid create necessary objects
* Use helper functions to simplify example code
* Add example for 4xfp16 permute
* Disable failed-to-compile example
* Add check for the NUM_ELEMS_IN_BUNDLE
* Remove redundant parameter in helper lambda function
* Add check for the input tensor type's byte-size
* Check scalar-per-vector with padded length
* Use more verbose name to avoid name collision
* Use fixed 'VectorDim' & 'ScalarPerVector' for LDS
* Embed shape info in name of descriptor constructor
* Rename example folder '36_permute' into '37_permute'
* Avoid using too-large LDS in kernel code
* Remove redundant example
* Usw switch() to group similar codes
* Add const to the span<> type arguement
* Simply initialize tensor with floating point values
* Use fp16 as data type in all examples
* Enlarge tensor size in example
* Enalrge N-dim in example
* Add check for the bundled type in example
* Use more stricter error threshold
* Remove global load/store loop in kernel code
* Measure execution time by default
* Use faster device op config for example 'NxHxW_fp16'
* Use faster device op config for example '1xHxW_fp16'
* Use faster device op config for example 'HxWx4_fp16'
* Remove cmd arg parsing logics
* Rename functions
* Extract bundle permutation logic out
* Simplify permute bundle example
* Add Tensor<>::GetElementSpaceSizeInBytes()
* Add Tensor<>::data()
* Use new methods to simplify code
* Use type alias to replace duplicated code
* Use existing method to shorten code
* Allow FillUniformDistribution accept range arugment
* Intialize random values in range
* Add Tensor<>::size()
* Use more meaningful names in permute bundle example
* Use more meaningful names in permute element examples
* Use rangified copy() to copy elements
* Use function return value directly to eliminate variables
* Add to_array() conversion tool to eliminate more variables
* Add Tensor<>::AsSpan<>() to create view of tensor values
* Use AsSpan() to shorten check_err() calls
* Remove no-longer-used 'using' directives
* Move 'using' directive to proper code position
* Remove redudant variables
* Remove useless static_assert()
* Add check for range types
* Declare variable right before first use
* Move long return type as tailing return type
* Add BaseInvokerCRTP<> class template to generate method
* Create new base type for 'DervicePermute' implementations
* Move 'NumDim' template param to the first
* Rename 'DevicePermute' to 'DevicePermuteImpl'
* Add 'noexcept' specifier to CRTP generated method
* Move 'Block2TileMap' definition into 'GridwisePermute'
* Use type alias to reduce code
* Unify naming style in 'DevicePermute'
* Add comments in 'GridwisePermute'
* Rename permute example folder
* Use std::cerr to report error
* Use larger shape in examples
* Rename '38_permute' to '39_permute'
* Make sure we use unsigned type for shape & indices
* Remove opt-ed out assertion
* Remove template BaseInvokerCRTP<>
[ROCm/composable_kernel commit: f584ab0c54]
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
|
||||
#include "ck/stream_config.hpp"
|
||||
|
||||
@@ -222,14 +222,9 @@ struct DeviceElementwise
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.back() % MPerThread != 0)
|
||||
if(arg.lengths_.back() % MPerThread != 0)
|
||||
return false;
|
||||
|
||||
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
|
||||
@@ -247,19 +242,40 @@ struct DeviceElementwise
|
||||
bool valid = true;
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(
|
||||
pArg->lengths_, pArg->inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
|
||||
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
|
||||
valid = false;
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(
|
||||
pArg->lengths_, pArg->outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
|
||||
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
|
||||
valid = false;
|
||||
});
|
||||
|
||||
return valid;
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op)
|
||||
{
|
||||
return Argument{lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op};
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
|
||||
37
include/ck/tensor_operation/gpu/device/device_permute.hpp
Normal file
37
include/ck/tensor_operation/gpu/device/device_permute.hpp
Normal file
@@ -0,0 +1,37 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t NumDim, typename InDataType, typename OutDataType, typename ElementwiseOperation>
|
||||
struct DevicePermute : BaseOperator
|
||||
{
|
||||
using Lengths = std::array<index_t, NumDim>;
|
||||
using Strides = Lengths;
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const Lengths& in_lengths,
|
||||
const Strides& in_strides,
|
||||
const Lengths& out_lengths,
|
||||
const Strides& out_strides,
|
||||
const void* in_dev_buffer,
|
||||
void* out_dev_buffer,
|
||||
ElementwiseOperation elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,282 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Swap last 2 dimensions
|
||||
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
|
||||
// ^^^^^^^^^^^
|
||||
// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]]
|
||||
// ^^^^^^^^^^^
|
||||
template <index_t NumDim,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename ElementwiseOperation,
|
||||
index_t BlockSize,
|
||||
index_t NPerBlock,
|
||||
index_t HPerBlock,
|
||||
index_t WPerBlock,
|
||||
index_t InBlockLdsExtraW,
|
||||
typename InBlockTransferThreadClusterLengths,
|
||||
typename InBlockTransferThreadClusterArrangeOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector>
|
||||
struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>
|
||||
{
|
||||
using BaseType = DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>;
|
||||
using typename BaseType::Lengths;
|
||||
using typename BaseType::Strides;
|
||||
|
||||
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
|
||||
static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim);
|
||||
static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim);
|
||||
static_assert(SrcVectorDim != DstVectorDim);
|
||||
|
||||
template <index_t N = NumDim>
|
||||
static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array)
|
||||
{
|
||||
static_assert(1 <= N && N <= NumDim);
|
||||
|
||||
return generate_tuple([&](auto I) { return array[I]; }, Number<N>{});
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_N_H_W(const Lengths& lengths, const Strides& stride)
|
||||
{
|
||||
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
|
||||
// d[NumDim-1]]
|
||||
const auto desc =
|
||||
make_naive_tensor_descriptor(ConvertArrayToTuple(lengths), ConvertArrayToTuple(stride));
|
||||
|
||||
// merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2],
|
||||
// d[NumDim-1]]
|
||||
// => [N, H, W]
|
||||
const index_t H = *std::next(rbegin(lengths));
|
||||
const index_t W = *rbegin(lengths);
|
||||
const auto desc_n_h_w = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(ConvertArrayToTuple<NumDim - 2>(lengths)),
|
||||
make_pass_through_transform(H),
|
||||
make_pass_through_transform(W)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
|
||||
Sequence<NumDim - 2>{},
|
||||
Sequence<NumDim - 1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return PadTensorDescriptor(
|
||||
desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence<true, true, true>{});
|
||||
}
|
||||
|
||||
using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}));
|
||||
using OutGridDesc = InGridDesc;
|
||||
|
||||
using GridwisePermute = GridwisePermute<
|
||||
InGridDesc,
|
||||
OutGridDesc,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
ElementwiseOperation,
|
||||
BlockSize,
|
||||
NPerBlock,
|
||||
HPerBlock,
|
||||
WPerBlock,
|
||||
InBlockLdsExtraW,
|
||||
InBlockTransferThreadClusterLengths,
|
||||
InBlockTransferThreadClusterArrangeOrder,
|
||||
SrcVectorDim - (NumDim - 3), // calculate new SrcVectorDim for the merged descriptor
|
||||
DstVectorDim - (NumDim - 3), // calculate new DstVectorDim for the merged descriptor
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector>;
|
||||
|
||||
using Block2TileMap = typename GridwisePermute::DefaultBlock2TileMap;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const Lengths& in_lengths,
|
||||
const Strides& in_strides,
|
||||
const Lengths& out_lengths,
|
||||
const Strides& out_strides,
|
||||
const void* in_dev_buffer,
|
||||
void* out_dev_buffer,
|
||||
ElementwiseOperation elementwise_op)
|
||||
: in_dev_buffer_(static_cast<const InDataType*>(in_dev_buffer)),
|
||||
out_dev_buffer_(static_cast<OutDataType*>(out_dev_buffer)),
|
||||
in_grid_desc_(MakeDescriptor_N_H_W(in_lengths, in_strides)),
|
||||
out_grid_desc_(MakeDescriptor_N_H_W(out_lengths, out_strides)),
|
||||
in_lengths_(in_lengths),
|
||||
in_strides_(in_strides),
|
||||
out_lengths_(out_lengths),
|
||||
out_strides_(out_strides),
|
||||
elementwise_op_(elementwise_op),
|
||||
block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_))
|
||||
{
|
||||
}
|
||||
|
||||
const InDataType* in_dev_buffer_;
|
||||
OutDataType* out_dev_buffer_;
|
||||
InGridDesc in_grid_desc_;
|
||||
OutGridDesc out_grid_desc_;
|
||||
|
||||
Lengths in_lengths_;
|
||||
Strides in_strides_;
|
||||
Lengths out_lengths_;
|
||||
Strides out_strides_;
|
||||
|
||||
ElementwiseOperation elementwise_op_;
|
||||
|
||||
Block2TileMap block_2_tile_map_;
|
||||
};
|
||||
|
||||
struct Invoker : BaseInvoker
|
||||
{
|
||||
static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const index_t grid_size = arg.block_2_tile_map_.CalculateGridSize(arg.in_grid_desc_);
|
||||
|
||||
const auto kernel = kernel_nd_permute<GridwisePermute,
|
||||
InGridDesc,
|
||||
OutGridDesc,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
ElementwiseOperation,
|
||||
Block2TileMap>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.in_grid_desc_,
|
||||
arg.out_grid_desc_,
|
||||
arg.in_dev_buffer_,
|
||||
arg.out_dev_buffer_,
|
||||
arg.elementwise_op_,
|
||||
arg.block_2_tile_map_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override final
|
||||
{
|
||||
const auto* const argument = dynamic_cast<const Argument*>(arg);
|
||||
if(!argument)
|
||||
{
|
||||
return NAN;
|
||||
}
|
||||
|
||||
return Run(*argument, stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
constexpr auto GetPaddedLength = [](index_t length, index_t tile_length) {
|
||||
return math::integer_divide_ceil(length, tile_length) * tile_length;
|
||||
};
|
||||
|
||||
constexpr auto IsScalarPerVectorValid =
|
||||
[](index_t length, index_t stride, index_t scalar_per_vector) {
|
||||
if(stride == 1 && length % scalar_per_vector == 0)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if(stride != 1 && scalar_per_vector == 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
return IsScalarPerVectorValid(arg.in_lengths_[SrcVectorDim],
|
||||
arg.in_strides_[SrcVectorDim],
|
||||
SrcScalarPerVector) &&
|
||||
IsScalarPerVectorValid(
|
||||
GetPaddedLength(arg.in_lengths_[SrcVectorDim],
|
||||
(SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
|
||||
arg.in_strides_[SrcVectorDim],
|
||||
SrcScalarPerVector) &&
|
||||
IsScalarPerVectorValid(arg.out_lengths_[DstVectorDim],
|
||||
arg.out_strides_[DstVectorDim],
|
||||
DstScalarPerVector) &&
|
||||
IsScalarPerVectorValid(
|
||||
GetPaddedLength(arg.out_lengths_[DstVectorDim],
|
||||
(DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
|
||||
arg.in_strides_[DstVectorDim],
|
||||
DstScalarPerVector) &&
|
||||
GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
|
||||
};
|
||||
|
||||
// override methods inherited from 'BaseOperator'
|
||||
bool IsSupportedArgument(const BaseArgument* arg) override final
|
||||
{
|
||||
const auto* const argument = dynamic_cast<const Argument*>(arg);
|
||||
if(!argument)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return IsSupportedArgument(*argument);
|
||||
}
|
||||
|
||||
// override methods inherited from 'DevicePermute'
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const Lengths& in_lengths,
|
||||
const Strides& in_strides,
|
||||
const Lengths& out_lengths,
|
||||
const Strides& out_strides,
|
||||
const void* in_dev_buffer,
|
||||
void* out_dev_buffer,
|
||||
ElementwiseOperation elementwise_op) override final
|
||||
{
|
||||
return std::make_unique<Argument>(in_lengths,
|
||||
in_strides,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
in_dev_buffer,
|
||||
out_dev_buffer,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
// other constructor methods
|
||||
template <typename... Args>
|
||||
static std::enable_if_t<std::is_constructible_v<Argument, Args...>, Argument>
|
||||
MakeArgument(Args&&... args) noexcept(std::is_nothrow_constructible_v<Argument, Args...>)
|
||||
{
|
||||
return Argument{std::forward<Args>(args)...};
|
||||
}
|
||||
|
||||
static std::enable_if_t<std::is_default_constructible_v<Invoker>, Invoker>
|
||||
MakeInvoker() noexcept(std::is_nothrow_default_constructible_v<Invoker>)
|
||||
{
|
||||
return Invoker{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -83,6 +83,8 @@ struct GridwiseElementwise_1D
|
||||
|
||||
auto in_global_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
|
||||
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize());
|
||||
},
|
||||
@@ -90,6 +92,8 @@ struct GridwiseElementwise_1D
|
||||
|
||||
auto out_global_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
|
||||
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize());
|
||||
},
|
||||
|
||||
339
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
Normal file
339
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
Normal file
@@ -0,0 +1,339 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <iterator>
|
||||
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwisePermute,
|
||||
typename InGridDesc,
|
||||
typename OutGridDesc,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename ElementwiseOperation,
|
||||
typename Block2TileMap>
|
||||
__global__ void kernel_nd_permute(const InGridDesc in_grid_desc,
|
||||
const OutGridDesc out_grid_desc,
|
||||
const InDataType* p_in_global,
|
||||
OutDataType* p_out_global,
|
||||
const ElementwiseOperation elementwise_op,
|
||||
const Block2TileMap block_2_tile_map)
|
||||
{
|
||||
__shared__ char p_shared[GridwisePermute::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwisePermute::Run(in_grid_desc,
|
||||
out_grid_desc,
|
||||
p_in_global,
|
||||
p_out_global,
|
||||
p_shared,
|
||||
elementwise_op,
|
||||
block_2_tile_map);
|
||||
}
|
||||
|
||||
template <typename InGridDesc,
|
||||
typename OutGridDesc,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename ElementwiseOperation,
|
||||
index_t BlockSize,
|
||||
index_t NPerBlock,
|
||||
index_t HPerBlock,
|
||||
index_t WPerBlock,
|
||||
index_t InBlockLdsExtraW,
|
||||
typename InBlockTransferThreadClusterLengths,
|
||||
typename InBlockTransferThreadClusterArrangeOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector>
|
||||
struct GridwisePermute
|
||||
{
|
||||
static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension());
|
||||
static_assert(3 <= InGridDesc::GetNumOfDimension());
|
||||
static_assert((InGridDesc::GetNumOfDimension() - 2) <= SrcVectorDim &&
|
||||
SrcVectorDim < InGridDesc::GetNumOfDimension());
|
||||
static_assert((OutGridDesc::GetNumOfDimension() - 2) <= DstVectorDim &&
|
||||
DstVectorDim < OutGridDesc::GetNumOfDimension());
|
||||
static_assert(SrcVectorDim != DstVectorDim);
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
struct Block2TileMap
|
||||
{
|
||||
static constexpr index_t NumDim = InGridDesc::GetNumOfDimension();
|
||||
static_assert(3 <= NumDim);
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
Block2TileMap() = delete;
|
||||
Block2TileMap(const Block2TileMap&) = default;
|
||||
Block2TileMap(Block2TileMap&&) = delete;
|
||||
|
||||
~Block2TileMap() = default;
|
||||
|
||||
Block2TileMap& operator=(const Block2TileMap&) = delete;
|
||||
Block2TileMap& operator=(Block2TileMap&&) = delete;
|
||||
|
||||
explicit Block2TileMap(const InGridDesc& desc) : desc_(desc) {}
|
||||
|
||||
__host__ constexpr index_t CalculateGridSize(const InGridDesc& desc) const
|
||||
{
|
||||
const auto N0 =
|
||||
math::integer_divide_ceil(desc.GetLength(Number<NumDim - 3>{}), NPerBlock);
|
||||
const auto H0 =
|
||||
math::integer_divide_ceil(desc.GetLength(Number<NumDim - 2>{}), HPerBlock);
|
||||
const auto W0 =
|
||||
math::integer_divide_ceil(desc.GetLength(Number<NumDim - 1>{}), WPerBlock);
|
||||
|
||||
const index_t grid_size = N0 * H0 * W0;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
static_assert(TopIdx::Size() == 1);
|
||||
|
||||
auto block_1d_id = idx_top[I0];
|
||||
|
||||
const auto N0 =
|
||||
math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 3>{}), NPerBlock);
|
||||
const auto H0 =
|
||||
math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 2>{}), HPerBlock);
|
||||
const auto W0 =
|
||||
math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 1>{}), WPerBlock);
|
||||
|
||||
block_1d_id = block_1d_id % (N0 * H0 * W0);
|
||||
|
||||
index_t idx_N0 = block_1d_id / (H0 * W0);
|
||||
index_t idx_H0 = (block_1d_id % (H0 * W0)) / W0;
|
||||
index_t idx_W0 = block_1d_id % W0;
|
||||
|
||||
return make_tuple(idx_N0, idx_H0, idx_W0);
|
||||
}
|
||||
|
||||
private:
|
||||
const InGridDesc desc_;
|
||||
};
|
||||
|
||||
using DefaultBlock2TileMap = Block2TileMap;
|
||||
|
||||
// use an [NPerBlock, HPerBlock, WPerBlock] tensor as element-copy relay
|
||||
__host__ __device__ static constexpr auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock()
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<NPerBlock>{}, Number<HPerBlock>{}, Number<WPerBlock>{}),
|
||||
make_tuple(Number<HPerBlock*(WPerBlock + InBlockLdsExtraW)>{},
|
||||
Number<WPerBlock + InBlockLdsExtraW>{},
|
||||
I1));
|
||||
}
|
||||
|
||||
// for N-dimension descriptor, reserve its last 2 dimensions, then merge its leading dimensions
|
||||
// into single one. finally, form a 3D descriptor: [d(0), d(1), ..., d(N - 2), d(N - 1)] ->
|
||||
// [(d(0) x d(1) x ...), d(N - 2), d(N - 1)]
|
||||
template <typename GridDesc>
|
||||
__host__ __device__ static constexpr auto GetMergedDesc(const GridDesc& desc)
|
||||
{
|
||||
constexpr index_t NumDim = GridDesc::GetNumOfDimension();
|
||||
static_assert(3 <= NumDim);
|
||||
|
||||
const auto merged_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(generate_tuple(
|
||||
[&](auto I) { return desc.GetLength(I); }, Number<NumDim - 2>{})),
|
||||
make_pass_through_transform(desc.GetLength(Number<NumDim - 2>{})),
|
||||
make_pass_through_transform(desc.GetLength(Number<NumDim - 1>{}))),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
|
||||
Sequence<NumDim - 2>{},
|
||||
Sequence<NumDim - 1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
return merged_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto in_block_desc_nperblock_hperblock_wperblock =
|
||||
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock();
|
||||
|
||||
return in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize() *
|
||||
sizeof(InDataType);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeDefaultBlock2TileMap(const InGridDesc& desc)
|
||||
{
|
||||
return DefaultBlock2TileMap{desc};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CheckValidity(const InGridDesc& in_grid_desc,
|
||||
const OutGridDesc& out_grid_desc)
|
||||
{
|
||||
constexpr index_t NumDim = InGridDesc::GetNumOfDimension();
|
||||
|
||||
// check if we only swap last 2 dimensions
|
||||
bool valid = true;
|
||||
static_for<0, NumDim - 2, 1>{}([&](auto I) {
|
||||
if(valid && in_grid_desc.GetLength(I) != out_grid_desc.GetLength(I))
|
||||
{
|
||||
valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
return valid &&
|
||||
(in_grid_desc.GetLength(Number<NumDim - 1>{}) ==
|
||||
out_grid_desc.GetLength(Number<NumDim - 2>{})) &&
|
||||
(in_grid_desc.GetLength(Number<NumDim - 2>{}) ==
|
||||
out_grid_desc.GetLength(Number<NumDim - 1>{}));
|
||||
}
|
||||
|
||||
template <typename Block2TileMap>
|
||||
__device__ static void Run(const InGridDesc in_grid_desc,
|
||||
const OutGridDesc out_grid_desc,
|
||||
const InDataType* p_in_global,
|
||||
OutDataType* p_out_global,
|
||||
void* __restrict__ p_shared,
|
||||
const ElementwiseOperation elementwise_op,
|
||||
const Block2TileMap& block_2_tile_map)
|
||||
{
|
||||
auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_desc.GetElementSpaceSize());
|
||||
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc.GetElementSpaceSize());
|
||||
|
||||
// each workgroup handles an [NPerBlock, HPerBlock, WPerBLock] slice-transpose problem
|
||||
const auto block_work_idx =
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock);
|
||||
|
||||
const index_t h_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * HPerBlock);
|
||||
|
||||
const index_t w_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * WPerBlock);
|
||||
|
||||
// create [NPerBlock, HPerBlock, WPerBLock] shaped LDS buffer
|
||||
constexpr auto in_block_desc_nperblock_hperblock_wperblock =
|
||||
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock();
|
||||
|
||||
auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<InDataType*>(p_shared),
|
||||
in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize());
|
||||
|
||||
using BlockSliceLengths = Sequence<NPerBlock, HPerBlock, WPerBlock>;
|
||||
using InBlockTransferAccessOrder = Sequence<0, 1, 2>;
|
||||
|
||||
constexpr index_t SrcVectorDimAfterMerge =
|
||||
SrcVectorDim - (InGridDesc::GetNumOfDimension() - 3);
|
||||
constexpr index_t DstVectorDimAfterMerge = SrcVectorDimAfterMerge;
|
||||
|
||||
using ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// merge input descriptor into [(in_grid_desc.GetLength(0) x in_grid_desc.GetLength(1) x
|
||||
// ...), in_grid_desc.GetLength(NumDim - 2), in_grid_desc.GetLength(NumDim - 1)]
|
||||
const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc);
|
||||
|
||||
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from global memory to LDS
|
||||
auto in_global_load = ThreadGroupTensorSliceTransfer_v4r1<
|
||||
ThisThreadBlock,
|
||||
ElementwiseOperation,
|
||||
PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
BlockSliceLengths,
|
||||
InBlockTransferThreadClusterLengths,
|
||||
InBlockTransferThreadClusterArrangeOrder,
|
||||
InDataType,
|
||||
InDataType,
|
||||
decltype(in_grid_desc_n_h_w),
|
||||
decltype(in_block_desc_nperblock_hperblock_wperblock),
|
||||
InBlockTransferAccessOrder,
|
||||
InBlockTransferAccessOrder,
|
||||
SrcVectorDimAfterMerge,
|
||||
2,
|
||||
SrcScalarPerVector,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true,
|
||||
true>(in_grid_desc_n_h_w,
|
||||
make_multi_index(
|
||||
n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
|
||||
PassThrough{},
|
||||
in_block_desc_nperblock_hperblock_wperblock,
|
||||
make_multi_index(0, 0, 0),
|
||||
PassThrough{});
|
||||
|
||||
// merge output descriptor into [(out_grid_desc.GetLength(0) x out_grid_desc.GetLength(1) x
|
||||
// ...), out_grid_desc.GetLength(NumDim - 2), out_grid_desc.GetLength(NumDim - 1)]
|
||||
const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc);
|
||||
|
||||
// create transposed view of output tensor
|
||||
const auto out_grid_desc_n_h_w = transform_tensor_descriptor(
|
||||
out_grid_desc_n_w_h,
|
||||
make_tuple(make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I0)),
|
||||
make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I1)),
|
||||
make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I2))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}));
|
||||
|
||||
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from LDS to global memory
|
||||
auto out_global_store = ThreadGroupTensorSliceTransfer_v4r1<
|
||||
ThisThreadBlock,
|
||||
ElementwiseOperation,
|
||||
PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
BlockSliceLengths,
|
||||
InBlockTransferThreadClusterLengths,
|
||||
InBlockTransferThreadClusterArrangeOrder,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
decltype(in_block_desc_nperblock_hperblock_wperblock),
|
||||
decltype(out_grid_desc_n_h_w),
|
||||
InBlockTransferAccessOrder,
|
||||
InBlockTransferAccessOrder,
|
||||
2,
|
||||
DstVectorDimAfterMerge,
|
||||
1,
|
||||
DstScalarPerVector,
|
||||
1,
|
||||
1,
|
||||
true,
|
||||
true>(in_block_desc_nperblock_hperblock_wperblock,
|
||||
make_multi_index(0, 0, 0),
|
||||
PassThrough{},
|
||||
out_grid_desc_n_h_w,
|
||||
make_multi_index(
|
||||
n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
|
||||
elementwise_op);
|
||||
|
||||
in_global_load.Run(in_grid_desc_n_h_w,
|
||||
in_global_buf,
|
||||
in_block_desc_nperblock_hperblock_wperblock,
|
||||
in_block_buf,
|
||||
I0);
|
||||
|
||||
out_global_store.Run(in_block_desc_nperblock_hperblock_wperblock,
|
||||
in_block_buf,
|
||||
out_grid_desc_n_h_w,
|
||||
out_global_buf,
|
||||
I0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
67
include/ck/utility/span.hpp
Normal file
67
include/ck/utility/span.hpp
Normal file
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <array>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
class span
|
||||
{
|
||||
public:
|
||||
using element_type = T;
|
||||
using value_type = std::remove_cv_t<element_type>;
|
||||
using size_type = std::size_t;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using pointer = element_type*;
|
||||
using const_pointer = const element_type*;
|
||||
using reference = element_type&;
|
||||
using const_reference = const element_type&;
|
||||
using iterator = pointer;
|
||||
using const_iterator = pointer;
|
||||
|
||||
constexpr span() : span(nullptr, size_type{0}) {}
|
||||
|
||||
constexpr span(pointer first, size_type count) : ptr_(first), size_(count) {}
|
||||
|
||||
constexpr span(pointer first, pointer last) : span(first, last - first) {}
|
||||
|
||||
template <std::size_t N>
|
||||
constexpr span(element_type (&arr)[N]) noexcept : span(arr, N)
|
||||
{
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
constexpr span(std::array<value_type, N>& arr) noexcept : span(arr.data(), N)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
constexpr span(const Container& container) : span(container.data(), container.size())
|
||||
{
|
||||
}
|
||||
|
||||
constexpr iterator begin() const noexcept { return ptr_; }
|
||||
constexpr const_iterator cbegin() const noexcept { return begin(); }
|
||||
|
||||
constexpr iterator end() const noexcept { return begin() + size(); }
|
||||
constexpr const_iterator cend() const noexcept { return end(); }
|
||||
|
||||
constexpr reference front() const { return *begin(); }
|
||||
constexpr reference back() const { return *(--end()); }
|
||||
|
||||
constexpr reference operator[](size_type idx) const { return *(begin() + idx); }
|
||||
constexpr pointer data() const noexcept { return ptr_; }
|
||||
|
||||
constexpr size_type size() const noexcept { return size_; }
|
||||
|
||||
private:
|
||||
pointer ptr_;
|
||||
size_type size_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user