Files
composable_kernel/example/ck_tile/37_transpose/transpose_policy.hpp
joyeamd a2f01141aa transpose load api development (#2177)
* add transpose load; no real logic

* fix some compile errors

* fix some issues

* update transpose load logic

* add some fixes

* fix a distribution issue

* update some codes

* add some fix

* can pass; but no logic

* transpose load enable

* update tile transpose

* miss output tile distribution mapping

* hack for transpose 16x16

* update output tensor distribution

* delete unused variables

* fix transpose related codes

* update transpose load example

* exchange the iteration order

* fix 16x16 related dimension transpose

* fix a transpose index issue

* fix a transpose index issue

* fix clang format check

* update load tile transpose related codes

* fix compile errors and pass 16x16 tests

* fix a typo

* update logic

* check other data types

* add transpose load api

* update transpose load api

* fix clang format check

* change file name

* refactor codes

* update code name

* delete some unused codes

* delete the unused oob flag for transpose load

* update tensor view api for transpose load

* update for testing

* fix a typo error

* move transpose ops to example directory

* update transpose api

* update include file

* fix for pr review

* fix compile errors

* add transpose load; no real logic

* fix some compile errors

* fix some issues

* update transpose load logic

* add some fixes

* fix a distribution issue

* update some codes

* add some fix

* can pass; but no logic

* transpose load enable

* update tile transpose

* miss output tile distribution mapping

* hack for transpose 16x16

* update output tensor distribution

* delete unused variables

* fix transpose related codes

* update transpose load example

* exchange the iteration order

* fix 16x16 related dimension transpose

* fix a transpose index issue

* fix a transpose index issue

* fix clang format check

* update load tile transpose related codes

* fix compile errors and pass 16x16 tests

* fix a typo

* update logic

* check other data types

* add transpose load api

* update transpose load api

* fix clang format check

* change file name

* refactor codes

* update code name

* delete some unused codes

* delete the unused oob flag for transpose load

* update tensor view api for transpose load

* update for testing

* fix a typo error

* move transpose ops to example directory

* update transpose api

* update include file

* fix for pr review

* fix compile errors

* change directory name

* delete the duplicated directory

* update cmakelists file

* delete the unused codes

* update function names

* update transpose policy

* update code after remod.py

* update codes

* add some comment

* Polish the instr infrastructure

* build up the fixed instr

* redesign the transpose api, currently it has numerical error

* add the bf16 transpose

* fix some issues

* add some comments

* update document

* Finished the refactor of API and pass through the verification

* fix the merging issue

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
2025-06-18 01:28:34 -07:00

152 lines
6.7 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
struct TransposePolicy
{
static constexpr auto TileAccessPattern = tile_distribution_pattern::thread_raked;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSize()
{
return 16 / sizeof(typename Problem::DataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return integer_least_multiple(
sizeof(typename Problem::DataType) *
MakeLdsStoreBlockDescriptor<Problem>().get_element_space_size(),
16);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t LeadDimPerBlock = Problem::kLeadSizePerBlock;
constexpr index_t SecondDimPerBlock = Problem::kSecondSizePerBlock;
constexpr index_t VecLoadSize = 16 / sizeof(typename Problem::DataType);
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
SecondDimPerBlock,
LeadDimPerBlock,
VecLoadSize,
TileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
{
constexpr auto input_dstr = MakeLdsLoadTileDistribution<Problem>();
using OutTileDstrEncode =
typename OutputTileDistributionTraits<remove_cvref_t<decltype(input_dstr)>,
typename Problem::DataType>::OutDstrEncode;
constexpr auto block_dstr = make_static_tile_distribution(OutTileDstrEncode{});
return block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor()
{
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
constexpr index_t kVectorSize = 16 / sizeof(typename Problem::DataType);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kSecondDimPerBlock>{},
number<kLeadDimPerBlock / kVectorSize>{},
number<kVectorSize>{}),
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
number<kVectorSize>{},
number<1>{});
constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
number<kVectorSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor()
{
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
constexpr index_t kVectorSize = 8 / sizeof(typename Problem::DataType);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kSecondDimPerBlock>{},
number<kLeadDimPerBlock / kVectorSize>{},
number<kVectorSize>{}),
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
number<kVectorSize>{},
number<1>{});
constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
number<kVectorSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadTileDistribution()
{
using DataType = typename Problem::DataType;
// Extract base dimensions from the traits
constexpr index_t kBaseLeadDim = LaneGroupTransposeTraits<DataType>::kleadDim;
constexpr index_t kBaseSecondDim = LaneGroupTransposeTraits<DataType>::ksecondDim;
// Calculate block-level dimensions
constexpr index_t kLead = Problem::kLeadSizePerXdl;
constexpr index_t kSecond = Problem::kSecondSizePerXdl;
constexpr index_t kLeadIterPerWarp = Problem::kLeadXdlNumPerWarp;
constexpr index_t kSecondIterPerWarp = Problem::kSecondXdlNumPerWarp;
constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps;
constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps;
// Calculate repetitions of base pattern
constexpr index_t kLeadRepetitions = kLead / kBaseLeadDim;
constexpr index_t kSecondRepetitions = kSecond / kBaseSecondDim;
constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim;
constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations;
constexpr auto xdllevel_dstr_encoding = make_transposed_distr_encode<DataType,
kSecondDimStrSub,
kSecondDimIterations,
kLeadRepetitions,
1>();
constexpr auto input_tile_encode =
InputTileDistributionEncoding<decltype(xdllevel_dstr_encoding),
kLeadIterPerWarp,
kSecondIterPerWarp,
kLeadNumWarps,
kSecondNumWarps>();
constexpr auto block_dstr = make_static_tile_distribution(input_tile_encode);
return block_dstr;
}
};
} // namespace ck_tile