Add an instance of load_tile_transpose that takes a reference to the output tensor as an input

This commit is contained in:
Sami Aario
2026-01-02 14:47:32 +00:00
parent 63a455952a
commit 8fc4030a57

View File

@@ -373,6 +373,7 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
@@ -380,18 +381,19 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param out_tensor A statically distributed tensor containing the transposed tile
* data.
* @param tile_window The tile window with static distribution to load and transpose.
* @param offset The offset (in elements) added to the base address before
* indexing.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
@@ -401,21 +403,17 @@ template <
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE auto load_tile_transpose_with_offset(
CK_TILE_DEVICE void load_tile_transpose_with_offset(
DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window,
index_t offset)
{
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{};
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
@@ -442,6 +440,32 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
number<iAccess>{},
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
});
}
template <
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE auto load_tile_transpose_with_offset(
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window,
index_t offset)
{
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
load_tile_transpose_with_offset(out_tensor, tile_window, offset);
return out_tensor;
}
@@ -455,6 +479,7 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
@@ -462,16 +487,37 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param out_tensor A statically distributed tensor containing the transposed tile
* data.
* @param tile_window The tile window with static distribution to load and transpose.
* indexing.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE void
load_tile_transpose(DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window)
{
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
}
template <
typename BottomTensorView_,
typename WindowLengths_,
@@ -488,7 +534,15 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
TileDistribution_,
NumCoord>& __restrict__ tile_window)
{
return load_tile_transpose_with_offset(tile_window, 0);
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
return out_tensor;
}
} // namespace ck_tile