Vectorized Transpose for Batched Transpose CK Tile Operator (#2131)

* Shared Memory for single data point

* CKTile Transpose vectorize CP1

* CKTile Transpose vectorize CP2

* CKTile Transpose vectorize CP2.1

* fixed the compile error of the transpose tile 2d

* Have the correct result for the current test sample

* Changes to printing tensor

* fp8 support added

* Debugging for transpose

* solving the corner issue

* Changed padding flag

* Intermideate Debugging

* Intermidiate Debugging

* Intermediate Debugging

* Finished debugging of the transpose op

* Code Cleanup

* Adding edge case smoke tests

* Adding Transpose test to CI/CD

* Adding Transpose test to CI/CD

* Adding Transpose test to CI/CD

* Addressing Review Comment

* Addressing Comments

* Addressing Comments

* Measuring Perf Tests

* Code Cleanup

* Changlog

* Added the running iterations

* clang format

* Fix the changelog

* Fix the compilation error

* change the printing factor

---------

Co-authored-by: ThruptiRajLakshmanaGowda <tlakshma@amd.com>
This commit is contained in:
Thomas Ning
2025-05-12 00:41:45 -07:00
committed by GitHub
parent d8faf1c6a1
commit 9d1e44e56a
14 changed files with 311 additions and 152 deletions

View File

@@ -29,24 +29,18 @@ struct BatchedTransposePipeline
{
auto inp_win =
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
auto input_tile = load_tile(inp_win);
auto output_tile = make_static_distributed_tensor<InputType>(
Policy::template MakeOutputDistribution<Problem>());
transpose_tile2d(output_tile, input_tile);
auto out_win =
make_tile_window(out_window, Policy::template MakeOutputDistribution<Problem>());
auto x = load_tile(inp_win); // x->thread input_win->block
auto y = make_static_distributed_tensor<InputType>(
Policy::template MakeOutputDistribution<Problem>());
constexpr auto span_2d_x = decltype(x)::get_distributed_spans();
sweep_tile_span(span_2d_x[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d_x[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx1, idx0);
y(i_j_idx) = x(i_j_idx);
});
});
store_tile(out_win, y);
store_tile(out_win, output_tile);
}
};
} // namespace ck_tile

View File

@@ -14,31 +14,34 @@ struct BatchedTransposePolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
{
using S = Problem;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>,
sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::kMPerBlock;
constexpr index_t NPerBlock = Problem::kNPerBlock;
constexpr index_t VecLoadSize = Problem::VectorSizeInput;
using TileEncodingPattern =
TileDistributionEncodingPattern2D<BlockSize,
MPerBlock,
NPerBlock,
VecLoadSize,
tile_distribution_pattern::thread_raked>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
{
using S = Problem;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>,
sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>>,
tuple<sequence<2, 1>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<2, 1>,
sequence<2, 2>>{});
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::kMPerBlock;
constexpr index_t NPerBlock = Problem::kNPerBlock;
constexpr index_t VecLoadSize = Problem::VectorSizeOutput;
using TileEncodingPattern =
TileDistributionEncodingPattern2D<BlockSize,
NPerBlock,
MPerBlock,
VecLoadSize,
tile_distribution_pattern::thread_raked>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
};
} // namespace ck_tile

View File

@@ -4,7 +4,6 @@
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
#define VectorLoadSize 16
@@ -12,11 +11,11 @@
namespace ck_tile {
template <typename InputType_,
typename BlockTile, // Sequence<...
typename WarpTile, // Sequence<...
typename ThreadTile, // Sequence<...
bool kPadM_ = true,
bool kPadN_ = true>
typename BlockTile, // Sequence<...
typename WarpTile, // Sequence<...
typename ThreadTile,
bool kPadM_ = false,
bool kPadN_ = false> // Sequence<...
struct BatchedTransposeProblem
{
using InputType = remove_cvref_t<InputType_>;
@@ -42,7 +41,7 @@ struct BatchedTransposeProblem
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr index_t AlignmentM = kPadM ? VectorLoadSize / sizeof(InputType) : 1; // TODO
static constexpr index_t AlignmentN = kPadN ? VectorLoadSize / sizeof(InputType) : 1;
static constexpr index_t VectorSizeInput = kPadM ? 1 : VectorLoadSize / sizeof(InputType);
static constexpr index_t VectorSizeOutput = kPadN ? 1 : VectorLoadSize / sizeof(InputType);
};
} // namespace ck_tile