mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
Vectorized Transpose for Batched Transpose CK Tile Operator (#2131)
* Shared Memory for single data point * CKTile Transpose vectorize CP1 * CKTile Transpose vectorize CP2 * CKTile Transpose vectorize CP2.1 * fixed the compile error of the transpose tile 2d * Have the correct result for the current test sample * Changes to printing tensor * fp8 support added * Debugging for transpose * solving the corner issue * Changed padding flag * Intermideate Debugging * Intermidiate Debugging * Intermediate Debugging * Finished debugging of the transpose op * Code Cleanup * Adding edge case smoke tests * Adding Transpose test to CI/CD * Adding Transpose test to CI/CD * Adding Transpose test to CI/CD * Addressing Review Comment * Addressing Comments * Addressing Comments * Measuring Perf Tests * Code Cleanup * Changlog * Added the running iterations * clang format * Fix the changelog * Fix the compilation error * change the printing factor --------- Co-authored-by: ThruptiRajLakshmanaGowda <tlakshma@amd.com>
This commit is contained in:
@@ -384,22 +384,6 @@ struct tensor_view
|
||||
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_view{");
|
||||
|
||||
// buf_
|
||||
printf("buf_: ");
|
||||
print(buf_);
|
||||
printf(", ");
|
||||
|
||||
// desc_
|
||||
printf("desc_: ");
|
||||
print(desc_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// member
|
||||
buffer_view buf_;
|
||||
TensorDesc desc_;
|
||||
@@ -494,6 +478,7 @@ template <typename TensorView,
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads)
|
||||
{
|
||||
|
||||
constexpr index_t num_dim = DoPads::size();
|
||||
|
||||
static_assert(num_dim == TileLengths::size() && num_dim == TensorView::get_num_of_dimension(),
|
||||
|
||||
@@ -85,7 +85,12 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
|
||||
|
||||
// SFC
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
|
||||
[&](auto i) {
|
||||
if constexpr(vec_length_in == 1)
|
||||
return 1;
|
||||
else
|
||||
return (i == y_dim_vec_in || i == y_dim_vec_out) ? y_lengths[i] : 1;
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
|
||||
@@ -103,13 +108,19 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
|
||||
// loop over SFC
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
// data index [y0, y1, ...] in the order of input tensor
|
||||
constexpr auto idx_y = SFC_Y::get_index(iAccess);
|
||||
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y);
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y);
|
||||
|
||||
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
|
||||
constexpr auto idx_y_in =
|
||||
generate_tuple([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
|
||||
static_assert(in_offset % vec_length_in == 0);
|
||||
constexpr auto idx_y_out_tmp =
|
||||
generate_array([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
|
||||
constexpr auto idx_y_out =
|
||||
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
|
||||
if constexpr(vec_length_in == 1)
|
||||
{
|
||||
|
||||
out_tensor.get_thread_buffer()[number<out_offset>{}] =
|
||||
in_tensor.get_thread_buffer()[number<in_offset>{}];
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ struct BatchedTransposeHostArgs
|
||||
index_t batch;
|
||||
index_t height;
|
||||
index_t width;
|
||||
// index_t dim_blocks;
|
||||
index_t dim_stride;
|
||||
index_t dim_block_h;
|
||||
index_t dim_block_w;
|
||||
@@ -28,8 +27,10 @@ struct BatchedTransposeHostArgs
|
||||
template <typename Pipeline_>
|
||||
struct BatchedTransposeKernel
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = remove_cvref_t<typename Pipeline::Problem>;
|
||||
|
||||
CK_TILE_DEVICE static index_t counter = 0;
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = remove_cvref_t<typename Pipeline::Problem>;
|
||||
|
||||
using Type = typename Problem::InputType;
|
||||
|
||||
@@ -46,11 +47,11 @@ struct BatchedTransposeKernel
|
||||
using Kargs = BatchedTransposeKargs;
|
||||
using Hargs = BatchedTransposeHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& host_args)
|
||||
{
|
||||
size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w;
|
||||
size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h;
|
||||
size_t grid_size_z = h.batch;
|
||||
size_t grid_size_x = (host_args.height + host_args.dim_block_h - 1) / host_args.dim_block_h;
|
||||
size_t grid_size_y = (host_args.width + host_args.dim_block_w - 1) / host_args.dim_block_w;
|
||||
size_t grid_size_z = host_args.batch;
|
||||
return dim3(grid_size_x, grid_size_y, grid_size_z);
|
||||
}
|
||||
|
||||
@@ -70,58 +71,52 @@ struct BatchedTransposeKernel
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr ck_tile::index_t VectorSizeInput = Problem::VectorSizeInput;
|
||||
static constexpr ck_tile::index_t VectorSizeOutput = Problem::VectorSizeOutput;
|
||||
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
|
||||
const auto iDim = blockIdx.z;
|
||||
|
||||
static constexpr ck_tile::index_t kMPerThread = Problem::kMPerThread;
|
||||
static constexpr ck_tile::index_t kNPerThread = Problem::kNPerThread;
|
||||
|
||||
static_assert(kMPerThread == 1 && kNPerThread == 1);
|
||||
|
||||
const auto iDim = blockIdx.z;
|
||||
const auto x_m_n = [&]() {
|
||||
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const Type*>(kargs.p_input) + iDim * kargs.dim_stride,
|
||||
make_tuple(kargs.height, kargs.width),
|
||||
make_tuple(kargs.width, 1),
|
||||
number<kNPerThread>{}, // TODO thread load value
|
||||
number<VectorSizeInput>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(x_dram_naive,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
sequence<kPadN, kPadM>{});
|
||||
}();
|
||||
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
|
||||
|
||||
const auto y_n_m = [&]() {
|
||||
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<Type*>(kargs.p_output) + iDim * kargs.dim_stride,
|
||||
make_tuple(kargs.width, kargs.height),
|
||||
make_tuple(kargs.height, 1),
|
||||
number<kMPerThread>{},
|
||||
number<VectorSizeOutput>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(y_dram_naive,
|
||||
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
|
||||
sequence<kPadN, kPadM>{});
|
||||
sequence<kPadM, kPadN>{});
|
||||
}();
|
||||
|
||||
auto x_block_window =
|
||||
make_tile_window(x_m_n,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iM * kMPerBlock),
|
||||
static_cast<ck_tile::index_t>(iN * kNPerBlock)});
|
||||
auto x_block_window = make_tile_window(
|
||||
x_m_n,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iM), static_cast<ck_tile::index_t>(iN)});
|
||||
|
||||
auto y_block_window =
|
||||
make_tile_window(y_n_m,
|
||||
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iN * kNPerBlock),
|
||||
static_cast<ck_tile::index_t>(iM * kMPerBlock)});
|
||||
auto y_block_window = make_tile_window(
|
||||
y_n_m,
|
||||
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iN), static_cast<ck_tile::index_t>(iM)});
|
||||
|
||||
Pipeline{}(x_block_window, y_block_window);
|
||||
}
|
||||
|
||||
@@ -29,24 +29,18 @@ struct BatchedTransposePipeline
|
||||
{
|
||||
auto inp_win =
|
||||
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
|
||||
|
||||
auto input_tile = load_tile(inp_win);
|
||||
|
||||
auto output_tile = make_static_distributed_tensor<InputType>(
|
||||
Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
transpose_tile2d(output_tile, input_tile);
|
||||
|
||||
auto out_win =
|
||||
make_tile_window(out_window, Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
auto x = load_tile(inp_win); // x->thread input_win->block
|
||||
|
||||
auto y = make_static_distributed_tensor<InputType>(
|
||||
Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
constexpr auto span_2d_x = decltype(x)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(span_2d_x[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d_x[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx1, idx0);
|
||||
y(i_j_idx) = x(i_j_idx);
|
||||
});
|
||||
});
|
||||
|
||||
store_tile(out_win, y);
|
||||
store_tile(out_win, output_tile);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -14,31 +14,34 @@ struct BatchedTransposePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
|
||||
{
|
||||
using S = Problem;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>,
|
||||
sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>{});
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::kMPerBlock;
|
||||
constexpr index_t NPerBlock = Problem::kNPerBlock;
|
||||
constexpr index_t VecLoadSize = Problem::VectorSizeInput;
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
|
||||
{
|
||||
using S = Problem;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>,
|
||||
sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>>,
|
||||
tuple<sequence<2, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<2, 2>>{});
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::kMPerBlock;
|
||||
constexpr index_t NPerBlock = Problem::kNPerBlock;
|
||||
constexpr index_t VecLoadSize = Problem::VectorSizeOutput;
|
||||
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<BlockSize,
|
||||
NPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#define VectorLoadSize 16
|
||||
@@ -12,11 +11,11 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InputType_,
|
||||
typename BlockTile, // Sequence<...
|
||||
typename WarpTile, // Sequence<...
|
||||
typename ThreadTile, // Sequence<...
|
||||
bool kPadM_ = true,
|
||||
bool kPadN_ = true>
|
||||
typename BlockTile, // Sequence<...
|
||||
typename WarpTile, // Sequence<...
|
||||
typename ThreadTile,
|
||||
bool kPadM_ = false,
|
||||
bool kPadN_ = false> // Sequence<...
|
||||
struct BatchedTransposeProblem
|
||||
{
|
||||
using InputType = remove_cvref_t<InputType_>;
|
||||
@@ -42,7 +41,7 @@ struct BatchedTransposeProblem
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
|
||||
static constexpr index_t AlignmentM = kPadM ? VectorLoadSize / sizeof(InputType) : 1; // TODO
|
||||
static constexpr index_t AlignmentN = kPadN ? VectorLoadSize / sizeof(InputType) : 1;
|
||||
static constexpr index_t VectorSizeInput = kPadM ? 1 : VectorLoadSize / sizeof(InputType);
|
||||
static constexpr index_t VectorSizeOutput = kPadN ? 1 : VectorLoadSize / sizeof(InputType);
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user