diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index 0ac2ded5f6..0266fc653f 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -373,6 +373,7 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding() * element space size and vector length remain consistent between the input and output * distributions. * + * @tparam DistributedTensor_ The type of the tensor containing the transposed tile data. * @tparam BottomTensorView_ The type of the bottom tensor view. * @tparam WindowLengths_ The type representing the window lengths. * @tparam TileDistribution_ The type representing the tile distribution. @@ -380,18 +381,19 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding() * @tparam Policy The transpose policy to use (defaults to DefaultTranspose). * the last is SFINAE to ensure the tile distribution encoding is valid. * + * @param out_tensor A statically distributed tensor containing the transposed tile + * data. * @param tile_window The tile window with static distribution to load and transpose. * @param offset The offset (in elements) added to the base address before * indexing. * - * @return A statically distributed tensor containing the transposed tile data. - * * @note * - The function uses compile-time checks to ensure the input and output tile distributions * are compatible in terms of element space size and vector length. * - The transpose operation is performed according to the specified Policy. */ template < + typename DistributedTensor_, typename BottomTensorView_, typename WindowLengths_, typename TileDistribution_, @@ -401,21 +403,17 @@ template < typename BottomTensorView_::DataType, Policy>::distr_encoding_valid, Policy>> -CK_TILE_DEVICE auto load_tile_transpose_with_offset( +CK_TILE_DEVICE void load_tile_transpose_with_offset( + DistributedTensor_& out_tensor, const tile_window_with_static_distribution& __restrict__ tile_window, index_t offset) { - using OutTileDstrEncode = typename OutputTileDistributionTraits< - typename TileDistribution_::DstrEncode, - typename BottomTensorView_::DataType>::TransposedDstrEncode; - auto out_tensor = make_static_distributed_tensor( - make_static_tile_distribution(OutTileDstrEncode{})); auto trans_tensor = tile_window.template load_transpose_with_offset(offset); constexpr auto input_distr = TileDistribution_{}; - constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{}); + constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{}; constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor(); constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor(); @@ -442,6 +440,32 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( number{}, trans_tensor.get_thread_buffer().template get_as(number{})); }); +} + +template < + typename BottomTensorView_, + typename WindowLengths_, + typename TileDistribution_, + index_t NumCoord, + typename Policy = DefaultTranspose, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE auto load_tile_transpose_with_offset( + const tile_window_with_static_distribution& __restrict__ tile_window, + index_t offset) +{ + using OutTileDstrEncode = typename OutputTileDistributionTraits< + typename TileDistribution_::DstrEncode, + typename BottomTensorView_::DataType>::TransposedDstrEncode; + auto out_tensor = make_static_distributed_tensor( + make_static_tile_distribution(OutTileDstrEncode{})); + + load_tile_transpose_with_offset(out_tensor, tile_window, offset); return out_tensor; } @@ -455,6 +479,7 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( * element space size and vector length remain consistent between the input and output * distributions. * + * @tparam DistributedTensor_ The type of the tensor containing the transposed tile data. * @tparam BottomTensorView_ The type of the bottom tensor view. * @tparam WindowLengths_ The type representing the window lengths. * @tparam TileDistribution_ The type representing the tile distribution. @@ -462,16 +487,37 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( * @tparam Policy The transpose policy to use (defaults to DefaultTranspose). * the last is SFINAE to ensure the tile distribution encoding is valid. * + * @param out_tensor A statically distributed tensor containing the transposed tile + * data. * @param tile_window The tile window with static distribution to load and transpose. * indexing. * - * @return A statically distributed tensor containing the transposed tile data. - * * @note * - The function uses compile-time checks to ensure the input and output tile distributions * are compatible in terms of element space size and vector length. * - The transpose operation is performed according to the specified Policy. */ +template < + typename DistributedTensor_, + typename BottomTensorView_, + typename WindowLengths_, + typename TileDistribution_, + index_t NumCoord, + typename Policy = DefaultTranspose, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE void +load_tile_transpose(DistributedTensor_& out_tensor, + const tile_window_with_static_distribution& __restrict__ tile_window) +{ + load_tile_transpose_with_offset(out_tensor, tile_window, 0); +} + template < typename BottomTensorView_, typename WindowLengths_, @@ -488,7 +534,15 @@ load_tile_transpose(const tile_window_with_static_distribution& __restrict__ tile_window) { - return load_tile_transpose_with_offset(tile_window, 0); + using OutTileDstrEncode = typename OutputTileDistributionTraits< + typename TileDistribution_::DstrEncode, + typename BottomTensorView_::DataType>::TransposedDstrEncode; + auto out_tensor = make_static_distributed_tensor( + make_static_tile_distribution(OutTileDstrEncode{})); + + load_tile_transpose_with_offset(out_tensor, tile_window, 0); + + return out_tensor; } } // namespace ck_tile