fix 16x16 related dimension transpose

This commit is contained in:
joye
2025-04-27 02:44:49 -05:00
parent 57e8e34705
commit de9407ed93
3 changed files with 11 additions and 21 deletions

View File

@@ -42,16 +42,6 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con
return ave_time;
}
template <typename T,
ck_tile::index_t block_x,
ck_tile::index_t block_y,
ck_tile::index_t warp_x,
ck_tile::index_t warp_y>
static float transpose_fn(batched_transpose_kargs& a, ck_tile::stream_config& s)
{
return batched_transpose_dispatch<T, block_x, block_y, warp_x, warp_y>(a, s);
}
float batched_transpose(batched_transpose_trait t,
batched_transpose_kargs a,
ck_tile::stream_config s)

View File

@@ -48,8 +48,8 @@ struct BatchedTransposeKernel
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w;
size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h;
size_t grid_size_x = h.dim_block_w;
size_t grid_size_y = h.dim_block_h;
size_t grid_size_z = h.batch;
return dim3(grid_size_x, grid_size_y, grid_size_z);
}
@@ -107,14 +107,14 @@ struct BatchedTransposeKernel
auto x_block_window =
make_tile_window(x_m_n,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{static_cast<ck_tile::index_t>(iM * kMPerBlock),
static_cast<ck_tile::index_t>(iN * kNPerBlock)});
{static_cast<ck_tile::index_t>(iM),
static_cast<ck_tile::index_t>(iN)});
auto y_block_window =
make_tile_window(y_n_m,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
{static_cast<ck_tile::index_t>(iN * kNPerBlock),
static_cast<ck_tile::index_t>(iM * kMPerBlock)});
{static_cast<ck_tile::index_t>(iN),
static_cast<ck_tile::index_t>(iM)});
Pipeline{}(x_block_window, y_block_window, smem);
}

View File

@@ -33,7 +33,7 @@ struct QuartTransposeTraits<T, std::enable_if_t<sizeof(T) == 2>>
using TileDistributionT =
tile_distribution_encoding<sequence<>,
tuple<sequence<kOuterDist, 16>, sequence<kInnerDist, 4>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<2, 1, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>;
@@ -115,7 +115,7 @@ struct TransposePolicy
constexpr index_t kLeadNumWarps = Problem::kSecondNumWarps;
// transpose is based on 64 Bytes
constexpr index_t kLead =
Problem::kSecondSizePerXdl / Problem::kIterations; // Problem::kLeadSizePerXdl;
Problem::kSecondSizePerXdl; // Problem::kLeadSizePerXdl;
constexpr index_t kSecond = Problem::kLeadSizePerXdl;
constexpr index_t kLeadDimstr =
kLead / QuartTransposeTraits<typename Problem::DataType>::kleadDimT;
@@ -128,9 +128,9 @@ struct TransposePolicy
tile_distribution_encoding<sequence<>,
tuple<sequence<kSecondIterPerWarp, kSecondNumWarps>,
sequence<kLeadIterPerWarp, kLeadNumWarps>>,
tuple<sequence<1, 2>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto blk_distr_encode = detail::make_embed_tile_distribution_encoding(
block_outer_dst_encoding, xdllevel_dstr_encoding{});
@@ -205,7 +205,7 @@ struct TransposePolicy
{
// one xdl implement kSecond x kLead
constexpr index_t kLead = Problem::kLeadSizePerXdl;
constexpr index_t kSecond = Problem::kSecondSizePerXdl / Problem::kIterations;
constexpr index_t kSecond = Problem::kSecondSizePerXdl;
constexpr index_t kLeadDimstr =
kLead / QuartTransposeTraits<typename Problem::DataType>::kleadDim;
constexpr index_t kSecondDimstr =