mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK-Tile] Merge transpose examples (#2450)
* unify pipeline signature with existing example * iwyu * move stuff around in load-tile-transpose * cleanups in batched transpose pipeline * comments * use same inputs size * cleaner printf * print host args * use 64 block sides in the 37_transpose example * roll back grid dimension size adjustment for 37_transpose example * transpose grid for 37_transpose to unify with 35_batched_transpose * unify grid computation logic * make policy methods device only (since they are used only on device from the pipeline) * more host/device attribute cleanups * copy over problem * move over pipeline and policy * add switch to batched transpose api * make the lds problem more similar to original problem * factor out logic into traits * factor out conditional compilation into trait parameter * propagate pipeline to args * unhardcode pipeline dispatch parameter * refactor vector size * put warp tile out of dispatch * rename template parameter for trait * rewrite vector size in terms of problem * mark policy-internal struct variable as device * factor out input distribution and thread access pattern from policies * reword vector size * use datatype across batched transpose pipelines, problems and kernel * remove transpose traits from lds pipeline * add padding to the lds pipeline *interface* * add comment * remove ck_tile example #37 * update cmakelists * add test for new pipeline * update batched transpose test * roll back load_tile_transpose changes * remove comments * pack dispatch parameters into a config * padM can be enabled * adjust lds vector size to enable padding along N * update test * clean up logic * swap m/n input vector size * adjust perf test script * sweep over C/W in perf test * count both read and written bytes into bandwidth (x2 the number) * clang-format * widen size range for perf test * remove 64k x 64k case; it's too large for index * remove thread tile from dispatch * Solve merge conflict * fix compile * modify the transpose * solve the test error and clang format * Add v3 support for Groupd fwd conv+bias+clamp & ckProfiler (#2463) * Add logging to IsSupported. * Less casting in AddClamp * Conv+bias+clamp instances & profiler BF16 * Fix 3D instances & run just 1x for verification. * :Run just once for verification conv fwd. * ckProfiler conv fwd clampwq * Remove exec bit & formatting * Add support for MultiD for grouped conv fwd v3. * Enable 2Lds. * clean * align instances * align instances * profiler fixes * Fixes * fix * fix --------- Co-authored-by: Adam Osewski <root@quanta-ccs-aus-f01-19.cs-aus.dcgpu> Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Fixing 0ms and inf GB/s issue in img2col (#2565) issue : ==== ``` sh $ bin/tile_example_img2col Perf: 0 ms, inf GB/s ``` solution : ====== Problem occured because config.time_kernel is false by default. if false, then no need to calculate perf, just print proper message `image_to_coloumn: pass, No Perf generated due to config.time_kernel=0` * merge with develop * solve clang format --------- Co-authored-by: ThomasNing <thomas.ning@amd.com> Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: Adam Osewski <root@quanta-ccs-aus-f01-19.cs-aus.dcgpu> Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> Co-authored-by: rahjain-amd <Rahul.Jain@amd.com>
This commit is contained in:
@@ -32,7 +32,7 @@ struct BatchedTransposeKernel
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = remove_cvref_t<typename Pipeline::Problem>;
|
||||
|
||||
using Type = typename Problem::InputType;
|
||||
using Type = typename Problem::DataType;
|
||||
|
||||
struct BatchedTransposeKargs
|
||||
{
|
||||
@@ -67,7 +67,7 @@ struct BatchedTransposeKernel
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::kBlockSize; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BatchedTransposeCommonPolicy
|
||||
{
|
||||
CK_TILE_DEVICE static constexpr auto TileAccessPattern =
|
||||
tile_distribution_pattern::thread_raked;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeInputDistribution()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t LeadDimPerBlock = Problem::kMPerBlock;
|
||||
constexpr index_t SecondDimPerBlock = Problem::kNPerBlock;
|
||||
|
||||
constexpr index_t kVectorSize = Problem::VectorSizeOutput;
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
SecondDimPerBlock,
|
||||
LeadDimPerBlock,
|
||||
kVectorSize,
|
||||
TileAccessPattern>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct BatchedTransposeLdsPipeline
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using DataType = remove_cvref_t<typename Problem::DataType>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock;
|
||||
static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock;
|
||||
|
||||
static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize<Problem>(); }
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename InputTileWindow, typename OutputTileWindow>
|
||||
CK_TILE_DEVICE void operator()(const InputTileWindow& input_window,
|
||||
OutputTileWindow& output_window)
|
||||
{
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
auto input_tile_window =
|
||||
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
|
||||
auto output_tile_window =
|
||||
make_tile_window(output_window, Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
DataType* p_lds_ptr = reinterpret_cast<DataType*>(smem);
|
||||
constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor<Problem>();
|
||||
auto input_lds_block =
|
||||
make_tensor_view<address_space_enum::lds>(p_lds_ptr, in_lds_block_desc);
|
||||
|
||||
constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor<Problem>();
|
||||
auto output_lds_block =
|
||||
make_tensor_view<address_space_enum::lds>(p_lds_ptr, out_lds_block_desc);
|
||||
|
||||
auto copy_to_lds_window =
|
||||
make_tile_window(input_lds_block,
|
||||
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
|
||||
{0, 0});
|
||||
auto load_from_lds_window =
|
||||
make_tile_window(output_lds_block,
|
||||
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeLdsLoadTileDistribution<Problem>());
|
||||
|
||||
auto x = load_tile(input_tile_window);
|
||||
|
||||
store_tile(copy_to_lds_window, x);
|
||||
block_sync_lds();
|
||||
|
||||
auto y = load_tile_transpose(load_from_lds_window);
|
||||
|
||||
store_tile(output_tile_window, y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,123 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "batched_transpose_common_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BatchedTransposeLdsPolicy : public BatchedTransposeCommonPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return integer_least_multiple(
|
||||
sizeof(typename Problem::DataType) *
|
||||
MakeLdsStoreBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeOutputDistribution()
|
||||
{
|
||||
constexpr auto input_dstr = MakeLdsLoadTileDistribution<Problem>();
|
||||
|
||||
using OutTileDstrEncode =
|
||||
typename OutputTileDistributionTraits<typename decltype(input_dstr)::DstrEncode,
|
||||
typename Problem::DataType>::TransposedDstrEncode;
|
||||
constexpr auto block_dstr = make_static_tile_distribution(OutTileDstrEncode{});
|
||||
|
||||
return block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
|
||||
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
|
||||
constexpr index_t kVectorSize = Problem::LDSVectorSize;
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kSecondDimPerBlock>{},
|
||||
number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}),
|
||||
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
|
||||
number<kVectorSize>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
|
||||
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
|
||||
constexpr index_t kVectorSize = Problem::LDSVectorSize;
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kSecondDimPerBlock>{},
|
||||
number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}),
|
||||
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
|
||||
number<kVectorSize>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
|
||||
number<kVectorSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsLoadTileDistribution()
|
||||
{
|
||||
using DataType = typename Problem::DataType;
|
||||
|
||||
// Calculate block-level dimensions
|
||||
constexpr index_t kLeadIterPerWarp = 1;
|
||||
constexpr index_t kSecondIterPerWarp = 1;
|
||||
constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps;
|
||||
constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps;
|
||||
|
||||
// Calculate repetitions of base pattern
|
||||
constexpr index_t kLeadRepetitions = Problem::kQuadNumPerLeadDim;
|
||||
constexpr index_t kSecondRepetitions = Problem::kQuadNumPerSecondDim;
|
||||
constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim;
|
||||
constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations;
|
||||
|
||||
constexpr index_t kLaneGroupSize = 16;
|
||||
constexpr auto xdllevel_dstr_encoding = make_transposed_distr_encode<DataType,
|
||||
kLaneGroupSize,
|
||||
kSecondDimStrSub,
|
||||
kSecondDimIterations,
|
||||
kLeadRepetitions,
|
||||
1>();
|
||||
|
||||
constexpr auto input_tile_encode =
|
||||
InputTileDistributionEncoding<decltype(xdllevel_dstr_encoding),
|
||||
kLeadIterPerWarp,
|
||||
kSecondIterPerWarp,
|
||||
kLeadNumWarps,
|
||||
kSecondNumWarps>();
|
||||
constexpr auto block_dstr = make_static_tile_distribution(input_tile_encode);
|
||||
return block_dstr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// supports 2D transpose which will store to lds,
|
||||
// then use ds_read_b*_tr_b* instruction to get the transposed data
|
||||
template <typename DataType_,
|
||||
typename BlockTile, // sequence<block_x, block_y>
|
||||
typename NumWarps,
|
||||
bool kPadM_,
|
||||
bool kPadN_>
|
||||
struct BatchedTransposeLdsProblem
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr index_t kRowWarps_ = NumWarps::at(number<1>{});
|
||||
static constexpr index_t kColWarps_ = NumWarps::at(number<0>{});
|
||||
static constexpr index_t kBlockSize_ = get_warp_size() * kRowWarps_ * kColWarps_;
|
||||
static constexpr index_t kRowPerBlock_ = BlockTile::at(number<1>{});
|
||||
static constexpr index_t kColPerBlock_ = BlockTile::at(number<0>{});
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
// warps per block
|
||||
static constexpr index_t kLeadNumWarps = kRowWarps_;
|
||||
static constexpr index_t kSecondNumWarps = kColWarps_;
|
||||
|
||||
static constexpr index_t kLeadSizePerBlock = kRowPerBlock_;
|
||||
static constexpr index_t kSecondSizePerBlock = kColPerBlock_;
|
||||
|
||||
static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits<DataType>::kleadDim;
|
||||
static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits<DataType>::ksecondDim;
|
||||
|
||||
static_assert(kLeadSizePerBlock % kLeadNumWarps == 0,
|
||||
"block dim should be divided by warp count!");
|
||||
static_assert(kSecondSizePerBlock % kSecondNumWarps == 0,
|
||||
"block dim should be divided by warp count!");
|
||||
// rows/cols per warp
|
||||
static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps;
|
||||
static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps;
|
||||
|
||||
static_assert(kLeadSizePerWarp % kQuadrantLeadDim == 0,
|
||||
"xdl dim should be divided by quad dim!");
|
||||
static_assert(kSecondSizePerWarp % kQuadrantSecondDim == 0,
|
||||
"xdl dim should be divided by quad dim!");
|
||||
// xdl rows/cols is divided into quadrants.
|
||||
static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerWarp / kQuadrantLeadDim;
|
||||
static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerWarp / kQuadrantSecondDim;
|
||||
|
||||
static constexpr index_t kIterationsInSecondDim =
|
||||
kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size();
|
||||
|
||||
// definitions to adapt to BatchedTransposeKernel
|
||||
|
||||
// FIXME: support padding
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
|
||||
static constexpr auto kMPerBlock = kLeadSizePerBlock;
|
||||
static constexpr auto kNPerBlock = kSecondSizePerBlock;
|
||||
|
||||
// 128-bit is the max single-instruction bandwidth for load/store
|
||||
static constexpr index_t MaxLoadStoreSize = 16;
|
||||
static constexpr auto VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType);
|
||||
static constexpr auto VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType);
|
||||
static constexpr auto LDSVectorSize = MaxLoadStoreSize / sizeof(DataType);
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -5,8 +5,6 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -14,15 +12,8 @@ template <typename Problem_, typename Policy_ = BatchedTransposePolicy>
|
||||
struct BatchedTransposePipeline
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using InputType = ck_tile::remove_cvref_t<typename Problem::InputType>;
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr index_t AlignmentM = Problem::AlignmentM;
|
||||
static constexpr index_t AlignmentN = Problem::AlignmentN;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
template <typename InputWindow, typename OutputWindow>
|
||||
CK_TILE_DEVICE auto operator()(const InputWindow& input_window, OutputWindow& out_window)
|
||||
@@ -32,7 +23,7 @@ struct BatchedTransposePipeline
|
||||
|
||||
auto input_tile = load_tile(inp_win);
|
||||
|
||||
auto output_tile = make_static_distributed_tensor<InputType>(
|
||||
auto output_tile = make_static_distributed_tensor<typename Problem::DataType>(
|
||||
Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
transpose_tile2d(output_tile, input_tile);
|
||||
|
||||
@@ -4,43 +4,25 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/softmax.hpp"
|
||||
#include "ck_tile/ops/topk.hpp"
|
||||
#include "batched_transpose_common_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BatchedTransposePolicy
|
||||
struct BatchedTransposePolicy : public BatchedTransposeCommonPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::kMPerBlock;
|
||||
constexpr index_t NPerBlock = Problem::kNPerBlock;
|
||||
constexpr index_t VecLoadSize = Problem::VectorSizeInput;
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
|
||||
CK_TILE_DEVICE static constexpr auto MakeOutputDistribution()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::kMPerBlock;
|
||||
constexpr index_t NPerBlock = Problem::kNPerBlock;
|
||||
constexpr index_t VecLoadSize = Problem::VectorSizeOutput;
|
||||
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<BlockSize,
|
||||
NPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
NPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
TileAccessPattern>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -6,42 +6,31 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
#define VectorLoadSize 16
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InputType_,
|
||||
template <typename DataType_,
|
||||
typename BlockTile, // Sequence<...
|
||||
typename WarpTile, // Sequence<...
|
||||
typename ThreadTile,
|
||||
typename WarpLayout,
|
||||
bool kPadM_ = false,
|
||||
bool kPadN_ = false> // Sequence<...
|
||||
struct BatchedTransposeProblem
|
||||
{
|
||||
using InputType = remove_cvref_t<InputType_>;
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
|
||||
static constexpr index_t kNPerThread = ThreadTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
|
||||
static constexpr index_t kNPerWarp = WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
|
||||
static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread;
|
||||
static constexpr index_t kMPerWarp = WarpLayout::at(number<0>{});
|
||||
static constexpr index_t kNPerWarp = WarpLayout::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kNPerBlock = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
|
||||
static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp;
|
||||
|
||||
static constexpr index_t kBlockSize =
|
||||
kMThreadPerWarp * kNThreadPerWarp * kMWarpPerBlock * kNWarpPerBlock;
|
||||
static constexpr index_t kBlockSize = kMPerWarp * kNPerWarp * get_warp_size();
|
||||
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
|
||||
static constexpr index_t VectorSizeInput = kPadM ? 1 : VectorLoadSize / sizeof(InputType);
|
||||
static constexpr index_t VectorSizeOutput = kPadN ? 1 : VectorLoadSize / sizeof(InputType);
|
||||
// 128-bit is the max single-instruction bandwidth for load/store
|
||||
static constexpr index_t MaxLoadStoreSize = 16;
|
||||
static constexpr index_t VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType);
|
||||
static constexpr index_t VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType);
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user