diff --git a/example/ck_tile/36_transpose/transpose_api.cpp b/example/ck_tile/36_transpose/transpose_api.cpp index aa0265bbf1..de4de298a0 100644 --- a/example/ck_tile/36_transpose/transpose_api.cpp +++ b/example/ck_tile/36_transpose/transpose_api.cpp @@ -42,16 +42,6 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con return ave_time; } -template -static float transpose_fn(batched_transpose_kargs& a, ck_tile::stream_config& s) -{ - return batched_transpose_dispatch(a, s); -} - float batched_transpose(batched_transpose_trait t, batched_transpose_kargs a, ck_tile::stream_config s) diff --git a/include/ck_tile/ops/transpose/batched_transpose_kernel.hpp b/include/ck_tile/ops/transpose/batched_transpose_kernel.hpp index 24e2492ce5..28d1693b11 100644 --- a/include/ck_tile/ops/transpose/batched_transpose_kernel.hpp +++ b/include/ck_tile/ops/transpose/batched_transpose_kernel.hpp @@ -48,8 +48,8 @@ struct BatchedTransposeKernel CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { - size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w; - size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h; + size_t grid_size_x = h.dim_block_w; + size_t grid_size_y = h.dim_block_h; size_t grid_size_z = h.batch; return dim3(grid_size_x, grid_size_y, grid_size_z); } @@ -107,14 +107,14 @@ struct BatchedTransposeKernel auto x_block_window = make_tile_window(x_m_n, make_tuple(number{}, number{}), - {static_cast(iM * kMPerBlock), - static_cast(iN * kNPerBlock)}); + {static_cast(iM), + static_cast(iN)}); auto y_block_window = make_tile_window(y_n_m, make_tuple(number{}, number{}), - {static_cast(iN * kNPerBlock), - static_cast(iM * kMPerBlock)}); + {static_cast(iN), + static_cast(iM)}); Pipeline{}(x_block_window, y_block_window, smem); } diff --git a/include/ck_tile/ops/transpose/transpose_policy.hpp b/include/ck_tile/ops/transpose/transpose_policy.hpp index 91c3f7ce03..ff65231924 100644 --- a/include/ck_tile/ops/transpose/transpose_policy.hpp +++ b/include/ck_tile/ops/transpose/transpose_policy.hpp @@ -33,7 +33,7 @@ struct QuartTransposeTraits> using TileDistributionT = tile_distribution_encoding, tuple, sequence>, - tuple>, + tuple>, tuple>, sequence<2>, sequence<1>>; @@ -115,7 +115,7 @@ struct TransposePolicy constexpr index_t kLeadNumWarps = Problem::kSecondNumWarps; // transpose is based on 64 Bytes constexpr index_t kLead = - Problem::kSecondSizePerXdl / Problem::kIterations; // Problem::kLeadSizePerXdl; + Problem::kSecondSizePerXdl; // Problem::kLeadSizePerXdl; constexpr index_t kSecond = Problem::kLeadSizePerXdl; constexpr index_t kLeadDimstr = kLead / QuartTransposeTraits::kleadDimT; @@ -128,9 +128,9 @@ struct TransposePolicy tile_distribution_encoding, tuple, sequence>, - tuple>, + tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto blk_distr_encode = detail::make_embed_tile_distribution_encoding( block_outer_dst_encoding, xdllevel_dstr_encoding{}); @@ -205,7 +205,7 @@ struct TransposePolicy { // one xdl implement kSecond x kLead constexpr index_t kLead = Problem::kLeadSizePerXdl; - constexpr index_t kSecond = Problem::kSecondSizePerXdl / Problem::kIterations; + constexpr index_t kSecond = Problem::kSecondSizePerXdl; constexpr index_t kLeadDimstr = kLead / QuartTransposeTraits::kleadDim; constexpr index_t kSecondDimstr =