fix some issues

This commit is contained in:
joye
2025-04-23 19:27:29 -05:00
parent acce2df3bf
commit c862437f05
3 changed files with 55 additions and 52 deletions

View File

@@ -71,8 +71,8 @@ struct BatchedTransposeKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
__shared__ char smem[Pipeline::GetSmemSize()];
static constexpr ck_tile::index_t kMPerBlock = Problem::kSecondDimWarps;
static constexpr ck_tile::index_t kNPerBlock = Problem::kLeadDimWarps;
static constexpr ck_tile::index_t kMPerBlock = Problem::kSecondSizePerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::kLeadSizePerBlock;
const auto iDim = blockIdx.z;
const auto x_m_n = [&]() {
@@ -88,8 +88,8 @@ struct BatchedTransposeKernel
sequence<false, false>{});
}();
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.y * kMPerBlock);
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.x * kNPerBlock);
const auto y_n_m = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(

View File

@@ -30,10 +30,10 @@ template <typename DataType_,
index_t kBlockSize_,
index_t kRowWarps_, // how many warps in row direction
index_t kColWarps_, // how many warps in col direction
index_t kRowPerBlock_,
index_t kColPerBlock_,
index_t kRowPerWarp_, // this is the row number per warp per iteration
index_t kColPerWarp_> // this is the col number per warp per iteration
index_t kRowPerBlock_, // row number per block
index_t kColPerBlock_, // col number per block
index_t kRowPerXdl_, // row number per xdl ops
index_t kColPerXdl_> // col number per xdl ops
struct TransposePipelineProblem
{
static_assert(kRowWarps_ * kColWarps_ * get_warp_size() == kBlockSize_,
@@ -41,18 +41,41 @@ struct TransposePipelineProblem
using DataType = remove_cvref_t<DataType_>;
using Layout = remove_cvref_t<Layout_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kLeadDimWarps =
static constexpr index_t kLeadNumWarps =
TransposeTraits<Layout, kRowWarps_, kColWarps_>::kLeadDim;
static constexpr index_t kSecondDimWarps =
static constexpr index_t kSecondNumWarps =
TransposeTraits<Layout, kRowWarps_, kColWarps_>::kSecondDim;
static constexpr index_t kLeadDimPerBlock =
static constexpr index_t kLeadSizePerBlock =
TransposeTraits<Layout, kRowPerBlock_, kColPerBlock_>::kLeadDim;
static constexpr index_t kSecondDimPerBlock =
static constexpr index_t kSecondSizePerBlock =
TransposeTraits<Layout, kRowPerBlock_, kColPerBlock_>::kSecondDim;
static constexpr index_t kLeadDimPerWarp =
TransposeTraits<Layout, kRowPerWarp_, kColPerWarp_>::kLeadDim;
static constexpr index_t kSecondDimPerWarp =
TransposeTraits<Layout, kRowPerWarp_, kColPerWarp_>::kSecondDim;
static constexpr index_t kLeadSizePerXdl =
TransposeTraits<Layout, kRowPerXdl_, kColPerXdl_>::kLeadDim;
static constexpr index_t kSecondSizePerXdl =
TransposeTraits<Layout, kRowPerXdl_, kColPerXdl_>::kSecondDim;
static constexpr index_t kQuadrantLeadDim = QuartTransposeTraits<DataType>::kleadDim;
static constexpr index_t kQuadrantSecondDim = QuartTransposeTraits<DataType>::ksecondDim;
static_assert(kLeadSizePerBlock % kLeadNumWarps == 0, "block dim should be divided by warp dim!");
static_assert(kSecondSizePerBlock % kSecondNumWarps == 0, "block dim should be divided by warp dim!");
static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps;
static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps;
static_assert(kLeadSizePerWarp % kLeadSizePerXdl == 0, "warp dim should be divided by xdl dim!");
static_assert(kSecondSizePerWarp % kSecondSizePerXdl == 0, "warp dim should be divided by xdl dim!");
static constexpr index_t kLeadXdlNumPerWarp = kLeadSizePerWarp / kLeadSizePerXdl;
static constexpr index_t kSecondXdlNumPerWarp = kSecondSizePerWarp / kSecondSizePerXdl;
static_assert(kLeadSizePerXdl % kQuadrantLeadDim == 0, "xdl dim should be divided by quad dim!");
static_assert(kSecondSizePerXdl % kQuadrantSecondDim == 0, "xdl dim should be divided by quad dim!");
static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerXdl / kQuadrantLeadDim;
static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerXdl / kQuadrantSecondDim;
static constexpr index_t kIterationsPerSecondDim = kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size();
};
template <typename Problem_, typename Policy_ = TransposePolicy>
@@ -65,32 +88,12 @@ struct BlockTranspose
using Layout = remove_cvref_t<typename Problem::Layout>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kLeadDimPerBlock = Problem::kLeadDimPerBlock;
static constexpr index_t kSecondDimPerBlock = Problem::kSecondDimPerBlock;
static constexpr index_t kLeadDimPerWarp = Problem::kLeadDimPerWarp;
static constexpr index_t kSecondDimPerWarp = Problem::kSecondDimPerWarp;
static constexpr index_t kQuadrantLeadDim = QuartTransposeTraits<DataType>::kleadDim;
static constexpr index_t kQuadrantSecondDim = QuartTransposeTraits<DataType>::ksecondDim;
static_assert(kLeadDimPerBlock % kLeadDimPerWarp == 0, "row per block is not correct!");
static_assert(kSecondDimPerBlock % kSecondDimPerWarp == 0, "col per block is not correct!");
static_assert(kLeadDimPerWarp % kQuadrantLeadDim == 0, "row per warp is not correct!");
static_assert(kSecondDimPerWarp % kQuadrantSecondDim == 0, "col per warp is not correct!");
static constexpr index_t kNumWarpInLeadDim = kLeadDimPerBlock / kLeadDimPerWarp;
static constexpr index_t kNumWarpInSecondDim = kSecondDimPerBlock / kSecondDimPerWarp;
static constexpr index_t kLeadDimPerWarpInQuadrant = kLeadDimPerWarp / kQuadrantLeadDim;
static constexpr index_t kSecondDimPerWarpInQuadrant = kSecondDimPerWarp / kQuadrantSecondDim;
// this pipeline is only designed for wave64 now
static_assert(get_warp_size() == 64, "the warp size is not correct!");
static_assert(kBlockSize == kNumWarpInLeadDim * kNumWarpInSecondDim * get_warp_size(),
"the block size is not correct!");
//static_assert(kLeadDimPerWarpInQuadrant * kSecondDimPerWarpInQuadrant * 4 == get_warp_size(),
// "the warp size is not correct!");
static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock;
static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock;
static constexpr index_t kLeadSizePerXdl = Problem::kLeadSizePerXdl;
static constexpr index_t kSecondSizePerXdl = Problem::kSecondSizePerXdl;
static constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps;
static constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps;
static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize<Problem>(); }
@@ -119,12 +122,12 @@ struct BlockTranspose
auto copy_to_lds_window =
make_tile_window(input_lds_block,
make_tuple(number<kSecondDimPerBlock>{}, number<kLeadDimPerBlock>{}),
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
{0, 0});
auto load_from_lds_window =
make_tile_window(output_lds_block,
make_tuple(number<kSecondDimPerBlock>{}, number<kLeadDimPerBlock>{}),
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
{0, 0},
Policy::template MakeLdsLoadTileDistribution<Problem>());

View File

@@ -59,8 +59,8 @@ struct TransposePolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t LeadDimPerBlock = Problem::kLeadDimPerBlock;
constexpr index_t SecondDimPerBlock = Problem::kSecondDimPerBlock;
constexpr index_t LeadDimPerBlock = Problem::kLeadSizePerBlock;
constexpr index_t SecondDimPerBlock = Problem::kSecondSizePerBlock;
constexpr index_t VecLoadSize = 16 / sizeof(typename Problem::DataType);
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
@@ -75,8 +75,8 @@ struct TransposePolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t LeadDimPerBlock = Problem::kLeadDimPerBlock;
constexpr index_t SecondDimPerBlock = Problem::kSecondDimPerBlock;
constexpr index_t LeadDimPerBlock = Problem::kLeadSizePerBlock;
constexpr index_t SecondDimPerBlock = Problem::kSecondSizePerBlock;
constexpr index_t VecLoadSize = 16 / sizeof(typename Problem::DataType);
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
@@ -91,8 +91,8 @@ struct TransposePolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor()
{
//using Layout = remove_cvref_t<typename Problem::Layout>;
constexpr index_t kLeadDimPerBlock = Problem::kLeadDimPerBlock;
constexpr index_t kSecondDimPerBlock = Problem::kSecondDimPerBlock;
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
constexpr index_t kVectorSize = 16 / sizeof(typename Problem::DataType);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
@@ -130,9 +130,9 @@ struct TransposePolicy
decltype(detail::make_embed_tile_distribution_encoding(
WarpLevelOuterDistribution_{}, QuartTransposeTileDistribution{}));
constexpr index_t LeadDimIterPerWarp =
Problem::kLeadDimPerBlock / (Problem::kLeadDimPerWarp * Problem::kLeadDimWarps);
Problem::kLeadSizePerBlock / (Problem::kLeadSizePerWarp * Problem::kLeadSizeWarps);
constexpr index_t SecondDimIterPerWarp =
Problem::kSecondDimPerBlock / (Problem::kSecondDimPerWarp * Problem::kSecondDimWarps);
Problem::kSecondSizePerBlock / (Problem::kSecondSizePerWarp * Problem::kSecondSizeWarps);
constexpr auto block_outer_dst_encoding = tile_distribution_encoding<
sequence<>,