mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Add 'Permute' device op & example (#408)
* Add example folder for 'DeviceElementwise' * Re-structure example files * Move common parts into common.hpp * Use more strict input * Add more helper methods in 'DeviceElementwise' * Use more specific method to write example * Allow specify problem through command line argument * Allow specify problem 'axes' through command line argument * Add check to template type argument * Add transpose_shape() to generalize shape permute * Generalize transpose utility functions * Use better name for tensor indices * Add checks in helper functions * Remove debug messages * Refine error message for check_err() * Generalize variable naming in example code * Add device op 'DevicePermute' This device op is clone of 'DeviceElementwise' * Use 'DevicePermute' device op in example * Remove 'elementwise' from identifiers * Remove 'elementwise' from file paths * Remove base class of 'DevicePermute' * Let 'DevicePermute' inherit from 'BaseOperator' * Add simple type traits to validate device op type * Add static_assert() to check type constraints * Create 'DevicePermuteBase' to generate methods * Use indirect base type to generate methods * Remove 'is_device_op<>' type traits * Only accept single-input-single-output for 'DervicePermute' * Simplify 'DevicePermute' interface * Re-format 'DeviceElementwise' * Use CRTP to generate overridden virtual method * Remove unnecessary include directives * Distinguish input & output shape in 'DevicePermute' * Passing 'axes' to 'DevicePermute' * Use more reasonable return value for Invoker::Run() * Add 'GridwisePermute' kernel This kernel is a clone of 'GridwiseElementwise_1D' * Remove no-longer used type argument * Check if input/output shape meet the requirement * Remove no-longer used method * Remove never-entered-if-clause * Change problem description for 'DevicePermute' * Transform descriptor into 3 dimensions * Add debug code the verify result * Add comment to indicate template argument location * Add N/H/WPerBlock template parameter to 'DevicePermute' * Rename 'GridwisePermute' to 'GridwiseCopy' * Check tensor descriptor dimensions in 'GridwiseElementwise_1D' * Add missing include directive * Add 'BlockSize' parameter to 'DevicePermute' * Remove no-longer used method * Add 'BlockToTileMap' for 'GridwiseCopy' * Use the normal Block2TileMap convention * Rename 'BlockToTileMap' as 'Block2TileMap' * Fix most of compilation errors * Let 'Block2TileMap' map block to 2d coordinate * Allow data transfer in 'GridwiseCopy' * Fix wrong output descriptor for 2nd blockwise copy * Rename 'GridwiseCopy' as 'GridwisePermute' * Remove '1d' in identifiers * Remove commented-out codes * Remove 'MPerThread' template parameter * Seperate template parameters * Unify variable namming convention * Use more verbose way to create expressions * Add template parameter 'InBlockLdsExtraW' * Release the constraint on In/OutGridDesc * Use date type directly as template argument * Re-arrange template arguments for blockwise copy * Remove no-longer used template parameters * Embed layout in the variable names * Add GridwisePermute::CheckValidity() * Extract local types as template parameters * Rename local type alias * Add more template parameters (vector width related) * Calculate new SrcVectorDim/DstVectorDim after merge descriptor dimensions * Fill tensor values start from 1 * Re-formate example code * Avoid too-large block id * Add comment * Make sure 'SrcVectorDim' is not same as 'DstVectorDim' * Add check for the 'VectorDim' & 'ScalarPerVector' template params * Let 'DstVectorDim' equals 'SrcVectorDim' after transpose out grid desc * Remove no-longer used template parameter 'NPerBlock' * Fix wrong descriptor creation logics * Specify problem in each examples * Use better example name * Add new example 'example_permute_NxHxW_fp32' * Add example for demonstrating bundle multiple elems in tensor * Add support to permute multiple elements together * Change the default problem size * Add span<> class template * Use span<> to generalize check_err() interface * Fix ambiguous ctor call * Avoid create necessary objects * Use helper functions to simplify example code * Add example for 4xfp16 permute * Disable failed-to-compile example * Add check for the NUM_ELEMS_IN_BUNDLE * Remove redundant parameter in helper lambda function * Add check for the input tensor type's byte-size * Check scalar-per-vector with padded length * Use more verbose name to avoid name collision * Use fixed 'VectorDim' & 'ScalarPerVector' for LDS * Embed shape info in name of descriptor constructor * Rename example folder '36_permute' into '37_permute' * Avoid using too-large LDS in kernel code * Remove redundant example * Usw switch() to group similar codes * Add const to the span<> type arguement * Simply initialize tensor with floating point values * Use fp16 as data type in all examples * Enlarge tensor size in example * Enalrge N-dim in example * Add check for the bundled type in example * Use more stricter error threshold * Remove global load/store loop in kernel code * Measure execution time by default * Use faster device op config for example 'NxHxW_fp16' * Use faster device op config for example '1xHxW_fp16' * Use faster device op config for example 'HxWx4_fp16' * Remove cmd arg parsing logics * Rename functions * Extract bundle permutation logic out * Simplify permute bundle example * Add Tensor<>::GetElementSpaceSizeInBytes() * Add Tensor<>::data() * Use new methods to simplify code * Use type alias to replace duplicated code * Use existing method to shorten code * Allow FillUniformDistribution accept range arugment * Intialize random values in range * Add Tensor<>::size() * Use more meaningful names in permute bundle example * Use more meaningful names in permute element examples * Use rangified copy() to copy elements * Use function return value directly to eliminate variables * Add to_array() conversion tool to eliminate more variables * Add Tensor<>::AsSpan<>() to create view of tensor values * Use AsSpan() to shorten check_err() calls * Remove no-longer-used 'using' directives * Move 'using' directive to proper code position * Remove redudant variables * Remove useless static_assert() * Add check for range types * Declare variable right before first use * Move long return type as tailing return type * Add BaseInvokerCRTP<> class template to generate method * Create new base type for 'DervicePermute' implementations * Move 'NumDim' template param to the first * Rename 'DevicePermute' to 'DevicePermuteImpl' * Add 'noexcept' specifier to CRTP generated method * Move 'Block2TileMap' definition into 'GridwisePermute' * Use type alias to reduce code * Unify naming style in 'DevicePermute' * Add comments in 'GridwisePermute' * Rename permute example folder * Use std::cerr to report error * Use larger shape in examples * Rename '38_permute' to '39_permute' * Make sure we use unsigned type for shape & indices * Remove opt-ed out assertion * Remove template BaseInvokerCRTP<>
This commit is contained in:
@@ -0,0 +1,282 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Swap last 2 dimensions
|
||||
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
|
||||
// ^^^^^^^^^^^
|
||||
// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]]
|
||||
// ^^^^^^^^^^^
|
||||
template <index_t NumDim,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename ElementwiseOperation,
|
||||
index_t BlockSize,
|
||||
index_t NPerBlock,
|
||||
index_t HPerBlock,
|
||||
index_t WPerBlock,
|
||||
index_t InBlockLdsExtraW,
|
||||
typename InBlockTransferThreadClusterLengths,
|
||||
typename InBlockTransferThreadClusterArrangeOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector>
|
||||
struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>
|
||||
{
|
||||
using BaseType = DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>;
|
||||
using typename BaseType::Lengths;
|
||||
using typename BaseType::Strides;
|
||||
|
||||
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
|
||||
static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim);
|
||||
static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim);
|
||||
static_assert(SrcVectorDim != DstVectorDim);
|
||||
|
||||
template <index_t N = NumDim>
|
||||
static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array)
|
||||
{
|
||||
static_assert(1 <= N && N <= NumDim);
|
||||
|
||||
return generate_tuple([&](auto I) { return array[I]; }, Number<N>{});
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_N_H_W(const Lengths& lengths, const Strides& stride)
|
||||
{
|
||||
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
|
||||
// d[NumDim-1]]
|
||||
const auto desc =
|
||||
make_naive_tensor_descriptor(ConvertArrayToTuple(lengths), ConvertArrayToTuple(stride));
|
||||
|
||||
// merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2],
|
||||
// d[NumDim-1]]
|
||||
// => [N, H, W]
|
||||
const index_t H = *std::next(rbegin(lengths));
|
||||
const index_t W = *rbegin(lengths);
|
||||
const auto desc_n_h_w = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(ConvertArrayToTuple<NumDim - 2>(lengths)),
|
||||
make_pass_through_transform(H),
|
||||
make_pass_through_transform(W)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
|
||||
Sequence<NumDim - 2>{},
|
||||
Sequence<NumDim - 1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return PadTensorDescriptor(
|
||||
desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence<true, true, true>{});
|
||||
}
|
||||
|
||||
using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}));
|
||||
using OutGridDesc = InGridDesc;
|
||||
|
||||
using GridwisePermute = GridwisePermute<
|
||||
InGridDesc,
|
||||
OutGridDesc,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
ElementwiseOperation,
|
||||
BlockSize,
|
||||
NPerBlock,
|
||||
HPerBlock,
|
||||
WPerBlock,
|
||||
InBlockLdsExtraW,
|
||||
InBlockTransferThreadClusterLengths,
|
||||
InBlockTransferThreadClusterArrangeOrder,
|
||||
SrcVectorDim - (NumDim - 3), // calculate new SrcVectorDim for the merged descriptor
|
||||
DstVectorDim - (NumDim - 3), // calculate new DstVectorDim for the merged descriptor
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector>;
|
||||
|
||||
using Block2TileMap = typename GridwisePermute::DefaultBlock2TileMap;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const Lengths& in_lengths,
|
||||
const Strides& in_strides,
|
||||
const Lengths& out_lengths,
|
||||
const Strides& out_strides,
|
||||
const void* in_dev_buffer,
|
||||
void* out_dev_buffer,
|
||||
ElementwiseOperation elementwise_op)
|
||||
: in_dev_buffer_(static_cast<const InDataType*>(in_dev_buffer)),
|
||||
out_dev_buffer_(static_cast<OutDataType*>(out_dev_buffer)),
|
||||
in_grid_desc_(MakeDescriptor_N_H_W(in_lengths, in_strides)),
|
||||
out_grid_desc_(MakeDescriptor_N_H_W(out_lengths, out_strides)),
|
||||
in_lengths_(in_lengths),
|
||||
in_strides_(in_strides),
|
||||
out_lengths_(out_lengths),
|
||||
out_strides_(out_strides),
|
||||
elementwise_op_(elementwise_op),
|
||||
block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_))
|
||||
{
|
||||
}
|
||||
|
||||
const InDataType* in_dev_buffer_;
|
||||
OutDataType* out_dev_buffer_;
|
||||
InGridDesc in_grid_desc_;
|
||||
OutGridDesc out_grid_desc_;
|
||||
|
||||
Lengths in_lengths_;
|
||||
Strides in_strides_;
|
||||
Lengths out_lengths_;
|
||||
Strides out_strides_;
|
||||
|
||||
ElementwiseOperation elementwise_op_;
|
||||
|
||||
Block2TileMap block_2_tile_map_;
|
||||
};
|
||||
|
||||
struct Invoker : BaseInvoker
|
||||
{
|
||||
static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const index_t grid_size = arg.block_2_tile_map_.CalculateGridSize(arg.in_grid_desc_);
|
||||
|
||||
const auto kernel = kernel_nd_permute<GridwisePermute,
|
||||
InGridDesc,
|
||||
OutGridDesc,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
ElementwiseOperation,
|
||||
Block2TileMap>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.in_grid_desc_,
|
||||
arg.out_grid_desc_,
|
||||
arg.in_dev_buffer_,
|
||||
arg.out_dev_buffer_,
|
||||
arg.elementwise_op_,
|
||||
arg.block_2_tile_map_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override final
|
||||
{
|
||||
const auto* const argument = dynamic_cast<const Argument*>(arg);
|
||||
if(!argument)
|
||||
{
|
||||
return NAN;
|
||||
}
|
||||
|
||||
return Run(*argument, stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
constexpr auto GetPaddedLength = [](index_t length, index_t tile_length) {
|
||||
return math::integer_divide_ceil(length, tile_length) * tile_length;
|
||||
};
|
||||
|
||||
constexpr auto IsScalarPerVectorValid =
|
||||
[](index_t length, index_t stride, index_t scalar_per_vector) {
|
||||
if(stride == 1 && length % scalar_per_vector == 0)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if(stride != 1 && scalar_per_vector == 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
return IsScalarPerVectorValid(arg.in_lengths_[SrcVectorDim],
|
||||
arg.in_strides_[SrcVectorDim],
|
||||
SrcScalarPerVector) &&
|
||||
IsScalarPerVectorValid(
|
||||
GetPaddedLength(arg.in_lengths_[SrcVectorDim],
|
||||
(SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
|
||||
arg.in_strides_[SrcVectorDim],
|
||||
SrcScalarPerVector) &&
|
||||
IsScalarPerVectorValid(arg.out_lengths_[DstVectorDim],
|
||||
arg.out_strides_[DstVectorDim],
|
||||
DstScalarPerVector) &&
|
||||
IsScalarPerVectorValid(
|
||||
GetPaddedLength(arg.out_lengths_[DstVectorDim],
|
||||
(DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
|
||||
arg.in_strides_[DstVectorDim],
|
||||
DstScalarPerVector) &&
|
||||
GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
|
||||
};
|
||||
|
||||
// override methods inherited from 'BaseOperator'
|
||||
bool IsSupportedArgument(const BaseArgument* arg) override final
|
||||
{
|
||||
const auto* const argument = dynamic_cast<const Argument*>(arg);
|
||||
if(!argument)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return IsSupportedArgument(*argument);
|
||||
}
|
||||
|
||||
// override methods inherited from 'DevicePermute'
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const Lengths& in_lengths,
|
||||
const Strides& in_strides,
|
||||
const Lengths& out_lengths,
|
||||
const Strides& out_strides,
|
||||
const void* in_dev_buffer,
|
||||
void* out_dev_buffer,
|
||||
ElementwiseOperation elementwise_op) override final
|
||||
{
|
||||
return std::make_unique<Argument>(in_lengths,
|
||||
in_strides,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
in_dev_buffer,
|
||||
out_dev_buffer,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
// other constructor methods
|
||||
template <typename... Args>
|
||||
static std::enable_if_t<std::is_constructible_v<Argument, Args...>, Argument>
|
||||
MakeArgument(Args&&... args) noexcept(std::is_nothrow_constructible_v<Argument, Args...>)
|
||||
{
|
||||
return Argument{std::forward<Args>(args)...};
|
||||
}
|
||||
|
||||
static std::enable_if_t<std::is_default_constructible_v<Invoker>, Invoker>
|
||||
MakeInvoker() noexcept(std::is_nothrow_default_constructible_v<Invoker>)
|
||||
{
|
||||
return Invoker{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user