mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 21:58:13 +00:00
fix 16x16 related dimension transpose
This commit is contained in:
@@ -42,16 +42,6 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
ck_tile::index_t block_x,
|
||||
ck_tile::index_t block_y,
|
||||
ck_tile::index_t warp_x,
|
||||
ck_tile::index_t warp_y>
|
||||
static float transpose_fn(batched_transpose_kargs& a, ck_tile::stream_config& s)
|
||||
{
|
||||
return batched_transpose_dispatch<T, block_x, block_y, warp_x, warp_y>(a, s);
|
||||
}
|
||||
|
||||
float batched_transpose(batched_transpose_trait t,
|
||||
batched_transpose_kargs a,
|
||||
ck_tile::stream_config s)
|
||||
|
||||
@@ -48,8 +48,8 @@ struct BatchedTransposeKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w;
|
||||
size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h;
|
||||
size_t grid_size_x = h.dim_block_w;
|
||||
size_t grid_size_y = h.dim_block_h;
|
||||
size_t grid_size_z = h.batch;
|
||||
return dim3(grid_size_x, grid_size_y, grid_size_z);
|
||||
}
|
||||
@@ -107,14 +107,14 @@ struct BatchedTransposeKernel
|
||||
auto x_block_window =
|
||||
make_tile_window(x_m_n,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iM * kMPerBlock),
|
||||
static_cast<ck_tile::index_t>(iN * kNPerBlock)});
|
||||
{static_cast<ck_tile::index_t>(iM),
|
||||
static_cast<ck_tile::index_t>(iN)});
|
||||
|
||||
auto y_block_window =
|
||||
make_tile_window(y_n_m,
|
||||
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
|
||||
{static_cast<ck_tile::index_t>(iN * kNPerBlock),
|
||||
static_cast<ck_tile::index_t>(iM * kMPerBlock)});
|
||||
{static_cast<ck_tile::index_t>(iN),
|
||||
static_cast<ck_tile::index_t>(iM)});
|
||||
|
||||
Pipeline{}(x_block_window, y_block_window, smem);
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ struct QuartTransposeTraits<T, std::enable_if_t<sizeof(T) == 2>>
|
||||
using TileDistributionT =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kOuterDist, 16>, sequence<kInnerDist, 4>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<2, 1, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
@@ -115,7 +115,7 @@ struct TransposePolicy
|
||||
constexpr index_t kLeadNumWarps = Problem::kSecondNumWarps;
|
||||
// transpose is based on 64 Bytes
|
||||
constexpr index_t kLead =
|
||||
Problem::kSecondSizePerXdl / Problem::kIterations; // Problem::kLeadSizePerXdl;
|
||||
Problem::kSecondSizePerXdl; // Problem::kLeadSizePerXdl;
|
||||
constexpr index_t kSecond = Problem::kLeadSizePerXdl;
|
||||
constexpr index_t kLeadDimstr =
|
||||
kLead / QuartTransposeTraits<typename Problem::DataType>::kleadDimT;
|
||||
@@ -128,9 +128,9 @@ struct TransposePolicy
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kSecondIterPerWarp, kSecondNumWarps>,
|
||||
sequence<kLeadIterPerWarp, kLeadNumWarps>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto blk_distr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
block_outer_dst_encoding, xdllevel_dstr_encoding{});
|
||||
@@ -205,7 +205,7 @@ struct TransposePolicy
|
||||
{
|
||||
// one xdl implement kSecond x kLead
|
||||
constexpr index_t kLead = Problem::kLeadSizePerXdl;
|
||||
constexpr index_t kSecond = Problem::kSecondSizePerXdl / Problem::kIterations;
|
||||
constexpr index_t kSecond = Problem::kSecondSizePerXdl;
|
||||
constexpr index_t kLeadDimstr =
|
||||
kLead / QuartTransposeTraits<typename Problem::DataType>::kleadDim;
|
||||
constexpr index_t kSecondDimstr =
|
||||
|
||||
Reference in New Issue
Block a user