mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Add an instance of load_tile_transpose that takes a reference to the output tensor as an input
This commit is contained in:
@@ -373,6 +373,7 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
|
||||
* element space size and vector length remain consistent between the input and output
|
||||
* distributions.
|
||||
*
|
||||
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
|
||||
* @tparam BottomTensorView_ The type of the bottom tensor view.
|
||||
* @tparam WindowLengths_ The type representing the window lengths.
|
||||
* @tparam TileDistribution_ The type representing the tile distribution.
|
||||
@@ -380,18 +381,19 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
|
||||
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param out_tensor A statically distributed tensor containing the transposed tile
|
||||
* data.
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
* @param offset The offset (in elements) added to the base address before
|
||||
* indexing.
|
||||
*
|
||||
* @return A statically distributed tensor containing the transposed tile data.
|
||||
*
|
||||
* @note
|
||||
* - The function uses compile-time checks to ensure the input and output tile distributions
|
||||
* are compatible in terms of element space size and vector length.
|
||||
* - The transpose operation is performed according to the specified Policy.
|
||||
*/
|
||||
template <
|
||||
typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
@@ -401,21 +403,17 @@ template <
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
CK_TILE_DEVICE void load_tile_transpose_with_offset(
|
||||
DistributedTensor_& out_tensor,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window,
|
||||
index_t offset)
|
||||
{
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
typename BottomTensorView_::DataType>::TransposedDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
|
||||
constexpr auto input_distr = TileDistribution_{};
|
||||
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
|
||||
constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{};
|
||||
|
||||
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
|
||||
@@ -442,6 +440,32 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
number<iAccess>{},
|
||||
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
|
||||
});
|
||||
}
|
||||
|
||||
template <
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
|
||||
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window,
|
||||
index_t offset)
|
||||
{
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
typename BottomTensorView_::DataType>::TransposedDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
|
||||
load_tile_transpose_with_offset(out_tensor, tile_window, offset);
|
||||
|
||||
return out_tensor;
|
||||
}
|
||||
@@ -455,6 +479,7 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
* element space size and vector length remain consistent between the input and output
|
||||
* distributions.
|
||||
*
|
||||
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
|
||||
* @tparam BottomTensorView_ The type of the bottom tensor view.
|
||||
* @tparam WindowLengths_ The type representing the window lengths.
|
||||
* @tparam TileDistribution_ The type representing the tile distribution.
|
||||
@@ -462,16 +487,37 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param out_tensor A statically distributed tensor containing the transposed tile
|
||||
* data.
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
* indexing.
|
||||
*
|
||||
* @return A statically distributed tensor containing the transposed tile data.
|
||||
*
|
||||
* @note
|
||||
* - The function uses compile-time checks to ensure the input and output tile distributions
|
||||
* are compatible in terms of element space size and vector length.
|
||||
* - The transpose operation is performed according to the specified Policy.
|
||||
*/
|
||||
template <
|
||||
typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
|
||||
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE void
|
||||
load_tile_transpose(DistributedTensor_& out_tensor,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window)
|
||||
{
|
||||
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
|
||||
}
|
||||
|
||||
template <
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
@@ -488,7 +534,15 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window)
|
||||
{
|
||||
return load_tile_transpose_with_offset(tile_window, 0);
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
typename BottomTensorView_::DataType>::TransposedDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
|
||||
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
|
||||
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user