mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
Add 'Permute' device op & example (#408)
* Add example folder for 'DeviceElementwise'
* Re-structure example files
* Move common parts into common.hpp
* Use more strict input
* Add more helper methods in 'DeviceElementwise'
* Use more specific method to write example
* Allow specify problem through command line argument
* Allow specify problem 'axes' through command line argument
* Add check to template type argument
* Add transpose_shape() to generalize shape permute
* Generalize transpose utility functions
* Use better name for tensor indices
* Add checks in helper functions
* Remove debug messages
* Refine error message for check_err()
* Generalize variable naming in example code
* Add device op 'DevicePermute'
This device op is clone of 'DeviceElementwise'
* Use 'DevicePermute' device op in example
* Remove 'elementwise' from identifiers
* Remove 'elementwise' from file paths
* Remove base class of 'DevicePermute'
* Let 'DevicePermute' inherit from 'BaseOperator'
* Add simple type traits to validate device op type
* Add static_assert() to check type constraints
* Create 'DevicePermuteBase' to generate methods
* Use indirect base type to generate methods
* Remove 'is_device_op<>' type traits
* Only accept single-input-single-output for 'DervicePermute'
* Simplify 'DevicePermute' interface
* Re-format 'DeviceElementwise'
* Use CRTP to generate overridden virtual method
* Remove unnecessary include directives
* Distinguish input & output shape in 'DevicePermute'
* Passing 'axes' to 'DevicePermute'
* Use more reasonable return value for Invoker::Run()
* Add 'GridwisePermute' kernel
This kernel is a clone of 'GridwiseElementwise_1D'
* Remove no-longer used type argument
* Check if input/output shape meet the requirement
* Remove no-longer used method
* Remove never-entered-if-clause
* Change problem description for 'DevicePermute'
* Transform descriptor into 3 dimensions
* Add debug code the verify result
* Add comment to indicate template argument location
* Add N/H/WPerBlock template parameter to 'DevicePermute'
* Rename 'GridwisePermute' to 'GridwiseCopy'
* Check tensor descriptor dimensions in 'GridwiseElementwise_1D'
* Add missing include directive
* Add 'BlockSize' parameter to 'DevicePermute'
* Remove no-longer used method
* Add 'BlockToTileMap' for 'GridwiseCopy'
* Use the normal Block2TileMap convention
* Rename 'BlockToTileMap' as 'Block2TileMap'
* Fix most of compilation errors
* Let 'Block2TileMap' map block to 2d coordinate
* Allow data transfer in 'GridwiseCopy'
* Fix wrong output descriptor for 2nd blockwise copy
* Rename 'GridwiseCopy' as 'GridwisePermute'
* Remove '1d' in identifiers
* Remove commented-out codes
* Remove 'MPerThread' template parameter
* Seperate template parameters
* Unify variable namming convention
* Use more verbose way to create expressions
* Add template parameter 'InBlockLdsExtraW'
* Release the constraint on In/OutGridDesc
* Use date type directly as template argument
* Re-arrange template arguments for blockwise copy
* Remove no-longer used template parameters
* Embed layout in the variable names
* Add GridwisePermute::CheckValidity()
* Extract local types as template parameters
* Rename local type alias
* Add more template parameters (vector width related)
* Calculate new SrcVectorDim/DstVectorDim after merge descriptor dimensions
* Fill tensor values start from 1
* Re-formate example code
* Avoid too-large block id
* Add comment
* Make sure 'SrcVectorDim' is not same as 'DstVectorDim'
* Add check for the 'VectorDim' & 'ScalarPerVector' template params
* Let 'DstVectorDim' equals 'SrcVectorDim' after transpose out grid desc
* Remove no-longer used template parameter 'NPerBlock'
* Fix wrong descriptor creation logics
* Specify problem in each examples
* Use better example name
* Add new example 'example_permute_NxHxW_fp32'
* Add example for demonstrating bundle multiple elems in tensor
* Add support to permute multiple elements together
* Change the default problem size
* Add span<> class template
* Use span<> to generalize check_err() interface
* Fix ambiguous ctor call
* Avoid create necessary objects
* Use helper functions to simplify example code
* Add example for 4xfp16 permute
* Disable failed-to-compile example
* Add check for the NUM_ELEMS_IN_BUNDLE
* Remove redundant parameter in helper lambda function
* Add check for the input tensor type's byte-size
* Check scalar-per-vector with padded length
* Use more verbose name to avoid name collision
* Use fixed 'VectorDim' & 'ScalarPerVector' for LDS
* Embed shape info in name of descriptor constructor
* Rename example folder '36_permute' into '37_permute'
* Avoid using too-large LDS in kernel code
* Remove redundant example
* Usw switch() to group similar codes
* Add const to the span<> type arguement
* Simply initialize tensor with floating point values
* Use fp16 as data type in all examples
* Enlarge tensor size in example
* Enalrge N-dim in example
* Add check for the bundled type in example
* Use more stricter error threshold
* Remove global load/store loop in kernel code
* Measure execution time by default
* Use faster device op config for example 'NxHxW_fp16'
* Use faster device op config for example '1xHxW_fp16'
* Use faster device op config for example 'HxWx4_fp16'
* Remove cmd arg parsing logics
* Rename functions
* Extract bundle permutation logic out
* Simplify permute bundle example
* Add Tensor<>::GetElementSpaceSizeInBytes()
* Add Tensor<>::data()
* Use new methods to simplify code
* Use type alias to replace duplicated code
* Use existing method to shorten code
* Allow FillUniformDistribution accept range arugment
* Intialize random values in range
* Add Tensor<>::size()
* Use more meaningful names in permute bundle example
* Use more meaningful names in permute element examples
* Use rangified copy() to copy elements
* Use function return value directly to eliminate variables
* Add to_array() conversion tool to eliminate more variables
* Add Tensor<>::AsSpan<>() to create view of tensor values
* Use AsSpan() to shorten check_err() calls
* Remove no-longer-used 'using' directives
* Move 'using' directive to proper code position
* Remove redudant variables
* Remove useless static_assert()
* Add check for range types
* Declare variable right before first use
* Move long return type as tailing return type
* Add BaseInvokerCRTP<> class template to generate method
* Create new base type for 'DervicePermute' implementations
* Move 'NumDim' template param to the first
* Rename 'DevicePermute' to 'DevicePermuteImpl'
* Add 'noexcept' specifier to CRTP generated method
* Move 'Block2TileMap' definition into 'GridwisePermute'
* Use type alias to reduce code
* Unify naming style in 'DevicePermute'
* Add comments in 'GridwisePermute'
* Rename permute example folder
* Use std::cerr to report error
* Use larger shape in examples
* Rename '38_permute' to '39_permute'
* Make sure we use unsigned type for shape & indices
* Remove opt-ed out assertion
* Remove template BaseInvokerCRTP<>
[ROCm/composable_kernel commit: f584ab0c54]
This commit is contained in:
@@ -0,0 +1,282 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Swap last 2 dimensions
|
||||
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
|
||||
// ^^^^^^^^^^^
|
||||
// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]]
|
||||
// ^^^^^^^^^^^
|
||||
template <index_t NumDim,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename ElementwiseOperation,
|
||||
index_t BlockSize,
|
||||
index_t NPerBlock,
|
||||
index_t HPerBlock,
|
||||
index_t WPerBlock,
|
||||
index_t InBlockLdsExtraW,
|
||||
typename InBlockTransferThreadClusterLengths,
|
||||
typename InBlockTransferThreadClusterArrangeOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector>
|
||||
struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>
|
||||
{
|
||||
using BaseType = DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>;
|
||||
using typename BaseType::Lengths;
|
||||
using typename BaseType::Strides;
|
||||
|
||||
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
|
||||
static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim);
|
||||
static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim);
|
||||
static_assert(SrcVectorDim != DstVectorDim);
|
||||
|
||||
template <index_t N = NumDim>
|
||||
static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array)
|
||||
{
|
||||
static_assert(1 <= N && N <= NumDim);
|
||||
|
||||
return generate_tuple([&](auto I) { return array[I]; }, Number<N>{});
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_N_H_W(const Lengths& lengths, const Strides& stride)
|
||||
{
|
||||
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
|
||||
// d[NumDim-1]]
|
||||
const auto desc =
|
||||
make_naive_tensor_descriptor(ConvertArrayToTuple(lengths), ConvertArrayToTuple(stride));
|
||||
|
||||
// merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2],
|
||||
// d[NumDim-1]]
|
||||
// => [N, H, W]
|
||||
const index_t H = *std::next(rbegin(lengths));
|
||||
const index_t W = *rbegin(lengths);
|
||||
const auto desc_n_h_w = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(ConvertArrayToTuple<NumDim - 2>(lengths)),
|
||||
make_pass_through_transform(H),
|
||||
make_pass_through_transform(W)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
|
||||
Sequence<NumDim - 2>{},
|
||||
Sequence<NumDim - 1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return PadTensorDescriptor(
|
||||
desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence<true, true, true>{});
|
||||
}
|
||||
|
||||
using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}));
|
||||
using OutGridDesc = InGridDesc;
|
||||
|
||||
using GridwisePermute = GridwisePermute<
|
||||
InGridDesc,
|
||||
OutGridDesc,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
ElementwiseOperation,
|
||||
BlockSize,
|
||||
NPerBlock,
|
||||
HPerBlock,
|
||||
WPerBlock,
|
||||
InBlockLdsExtraW,
|
||||
InBlockTransferThreadClusterLengths,
|
||||
InBlockTransferThreadClusterArrangeOrder,
|
||||
SrcVectorDim - (NumDim - 3), // calculate new SrcVectorDim for the merged descriptor
|
||||
DstVectorDim - (NumDim - 3), // calculate new DstVectorDim for the merged descriptor
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector>;
|
||||
|
||||
using Block2TileMap = typename GridwisePermute::DefaultBlock2TileMap;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const Lengths& in_lengths,
|
||||
const Strides& in_strides,
|
||||
const Lengths& out_lengths,
|
||||
const Strides& out_strides,
|
||||
const void* in_dev_buffer,
|
||||
void* out_dev_buffer,
|
||||
ElementwiseOperation elementwise_op)
|
||||
: in_dev_buffer_(static_cast<const InDataType*>(in_dev_buffer)),
|
||||
out_dev_buffer_(static_cast<OutDataType*>(out_dev_buffer)),
|
||||
in_grid_desc_(MakeDescriptor_N_H_W(in_lengths, in_strides)),
|
||||
out_grid_desc_(MakeDescriptor_N_H_W(out_lengths, out_strides)),
|
||||
in_lengths_(in_lengths),
|
||||
in_strides_(in_strides),
|
||||
out_lengths_(out_lengths),
|
||||
out_strides_(out_strides),
|
||||
elementwise_op_(elementwise_op),
|
||||
block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_))
|
||||
{
|
||||
}
|
||||
|
||||
const InDataType* in_dev_buffer_;
|
||||
OutDataType* out_dev_buffer_;
|
||||
InGridDesc in_grid_desc_;
|
||||
OutGridDesc out_grid_desc_;
|
||||
|
||||
Lengths in_lengths_;
|
||||
Strides in_strides_;
|
||||
Lengths out_lengths_;
|
||||
Strides out_strides_;
|
||||
|
||||
ElementwiseOperation elementwise_op_;
|
||||
|
||||
Block2TileMap block_2_tile_map_;
|
||||
};
|
||||
|
||||
struct Invoker : BaseInvoker
|
||||
{
|
||||
static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const index_t grid_size = arg.block_2_tile_map_.CalculateGridSize(arg.in_grid_desc_);
|
||||
|
||||
const auto kernel = kernel_nd_permute<GridwisePermute,
|
||||
InGridDesc,
|
||||
OutGridDesc,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
ElementwiseOperation,
|
||||
Block2TileMap>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.in_grid_desc_,
|
||||
arg.out_grid_desc_,
|
||||
arg.in_dev_buffer_,
|
||||
arg.out_dev_buffer_,
|
||||
arg.elementwise_op_,
|
||||
arg.block_2_tile_map_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override final
|
||||
{
|
||||
const auto* const argument = dynamic_cast<const Argument*>(arg);
|
||||
if(!argument)
|
||||
{
|
||||
return NAN;
|
||||
}
|
||||
|
||||
return Run(*argument, stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
constexpr auto GetPaddedLength = [](index_t length, index_t tile_length) {
|
||||
return math::integer_divide_ceil(length, tile_length) * tile_length;
|
||||
};
|
||||
|
||||
constexpr auto IsScalarPerVectorValid =
|
||||
[](index_t length, index_t stride, index_t scalar_per_vector) {
|
||||
if(stride == 1 && length % scalar_per_vector == 0)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if(stride != 1 && scalar_per_vector == 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
return IsScalarPerVectorValid(arg.in_lengths_[SrcVectorDim],
|
||||
arg.in_strides_[SrcVectorDim],
|
||||
SrcScalarPerVector) &&
|
||||
IsScalarPerVectorValid(
|
||||
GetPaddedLength(arg.in_lengths_[SrcVectorDim],
|
||||
(SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
|
||||
arg.in_strides_[SrcVectorDim],
|
||||
SrcScalarPerVector) &&
|
||||
IsScalarPerVectorValid(arg.out_lengths_[DstVectorDim],
|
||||
arg.out_strides_[DstVectorDim],
|
||||
DstScalarPerVector) &&
|
||||
IsScalarPerVectorValid(
|
||||
GetPaddedLength(arg.out_lengths_[DstVectorDim],
|
||||
(DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
|
||||
arg.in_strides_[DstVectorDim],
|
||||
DstScalarPerVector) &&
|
||||
GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
|
||||
};
|
||||
|
||||
// override methods inherited from 'BaseOperator'
|
||||
bool IsSupportedArgument(const BaseArgument* arg) override final
|
||||
{
|
||||
const auto* const argument = dynamic_cast<const Argument*>(arg);
|
||||
if(!argument)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return IsSupportedArgument(*argument);
|
||||
}
|
||||
|
||||
// override methods inherited from 'DevicePermute'
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const Lengths& in_lengths,
|
||||
const Strides& in_strides,
|
||||
const Lengths& out_lengths,
|
||||
const Strides& out_strides,
|
||||
const void* in_dev_buffer,
|
||||
void* out_dev_buffer,
|
||||
ElementwiseOperation elementwise_op) override final
|
||||
{
|
||||
return std::make_unique<Argument>(in_lengths,
|
||||
in_strides,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
in_dev_buffer,
|
||||
out_dev_buffer,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
// other constructor methods
|
||||
template <typename... Args>
|
||||
static std::enable_if_t<std::is_constructible_v<Argument, Args...>, Argument>
|
||||
MakeArgument(Args&&... args) noexcept(std::is_nothrow_constructible_v<Argument, Args...>)
|
||||
{
|
||||
return Argument{std::forward<Args>(args)...};
|
||||
}
|
||||
|
||||
static std::enable_if_t<std::is_default_constructible_v<Invoker>, Invoker>
|
||||
MakeInvoker() noexcept(std::is_nothrow_default_constructible_v<Invoker>)
|
||||
{
|
||||
return Invoker{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user