diff --git a/example/ck_tile/36_transpose/transpose_api.cpp b/example/ck_tile/36_transpose/transpose_api.cpp index c9cf60c4e6..61f02f768f 100644 --- a/example/ck_tile/36_transpose/transpose_api.cpp +++ b/example/ck_tile/36_transpose/transpose_api.cpp @@ -43,7 +43,7 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con } // Param Comb: type_size, block_x & y, warp_x & y, thread_x & y -#define FOREACH_TRANSPOSE_PARAM(F) F(fp16, ck_tile::fp16_t, 16, 32, 16, 32) +#define FOREACH_TRANSPOSE_PARAM(F) F(fp16, ck_tile::fp16_t, 16, 16, 16, 16) // Macro that defines one static function per line #define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY) \ @@ -61,7 +61,7 @@ float batched_transpose(batched_transpose_trait t, { if(t.type == "fp16") { - return transpose_fn_fp16_16_32_16_32(a, s); + return transpose_fn_fp16_16_16_16_16(a, s); } return -1; } diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 88efe0d8d9..dacae32053 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -19,3 +19,4 @@ add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) add_subdirectory(18_flatmm) add_subdirectory(35_batched_transpose) +add_subdirectory(36_transpose) diff --git a/include/ck_tile/ops/transpose/batched_transpose_kernel.hpp b/include/ck_tile/ops/transpose/batched_transpose_kernel.hpp index 26145c39ab..33b6d7b72c 100644 --- a/include/ck_tile/ops/transpose/batched_transpose_kernel.hpp +++ b/include/ck_tile/ops/transpose/batched_transpose_kernel.hpp @@ -31,7 +31,7 @@ struct BatchedTransposeKernel using Pipeline = remove_cvref_t; using Problem = remove_cvref_t; - using Type = typename Problem::InputType; + using Type = typename Problem::DataType; struct BatchedTransposeKargs { diff --git a/include/ck_tile/ops/transpose/block_transpose.hpp b/include/ck_tile/ops/transpose/block_transpose.hpp index ee351a40d2..02d458552e 100644 --- a/include/ck_tile/ops/transpose/block_transpose.hpp +++ b/include/ck_tile/ops/transpose/block_transpose.hpp @@ -13,7 +13,7 @@ struct TransposeTraits { static constexpr index_t kLeadDim = kCol; static constexpr index_t kSecondDim = kRow; -} +}; template struct TransposeTraits @@ -46,13 +46,13 @@ struct TransposePipelineProblem static constexpr index_t kSecondDimWarps = TransposeTraits::kSecondDim; static constexpr index_t kLeadDimPerBlock = - TransposeTraits::kLeadDim; + TransposeTraits::kLeadDim; static constexpr index_t kSecondDimPerBlock = - TransposeTraits::kSecondDim; + TransposeTraits::kSecondDim; static constexpr index_t kLeadDimPerWarp = - TransposeTraits::kLeadDim; + TransposeTraits::kLeadDim; static constexpr index_t kSecondDimPerWarp = - TransposeTraits::kSecondDim; + TransposeTraits::kSecondDim; }; template @@ -89,10 +89,10 @@ struct BlockTranspose static_assert(get_warp_size() == 64, "the warp size is not correct!"); static_assert(kBlockSize == kNumWarpInLeadDim * kNumWarpInSecondDim * get_warp_size(), "the block size is not correct!"); - static_assert(kLeadDimPerWarpInQuadrant * kSecondDimPerWarpInQuadrant * 4 == get_warp_size(), - "the warp size is not correct!"); + //static_assert(kLeadDimPerWarpInQuadrant * kSecondDimPerWarpInQuadrant * 4 == get_warp_size(), + // "the warp size is not correct!"); - static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize(); } + static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize(); } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { @@ -123,17 +123,17 @@ struct BlockTranspose {0, 0}); auto load_from_lds_window = - make_tile_window(out_lds_block, + make_tile_window(output_lds_block, make_tuple(number{}, number{}), {0, 0}, - Policy::MakeLdsLoadTileDistribution()); + Policy::template MakeLdsLoadTileDistribution()); auto x = load_tile(input_tile_window); store_tile(copy_to_lds_window, x); block_sync_lds(); - auto y = load_tile_transpose(lds_tile_window); + auto y = load_tile_transpose(load_from_lds_window); store_tile(output_tile_window, y); } }; diff --git a/include/ck_tile/ops/transpose/transpose_policy.hpp b/include/ck_tile/ops/transpose/transpose_policy.hpp index c7814d6855..3175c883f9 100644 --- a/include/ck_tile/ops/transpose/transpose_policy.hpp +++ b/include/ck_tile/ops/transpose/transpose_policy.hpp @@ -51,7 +51,7 @@ struct TransposePolicy { return integer_least_multiple( sizeof(typename Problem::DataType) * - MakeLdsStoreBlockDescriptor::get_element_space_size(), + MakeLdsStoreBlockDescriptor().get_element_space_size(), 16); } @@ -90,7 +90,7 @@ struct TransposePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor() { - using Layout = remove_cvref_t; + //using Layout = remove_cvref_t; constexpr index_t kLeadDimPerBlock = Problem::kLeadDimPerBlock; constexpr index_t kSecondDimPerBlock = Problem::kSecondDimPerBlock; constexpr index_t kVectorSize = 16 / sizeof(typename Problem::DataType); @@ -123,9 +123,9 @@ struct TransposePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadTileDistribution() { - using Layout = remove_cvref_t; + //using Layout = remove_cvref_t; using QuartTransposeTileDistribution = - QuartTransposeTraits::TileDistribution; + typename QuartTransposeTraits::TileDistribution; using WarpTransposeTileDistribution = decltype(detail::make_embed_tile_distribution_encoding( WarpLevelOuterDistribution_{}, QuartTransposeTileDistribution{}));