mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
fix some issues
This commit is contained in:
@@ -71,8 +71,8 @@ struct BatchedTransposeKernel
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
__shared__ char smem[Pipeline::GetSmemSize()];
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::kSecondDimWarps;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::kLeadDimWarps;
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::kSecondSizePerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::kLeadSizePerBlock;
|
||||
|
||||
const auto iDim = blockIdx.z;
|
||||
const auto x_m_n = [&]() {
|
||||
@@ -88,8 +88,8 @@ struct BatchedTransposeKernel
|
||||
sequence<false, false>{});
|
||||
}();
|
||||
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.y * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.x * kNPerBlock);
|
||||
|
||||
const auto y_n_m = [&]() {
|
||||
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
|
||||
@@ -30,10 +30,10 @@ template <typename DataType_,
|
||||
index_t kBlockSize_,
|
||||
index_t kRowWarps_, // how many warps in row direction
|
||||
index_t kColWarps_, // how many warps in col direction
|
||||
index_t kRowPerBlock_,
|
||||
index_t kColPerBlock_,
|
||||
index_t kRowPerWarp_, // this is the row number per warp per iteration
|
||||
index_t kColPerWarp_> // this is the col number per warp per iteration
|
||||
index_t kRowPerBlock_, // row number per block
|
||||
index_t kColPerBlock_, // col number per block
|
||||
index_t kRowPerXdl_, // row number per xdl ops
|
||||
index_t kColPerXdl_> // col number per xdl ops
|
||||
struct TransposePipelineProblem
|
||||
{
|
||||
static_assert(kRowWarps_ * kColWarps_ * get_warp_size() == kBlockSize_,
|
||||
@@ -41,18 +41,41 @@ struct TransposePipelineProblem
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using Layout = remove_cvref_t<Layout_>;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kLeadDimWarps =
|
||||
static constexpr index_t kLeadNumWarps =
|
||||
TransposeTraits<Layout, kRowWarps_, kColWarps_>::kLeadDim;
|
||||
static constexpr index_t kSecondDimWarps =
|
||||
static constexpr index_t kSecondNumWarps =
|
||||
TransposeTraits<Layout, kRowWarps_, kColWarps_>::kSecondDim;
|
||||
static constexpr index_t kLeadDimPerBlock =
|
||||
static constexpr index_t kLeadSizePerBlock =
|
||||
TransposeTraits<Layout, kRowPerBlock_, kColPerBlock_>::kLeadDim;
|
||||
static constexpr index_t kSecondDimPerBlock =
|
||||
static constexpr index_t kSecondSizePerBlock =
|
||||
TransposeTraits<Layout, kRowPerBlock_, kColPerBlock_>::kSecondDim;
|
||||
static constexpr index_t kLeadDimPerWarp =
|
||||
TransposeTraits<Layout, kRowPerWarp_, kColPerWarp_>::kLeadDim;
|
||||
static constexpr index_t kSecondDimPerWarp =
|
||||
TransposeTraits<Layout, kRowPerWarp_, kColPerWarp_>::kSecondDim;
|
||||
static constexpr index_t kLeadSizePerXdl =
|
||||
TransposeTraits<Layout, kRowPerXdl_, kColPerXdl_>::kLeadDim;
|
||||
static constexpr index_t kSecondSizePerXdl =
|
||||
TransposeTraits<Layout, kRowPerXdl_, kColPerXdl_>::kSecondDim;
|
||||
|
||||
static constexpr index_t kQuadrantLeadDim = QuartTransposeTraits<DataType>::kleadDim;
|
||||
static constexpr index_t kQuadrantSecondDim = QuartTransposeTraits<DataType>::ksecondDim;
|
||||
|
||||
static_assert(kLeadSizePerBlock % kLeadNumWarps == 0, "block dim should be divided by warp dim!");
|
||||
static_assert(kSecondSizePerBlock % kSecondNumWarps == 0, "block dim should be divided by warp dim!");
|
||||
|
||||
static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps;
|
||||
static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps;
|
||||
|
||||
static_assert(kLeadSizePerWarp % kLeadSizePerXdl == 0, "warp dim should be divided by xdl dim!");
|
||||
static_assert(kSecondSizePerWarp % kSecondSizePerXdl == 0, "warp dim should be divided by xdl dim!");
|
||||
|
||||
static constexpr index_t kLeadXdlNumPerWarp = kLeadSizePerWarp / kLeadSizePerXdl;
|
||||
static constexpr index_t kSecondXdlNumPerWarp = kSecondSizePerWarp / kSecondSizePerXdl;
|
||||
|
||||
static_assert(kLeadSizePerXdl % kQuadrantLeadDim == 0, "xdl dim should be divided by quad dim!");
|
||||
static_assert(kSecondSizePerXdl % kQuadrantSecondDim == 0, "xdl dim should be divided by quad dim!");
|
||||
|
||||
static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerXdl / kQuadrantLeadDim;
|
||||
static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerXdl / kQuadrantSecondDim;
|
||||
|
||||
static constexpr index_t kIterationsPerSecondDim = kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size();
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = TransposePolicy>
|
||||
@@ -65,32 +88,12 @@ struct BlockTranspose
|
||||
using Layout = remove_cvref_t<typename Problem::Layout>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kLeadDimPerBlock = Problem::kLeadDimPerBlock;
|
||||
static constexpr index_t kSecondDimPerBlock = Problem::kSecondDimPerBlock;
|
||||
static constexpr index_t kLeadDimPerWarp = Problem::kLeadDimPerWarp;
|
||||
static constexpr index_t kSecondDimPerWarp = Problem::kSecondDimPerWarp;
|
||||
|
||||
static constexpr index_t kQuadrantLeadDim = QuartTransposeTraits<DataType>::kleadDim;
|
||||
static constexpr index_t kQuadrantSecondDim = QuartTransposeTraits<DataType>::ksecondDim;
|
||||
|
||||
static_assert(kLeadDimPerBlock % kLeadDimPerWarp == 0, "row per block is not correct!");
|
||||
static_assert(kSecondDimPerBlock % kSecondDimPerWarp == 0, "col per block is not correct!");
|
||||
|
||||
static_assert(kLeadDimPerWarp % kQuadrantLeadDim == 0, "row per warp is not correct!");
|
||||
static_assert(kSecondDimPerWarp % kQuadrantSecondDim == 0, "col per warp is not correct!");
|
||||
|
||||
static constexpr index_t kNumWarpInLeadDim = kLeadDimPerBlock / kLeadDimPerWarp;
|
||||
static constexpr index_t kNumWarpInSecondDim = kSecondDimPerBlock / kSecondDimPerWarp;
|
||||
|
||||
static constexpr index_t kLeadDimPerWarpInQuadrant = kLeadDimPerWarp / kQuadrantLeadDim;
|
||||
static constexpr index_t kSecondDimPerWarpInQuadrant = kSecondDimPerWarp / kQuadrantSecondDim;
|
||||
|
||||
// this pipeline is only designed for wave64 now
|
||||
static_assert(get_warp_size() == 64, "the warp size is not correct!");
|
||||
static_assert(kBlockSize == kNumWarpInLeadDim * kNumWarpInSecondDim * get_warp_size(),
|
||||
"the block size is not correct!");
|
||||
//static_assert(kLeadDimPerWarpInQuadrant * kSecondDimPerWarpInQuadrant * 4 == get_warp_size(),
|
||||
// "the warp size is not correct!");
|
||||
static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock;
|
||||
static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock;
|
||||
static constexpr index_t kLeadSizePerXdl = Problem::kLeadSizePerXdl;
|
||||
static constexpr index_t kSecondSizePerXdl = Problem::kSecondSizePerXdl;
|
||||
static constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps;
|
||||
static constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps;
|
||||
|
||||
static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize<Problem>(); }
|
||||
|
||||
@@ -119,12 +122,12 @@ struct BlockTranspose
|
||||
|
||||
auto copy_to_lds_window =
|
||||
make_tile_window(input_lds_block,
|
||||
make_tuple(number<kSecondDimPerBlock>{}, number<kLeadDimPerBlock>{}),
|
||||
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
|
||||
{0, 0});
|
||||
|
||||
auto load_from_lds_window =
|
||||
make_tile_window(output_lds_block,
|
||||
make_tuple(number<kSecondDimPerBlock>{}, number<kLeadDimPerBlock>{}),
|
||||
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeLdsLoadTileDistribution<Problem>());
|
||||
|
||||
|
||||
@@ -59,8 +59,8 @@ struct TransposePolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t LeadDimPerBlock = Problem::kLeadDimPerBlock;
|
||||
constexpr index_t SecondDimPerBlock = Problem::kSecondDimPerBlock;
|
||||
constexpr index_t LeadDimPerBlock = Problem::kLeadSizePerBlock;
|
||||
constexpr index_t SecondDimPerBlock = Problem::kSecondSizePerBlock;
|
||||
constexpr index_t VecLoadSize = 16 / sizeof(typename Problem::DataType);
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
@@ -75,8 +75,8 @@ struct TransposePolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t LeadDimPerBlock = Problem::kLeadDimPerBlock;
|
||||
constexpr index_t SecondDimPerBlock = Problem::kSecondDimPerBlock;
|
||||
constexpr index_t LeadDimPerBlock = Problem::kLeadSizePerBlock;
|
||||
constexpr index_t SecondDimPerBlock = Problem::kSecondSizePerBlock;
|
||||
constexpr index_t VecLoadSize = 16 / sizeof(typename Problem::DataType);
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
@@ -91,8 +91,8 @@ struct TransposePolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor()
|
||||
{
|
||||
//using Layout = remove_cvref_t<typename Problem::Layout>;
|
||||
constexpr index_t kLeadDimPerBlock = Problem::kLeadDimPerBlock;
|
||||
constexpr index_t kSecondDimPerBlock = Problem::kSecondDimPerBlock;
|
||||
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
|
||||
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
|
||||
constexpr index_t kVectorSize = 16 / sizeof(typename Problem::DataType);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
@@ -130,9 +130,9 @@ struct TransposePolicy
|
||||
decltype(detail::make_embed_tile_distribution_encoding(
|
||||
WarpLevelOuterDistribution_{}, QuartTransposeTileDistribution{}));
|
||||
constexpr index_t LeadDimIterPerWarp =
|
||||
Problem::kLeadDimPerBlock / (Problem::kLeadDimPerWarp * Problem::kLeadDimWarps);
|
||||
Problem::kLeadSizePerBlock / (Problem::kLeadSizePerWarp * Problem::kLeadSizeWarps);
|
||||
constexpr index_t SecondDimIterPerWarp =
|
||||
Problem::kSecondDimPerBlock / (Problem::kSecondDimPerWarp * Problem::kSecondDimWarps);
|
||||
Problem::kSecondSizePerBlock / (Problem::kSecondSizePerWarp * Problem::kSecondSizeWarps);
|
||||
|
||||
constexpr auto block_outer_dst_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
|
||||
Reference in New Issue
Block a user