diff --git a/example/ck_tile/37_transpose/CMakeLists.txt b/example/ck_tile/37_transpose/CMakeLists.txt new file mode 100644 index 0000000000..d6f374a9b4 --- /dev/null +++ b/example/ck_tile/37_transpose/CMakeLists.txt @@ -0,0 +1,9 @@ +set(TARGET_NAME tile_example_transpose) +add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL transpose_example.cpp transpose_api.cpp) +target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) +target_compile_options(tile_example_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS}) + diff --git a/example/ck_tile/37_transpose/README.md b/example/ck_tile/37_transpose/README.md new file mode 100644 index 0000000000..21578dd00e --- /dev/null +++ b/example/ck_tile/37_transpose/README.md @@ -0,0 +1,27 @@ +# Batched Transpose +This folder contains example for transpose load for architecture gfx950. This transpose load has some constraints in input tile distribution. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# Make the transpose executable +make tile_example_transpose -j +``` +This will result in an executable `build/bin/tile_example_transpose` + +## example +``` +args: + -N input batch size (default:2) + -C input channel size. (default:64) + -H input height size. (default:1) + -W input width size. (default:64) + -v whether do CPU validation or not (default: 1) + -layout_in input tensor data layout - NCHW by default + -layout_out output tensor data layout - NHWC by default + -seed seed to be used, -1 means random every time (default:-1) + -k_name t to 1 will print kernel name (default:0) +``` \ No newline at end of file diff --git a/example/ck_tile/37_transpose/batched_transpose_kernel.hpp b/example/ck_tile/37_transpose/batched_transpose_kernel.hpp new file mode 100644 index 0000000000..4681a12cf7 --- /dev/null +++ b/example/ck_tile/37_transpose/batched_transpose_kernel.hpp @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include +#include + +namespace ck_tile { + +struct BatchedTransposeHostArgs +{ + const void* p_input; + void* p_output; + index_t batch; + index_t height; + index_t width; + // index_t dim_blocks; + index_t dim_stride; + index_t dim_block_h; + index_t dim_block_w; +}; + +template +struct BatchedTransposeKernel +{ + using Pipeline = remove_cvref_t; + using Problem = remove_cvref_t; + + using Type = typename Problem::DataType; + + struct BatchedTransposeKargs + { + const void* p_input; + void* p_output; + index_t batch; + index_t height; + index_t width; + index_t dim_stride; + }; + + using Kargs = BatchedTransposeKargs; + using Hargs = BatchedTransposeHostArgs; + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) + { + size_t grid_size_x = h.dim_block_w; + size_t grid_size_y = h.dim_block_h; + size_t grid_size_z = h.batch; + return dim3(grid_size_x, grid_size_y, grid_size_z); + } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_input = h.p_input; + k.p_output = h.p_output; + k.batch = h.batch; + k.height = h.height; + k.width = h.width; + k.dim_stride = h.dim_stride; + return k; + } + + CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + __shared__ char smem[Pipeline::GetSmemSize()]; + static constexpr ck_tile::index_t kMPerBlock = Problem::kSecondSizePerBlock; + static constexpr ck_tile::index_t kNPerBlock = Problem::kLeadSizePerBlock; + + const auto iDim = blockIdx.z; + const auto x_m_n = [&]() { + const auto x_dram_naive = make_naive_tensor_view( + static_cast(kargs.p_input) + iDim * kargs.dim_stride, + make_tuple(kargs.height, kargs.width), + make_tuple(kargs.width, 1), + number{}, + number<1>{}); + + return pad_tensor_view(x_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.y * kMPerBlock); + const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.x * kNPerBlock); + + const auto y_n_m = [&]() { + const auto y_dram_naive = make_naive_tensor_view( + static_cast(kargs.p_output) + iDim * kargs.dim_stride, + make_tuple(kargs.width, kargs.height), + make_tuple(kargs.height, 1), + number{}, + number<1>{}); + + return pad_tensor_view(y_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto x_block_window = make_tile_window( + x_m_n, + make_tuple(number{}, number{}), + {static_cast(iM), static_cast(iN)}); + + auto y_block_window = make_tile_window( + y_n_m, + make_tuple(number{}, number{}), + {static_cast(iN), static_cast(iM)}); + + Pipeline{}(x_block_window, y_block_window, smem); + } +}; +} // namespace ck_tile diff --git a/example/ck_tile/37_transpose/block_transpose.hpp b/example/ck_tile/37_transpose/block_transpose.hpp new file mode 100644 index 0000000000..5c0baab846 --- /dev/null +++ b/example/ck_tile/37_transpose/block_transpose.hpp @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "transpose_policy.hpp" + +namespace ck_tile { + +template +struct TransposeTraits +{ + static constexpr index_t kLeadDim = kCol; + static constexpr index_t kSecondDim = kRow; +}; + +template +struct TransposeTraits +{ + static constexpr index_t kLeadDim = kRow; + static constexpr index_t kSecondDim = kCol; +}; + +// supports 2D transpose which will store to lds, then use ds_read_b*_tr_b* instruction to get the +// transposed data; Layout in TransposePipelineProblem is the original layout of the data in the +// global memory +template // col number per xdl ops +struct TransposePipelineProblem +{ + static_assert(kRowWarps_ * kColWarps_ * get_warp_size() == kBlockSize_, + "the block size is not correct!"); + using DataType = remove_cvref_t; + using Layout = remove_cvref_t; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kLeadNumWarps = + TransposeTraits::kLeadDim; + static constexpr index_t kSecondNumWarps = + TransposeTraits::kSecondDim; + static constexpr index_t kLeadSizePerBlock = + TransposeTraits::kLeadDim; + static constexpr index_t kSecondSizePerBlock = + TransposeTraits::kSecondDim; + static constexpr index_t kLeadSizePerXdl = + TransposeTraits::kLeadDim; + static constexpr index_t kSecondSizePerXdl = + TransposeTraits::kSecondDim; + + static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits::kleadDim; + static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits::ksecondDim; + + static_assert(kLeadSizePerBlock % kLeadNumWarps == 0, + "block dim should be divided by warp dim!"); + static_assert(kSecondSizePerBlock % kSecondNumWarps == 0, + "block dim should be divided by warp dim!"); + // how many rows/cols implemented in one warp + static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps; + static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps; + + static_assert(kLeadSizePerWarp % kLeadSizePerXdl == 0, + "warp dim should be divided by xdl dim!"); + static_assert(kSecondSizePerWarp % kSecondSizePerXdl == 0, + "warp dim should be divided by xdl dim!"); + + // warp rows/cols is divided into xdl. + static constexpr index_t kLeadXdlNumPerWarp = kLeadSizePerWarp / kLeadSizePerXdl; + static constexpr index_t kSecondXdlNumPerWarp = kSecondSizePerWarp / kSecondSizePerXdl; + + static_assert(kLeadSizePerXdl % kQuadrantLeadDim == 0, + "xdl dim should be divided by quad dim!"); + static_assert(kSecondSizePerXdl % kQuadrantSecondDim == 0, + "xdl dim should be divided by quad dim!"); + // xdl rows/cols is divided into quadrants. + static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerXdl / kQuadrantLeadDim; + static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerXdl / kQuadrantSecondDim; + + static constexpr index_t kIterationsInSecondDim = + kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size(); +}; + +template +struct BlockTranspose +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using DataType = remove_cvref_t; + using Layout = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock; + static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock; + + static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize(); } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE void operator()(const InputTileWindow& input_window, + OutputTileWindow& output_window, + void* __restrict__ p_smem) + { + auto input_tile_window = + make_tile_window(input_window, Policy::template MakeInputDistribution()); + auto output_tile_window = + make_tile_window(output_window, Policy::template MakeOutputDistribution()); + + DataType* p_lds_ptr = static_cast(p_smem); + constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor(); + auto input_lds_block = + make_tensor_view(p_lds_ptr, in_lds_block_desc); + + constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor(); + auto output_lds_block = + make_tensor_view(p_lds_ptr, out_lds_block_desc); + + auto copy_to_lds_window = + make_tile_window(input_lds_block, + make_tuple(number{}, number{}), + {0, 0}); + auto load_from_lds_window = + make_tile_window(output_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeLdsLoadTileDistribution()); + + auto x = load_tile(input_tile_window); + + store_tile(copy_to_lds_window, x); + block_sync_lds(); + + auto y = load_tile_transpose(load_from_lds_window); + + store_tile(output_tile_window, y); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/37_transpose/transpose_api.cpp b/example/ck_tile/37_transpose/transpose_api.cpp new file mode 100644 index 0000000000..fe184b4023 --- /dev/null +++ b/example/ck_tile/37_transpose/transpose_api.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "transpose_example.hpp" +#include + +template +float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) +{ + uint32_t dim_block_h = (a.height + block_y - 1) / block_y; + uint32_t dim_block_w = (a.width + block_x - 1) / block_x; + uint32_t dim_stride = a.height * a.width; + + a.dim_stride = dim_stride; + a.dim_block_h = dim_block_h; + a.dim_block_w = dim_block_w; + + using ts_problem = ck_tile::TransposePipelineProblem; + using ts_pipeline = ck_tile::BlockTranspose; + + using kernel = ck_tile::BatchedTransposeKernel; + + auto kargs = kernel::MakeKargs(a); + + const dim3 grids = kernel::GridSize(a); + constexpr dim3 blocks = kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); + + return ave_time; +} + +float batched_transpose(batched_transpose_trait t, + batched_transpose_kargs a, + ck_tile::stream_config s) +{ + if(t.type == "fp16") + { + return batched_transpose_dispatch(a, s); + } + else if(t.type == "fp8") + { + return batched_transpose_dispatch(a, s); + } + + return -1; +} diff --git a/example/ck_tile/37_transpose/transpose_example.cpp b/example/ck_tile/37_transpose/transpose_example.cpp new file mode 100644 index 0000000000..ac27ca7911 --- /dev/null +++ b/example/ck_tile/37_transpose/transpose_example.cpp @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "transpose_example.hpp" + +#if 0 +template +void dump_host_tensor_4d(const ck_tile::HostTensor& x) +{ + auto len = x.get_lengths(); + assert(len.size() == 4); + std::cout << "["; + for(size_t i = 0; i < len[0]; i++) + { + std::cout << i << ": ["; + for(size_t j = 0; j < len[1]; j++) + { + std::cout << j << ": ["; + for(size_t k = 0; k < len[2]; k++) + { + std::cout << k << ": ["; + for(size_t v = 0; v < len[3]; v++) + { + if constexpr(std::is_same_v) + { + auto m = + ck_tile::type_convert(x(std::vector{i, j, k, v})); + + std::cout << m; + if(v != len[3] - 1) + std::cout << ","; + } + else + { + std::cout << x(std::vector{i, j, k, v}) << " "; + } + } + std::cout << "]" << std::endl; + } + std::cout << "]" << std::endl; + } + std::cout << std::endl; + } + std::cout << "--------------------" << std::endl; +} +#endif + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string init_method) +{ + if(init_method == "ui" || init_method == "ni") + { + unsigned max_rounding_point_distance = 0; + double atol = 2e-3; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } + else + { + unsigned max_rounding_point_distance = 1; + double atol = 0.0625; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "whether do CPU validation or not") + .insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)") + .insert("N", "2", "input batch size. ") + .insert("C", "64", "input channel size.") + .insert("H", "1", "input height size.") + .insert("W", "64", "input width size. ") + .insert("layout_in", "NCHW", "input tensor data layout - NCHW by default") + .insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ") + .insert("seed", "-1", "seed to be used, -1 means random every time") + .insert("kname", "0", "t to 1 will print kernel name"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run_batched_transpose(ck_tile::ArgParser args) +{ + int validate = args.get_int("v"); + std::string prec = args.get_str("pr"); + int N = args.get_int("N"); + int C = args.get_int("C"); + int H = args.get_int("H"); + int W = args.get_int("W"); + std::string layout_in = args.get_str("layout_in"); + std::string layout_out = args.get_str("layout_out"); + int seed = args.get_int("seed"); + + int dim_in[4], dim_out[4]; + int stride_dim_in[4], stride_dim_out[4]; + bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC"; + bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW"; + assert(nchw2nhwc != nhwc2nchw); + (void)nhwc2nchw; + + dim_in[0] = N; + dim_in[1] = nchw2nhwc ? C : H; + dim_in[2] = nchw2nhwc ? H : W; + dim_in[3] = nchw2nhwc ? W : C; + dim_out[0] = N; + dim_out[1] = nchw2nhwc ? H : C; + dim_out[2] = nchw2nhwc ? W : H; + dim_out[3] = nchw2nhwc ? C : W; + stride_dim_in[0] = C * H * W; + stride_dim_in[1] = nchw2nhwc ? H * W : C * W; + stride_dim_in[2] = nchw2nhwc ? W : C; + stride_dim_in[3] = 1; + stride_dim_out[0] = C * H * W; + stride_dim_out[1] = nchw2nhwc ? C * W : H * W; + stride_dim_out[2] = nchw2nhwc ? C : W; + stride_dim_out[3] = 1; + + if(seed < 0) + { + seed = std::time(nullptr); + } + + ck_tile::HostTensor x_host( + {dim_in[0], dim_in[1], dim_in[2], dim_in[3]}, + {stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]}); + ck_tile::HostTensor y_host( + {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, + {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + + ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes()); + + x_dev.ToDevice(x_host.data()); + + auto trait = batched_transpose_trait{prec, layout_in}; + + uint32_t height = nchw2nhwc ? C : H * W; + uint32_t width = nchw2nhwc ? H * W : C; + + batched_transpose_kargs karg = [&]() { + batched_transpose_kargs a_; + a_.p_input = x_dev.GetDeviceBuffer(); + a_.p_output = y_dev.GetDeviceBuffer(); + a_.batch = N; + a_.height = height; + a_.width = width; + return a_; + }(); + + ck_tile::stream_config sc{nullptr, true}; + + auto ms = batched_transpose(trait, karg, sc); + + std::size_t num_operations = N * C * H * (W - 1); + std::size_t num_bytes = N * C * H * W * sizeof(Type); + + float ave_time = ms * 1E-3; + float gb_per_sec = num_bytes / ms * 1.E-6; + float tflops = static_cast(num_operations) / ms * 1.E-6; + + std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H + << ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out + << " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops" + << gb_per_sec << " GB/s, " << std::endl; + + printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n", + prec.c_str(), + N, + C, + H, + W, + layout_in.c_str(), + ms); + if(ms < 0) + printf("not supported\n"); + fflush(stdout); + + if(ms < 0) + { + return false; + } + + y_dev.FromDevice(y_host.data()); + + bool rtn = true; + if(validate) + { + // this host buffer will not copy to GPU, so no need use stride + ck_tile::HostTensor y_ref( + {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, + {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); + + ck_tile::reference_batched_transpose(x_host, y_ref, layout_in, layout_out); + + auto [rtol, atol] = get_elimit(""); + + rtn &= ck_tile::check_err( + y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol); + } + printf("valid:%s\n", rtn ? "y" : "n"); + fflush(stdout); + return rtn; +} + +int main(int argc, char** argv) +{ + auto [result, args] = create_args(argc, argv); + if(!result) + return -1; + std::string prec = args.get_str("pr"); + + bool r = true; + if(prec.compare("fp16") == 0) + { + r &= run_batched_transpose(args); + } + else if(prec.compare("fp8") == 0) + { + r &= run_batched_transpose(args); + } + else + { + std::cerr << "Unsupported data type: " << prec << std::endl; + } + + return r ? 0 : -1; +} diff --git a/example/ck_tile/37_transpose/transpose_example.hpp b/example/ck_tile/37_transpose/transpose_example.hpp new file mode 100644 index 0000000000..8128d583ef --- /dev/null +++ b/example/ck_tile/37_transpose/transpose_example.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "batched_transpose_kernel.hpp" +#include "block_transpose.hpp" +#include "transpose_policy.hpp" + +#include +#include + +#pragma once + +struct batched_transpose_trait +{ + std::string type; + std::string layout; +}; + +struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs +{ +}; + +float batched_transpose(batched_transpose_trait t, + batched_transpose_kargs a, + ck_tile::stream_config s); diff --git a/example/ck_tile/37_transpose/transpose_policy.hpp b/example/ck_tile/37_transpose/transpose_policy.hpp new file mode 100644 index 0000000000..ea1a4130fe --- /dev/null +++ b/example/ck_tile/37_transpose/transpose_policy.hpp @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +struct TransposePolicy +{ + static constexpr auto TileAccessPattern = tile_distribution_pattern::thread_raked; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSize() + { + return 16 / sizeof(typename Problem::DataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return integer_least_multiple( + sizeof(typename Problem::DataType) * + MakeLdsStoreBlockDescriptor().get_element_space_size(), + 16); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution() + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t LeadDimPerBlock = Problem::kLeadSizePerBlock; + constexpr index_t SecondDimPerBlock = Problem::kSecondSizePerBlock; + constexpr index_t VecLoadSize = 16 / sizeof(typename Problem::DataType); + + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution() + { + constexpr auto input_dstr = MakeLdsLoadTileDistribution(); + + using OutTileDstrEncode = + typename OutputTileDistributionTraits, + typename Problem::DataType>::OutDstrEncode; + constexpr auto block_dstr = make_static_tile_distribution(OutTileDstrEncode{}); + + return block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor() + { + constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock; + constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock; + constexpr index_t kVectorSize = 16 / sizeof(typename Problem::DataType); + + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto lds_block_desc = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor() + { + constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock; + constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock; + + constexpr index_t kVectorSize = 8 / sizeof(typename Problem::DataType); + + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto lds_block_desc = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadTileDistribution() + { + using DataType = typename Problem::DataType; + + // Extract base dimensions from the traits + constexpr index_t kBaseLeadDim = LaneGroupTransposeTraits::kleadDim; + constexpr index_t kBaseSecondDim = LaneGroupTransposeTraits::ksecondDim; + + // Calculate block-level dimensions + constexpr index_t kLead = Problem::kLeadSizePerXdl; + constexpr index_t kSecond = Problem::kSecondSizePerXdl; + constexpr index_t kLeadIterPerWarp = Problem::kLeadXdlNumPerWarp; + constexpr index_t kSecondIterPerWarp = Problem::kSecondXdlNumPerWarp; + constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps; + constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps; + + // Calculate repetitions of base pattern + constexpr index_t kLeadRepetitions = kLead / kBaseLeadDim; + constexpr index_t kSecondRepetitions = kSecond / kBaseSecondDim; + constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim; + constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations; + + constexpr auto xdllevel_dstr_encoding = make_transposed_distr_encode(); + + constexpr auto input_tile_encode = + InputTileDistributionEncoding(); + constexpr auto block_dstr = make_static_tile_distribution(input_tile_encode); + return block_dstr; + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index f2f39b6e17..92b859a750 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -21,3 +21,4 @@ add_subdirectory(18_flatmm) add_subdirectory(19_gemm_multi_d) add_subdirectory(35_batched_transpose) add_subdirectory(36_copy) +add_subdirectory(37_transpose) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index be84842347..ed39719cf4 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -10,6 +10,7 @@ #include "ck_tile/core/algorithm/static_encoding_pattern.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp" +#include "ck_tile/core/arch/amd_transpose_load_encoding.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/utility.hpp" @@ -39,6 +40,7 @@ #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/tensor/buffer_view.hpp" #include "ck_tile/core/tensor/load_tile.hpp" +#include "ck_tile/core/tensor/load_tile_transpose.hpp" #include "ck_tile/core/tensor/null_tensor.hpp" #include "ck_tile/core/tensor/null_tile_window.hpp" #include "ck_tile/core/tensor/shuffle_tile.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7111eed596..0ec1a95511 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -2784,6 +2784,40 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, #endif } +template +__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) +{ + + if constexpr(std::is_same_v, ck_tile::half_t>) + { + typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t; + __attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>( + reinterpret_cast(in_ptr)); + return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr)); + } + else if constexpr(std::is_same_v, ck_tile::bf16_t>) + { + typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t; + __attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>( + reinterpret_cast(in_ptr)); + return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr)); + } + else if constexpr(std::is_same_v, ck_tile::fp8_t>) + { + typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t; + __attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>( + reinterpret_cast(in_ptr)); + return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); + } + else + { + static_assert(false, "not implemented"); + } +} + } // namespace ck_tile #endif // !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN diff --git a/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp b/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp new file mode 100644 index 0000000000..7ffe6dc0fb --- /dev/null +++ b/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/tensor/tile_distribution_encoding.hpp" + +namespace ck_tile { + +// this generate wave level tile distribution +template +struct LaneGroupTransposeTraits; + +template +struct LaneGroupTransposeTraits> +{ + // before transpose, 4x16 + static constexpr index_t ksecondDim = 4; + static constexpr index_t kleadDim = 16; + // after transpose, 16x4 + static constexpr index_t ksecondDimT = 16; + static constexpr index_t kleadDimT = 4; + template + using TileDistribution = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 1, 2>, + sequence<1, 1, 3>>; +}; + +template +struct LaneGroupTransposeTraits> +{ + static constexpr index_t ksecondDim = 8; + static constexpr index_t kleadDim = 16; + + static constexpr index_t ksecondDimT = 16; + static constexpr index_t kleadDimT = 8; + + template + using TileDistribution = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 1, 2>, + sequence<1, 1, 3>>; +}; + +/* + * @brief This function is used to generate the transposed distribution encoding + * for the given data type and distribution dimensions. + * + * @tparam T The data type of the elements in the tensor. + * @tparam kOuterDistDim0 The outer distribution dimension 0, which is outer dimension for stride. + * @tparam kOuterDistDim1 The outer distribution dimension 1, which is inner dimension for stride. + * @tparam kInnerDistDim0 The inner distribution dimension 0, which is outer dimension for + * consecutive. + * @tparam kInnerDistDim1 The inner distribution dimension 1, which is inner dimension for + * consecutive. + */ +template +CK_TILE_DEVICE constexpr auto make_transposed_distr_encode() +{ + using xdllevel_dstr_encoding = typename LaneGroupTransposeTraits:: + template TileDistribution; + return xdllevel_dstr_encoding{}; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index bdcfbdd920..cd7b7d0a1f 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -18,6 +18,7 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/utility/ignore.hpp" namespace ck_tile { @@ -133,6 +134,28 @@ struct buffer_view>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto transpose_get(index_t i, + index_t linear_offset, + bool is_valid_element, + bool_constant = {}) const + { + static_assert(false, "Error: transpose load not supported in global memory space."); + ignore = i; + ignore = linear_offset; + ignore = is_valid_element; + return; + } + // i is offset of T, not X. i should be aligned to X template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto transpose_get(index_t i, + index_t linear_offset, + bool is_valid_element, + bool_constant = {}) const + { + static_assert(false, "Error: transpose load not supported in global memory space."); + ignore = i; + ignore = linear_offset; + ignore = is_valid_element; + return; + } + // i is offset of T, not X. i should be aligned to X template {}(dst, v_offset * sizeof(T), i_offset * sizeof(T)); } + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto + transpose_get(index_t i, index_t linear_offset, bool is_valid_element) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + if(is_valid_element) + { + constexpr address_space_enum addr_space = get_address_space(); + return amd_transpose_load_to_vgpr, t_per_x, addr_space>( + p_data_ + i + linear_offset); + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{numeric>::zero()}; + } + else + { + return X{invalid_element_value_}; + } + } + } + // i is offset of T, not X. i should be aligned to X template +struct is_sequence_suffix +{ + static constexpr bool size_check = (Suffix::size() <= Sequence::size()); + + static constexpr index_t start_pos = Sequence::size() - Suffix::size(); + using extract_indices = typename arithmetic_sequence_gen::type; + + static constexpr bool value = + size_check && (Suffix{} == decltype(Sequence::extract(extract_indices{})){}); +}; + +template +struct is_sequence_suffix, sequence> +{ + static constexpr bool value = true; +}; + +template +constexpr bool is_sequence_suffix_v = is_sequence_suffix::value; + +} // namespace util + +// Default policy: Retains original 2D transpose behavior +template +struct DefaultTranspose +{ + struct Quad16 + { + using InputEncoding = tile_distribution_encoding, + tuple, sequence<4, 4>>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using OutputEncoding = tile_distribution_encoding, + tuple, sequence<4>>, + tuple>, + tuple>, + sequence<2>, + sequence<0>>; + }; + + struct Quad8 + { + using InputEncoding = tile_distribution_encoding, + tuple, sequence<2, 8>>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using OutputEncoding = tile_distribution_encoding, + tuple, sequence<8>>, + tuple>, + tuple>, + sequence<2>, + sequence<0>>; + }; + + // Select based on data size + using QuadInputEncoding = std::conditional_t; + + using QuadOutputEncoding = std::conditional_t; + + // Always swap last two dimensions + static constexpr auto transpose_dims = sequence<1, 0>{}; + + // Programmable: Element grouping function + static constexpr auto group_func = [](auto idx) { + return idx; // Identity mapping + }; + + template + struct ValidationTraits + { + static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_; + static constexpr auto quad_hs_lengthss = QuadInputEncoding::hs_lengthss_; + // 1. Must be 2D tensor + static constexpr bool dims_valid = (InDstrEncode::NDimX == 2); + // 2. Quad pattern must be suffix of input pattern + static constexpr bool suffix_valid_dim0 = + util::is_sequence_suffix_v()), + decltype(input_hs_lengthss.template get<0>())>; + static constexpr bool suffix_valid_dim1 = + util::is_sequence_suffix_v()), + decltype(input_hs_lengthss.template get<1>())>; + + // 3. PS→RHS mapping constraints + static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_; + static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_; + + static constexpr index_t ndimp_outer = input_ps_to_rhss_major.size() - 1; + static constexpr index_t ndimp_inner = + input_ps_to_rhss_major[number{}].size() - 1; + + static constexpr bool ps_mapping_valid = + (input_ps_to_rhss_major[number{}][number{}] == 2) && + (input_ps_to_rhss_minor[number{}][number{}] == + input_hs_lengthss[number<1>{}].size() - 2) && + (input_ps_to_rhss_major[number{}][number{}] == 1) && + (input_ps_to_rhss_minor[number{}][number{}] == + input_hs_lengthss[number<0>{}].size() - 1); + + // 4. YS→RHS mapping constraints + static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_; + static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_; + + static constexpr bool ys_mapping_valid = + (input_ys_to_rhs_major.back() == 2) && + (input_ys_to_rhs_minor.back() == input_hs_lengthss[number<1>{}].size() - 1) && + (input_ys_to_rhs_major[input_ys_to_rhs_major.size() - 2] == 1) && + (input_ys_to_rhs_minor[input_ys_to_rhs_minor.size() - 2] == + input_hs_lengthss[number<0>{}].size() - 2); + + static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 && + ps_mapping_valid && ys_mapping_valid; + }; +}; +template +struct TransposeTileDistrChecker +{ + using InDstrEncode = typename remove_cvref_t::DstrEncode; + + using Validator = typename Policy::template ValidationTraits; + + static constexpr bool distr_encoding_valid = Validator::value; +}; + +// this is used to generate the transposed output tile distribution encoding +// based on the input tile distribution encoding +template > +struct OutputTileDistributionTraits +{ + using InDstrEncode = typename remove_cvref_t::DstrEncode; + static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_; + static constexpr auto quad_input_hs_lengthss = Policy::QuadInputEncoding::hs_lengthss_; + static constexpr auto quad_output_hs_lengthss = Policy::QuadOutputEncoding::hs_lengthss_; + + static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_; + static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_; + static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_; + static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_; + + static constexpr auto quad_ps_to_rhss_major = Policy::QuadInputEncoding::ps_to_rhss_major_; + static constexpr auto quad_ps_to_rhss_minor = Policy::QuadInputEncoding::ps_to_rhss_minor_; + + // for transpose load + // append the reversed quad output hs lengths to the input hs lengthss after removing + // the quad_input_hs_lengthss + // then reverse the whole sequence to get the dst_out_hs_lengthss + static constexpr auto reversed_quad_output_hs_lengthss = tuple_reverse(quad_output_hs_lengthss); + + static constexpr auto full_out_hs_lengthss = generate_tuple( + [](auto i) { + return input_hs_lengthss[i] + .extract(typename arithmetic_sequence_gen<0, + input_hs_lengthss[i].size() - + quad_input_hs_lengthss[i].size(), + 1>::type{}) + .push_back(reversed_quad_output_hs_lengthss[i]); + }, + number{}); + + static constexpr auto dst_out_hs_lengthss = tuple_reverse(full_out_hs_lengthss); + + // for PS→RHS mapping(both major and minor), we need to modify the last element of the major + // sequence + static constexpr auto modified_ps_to_rhss_major = generate_tuple( + [](auto i) { + if constexpr(i == input_ps_to_rhss_major.size() - 1) + { + constexpr auto current_size = input_ps_to_rhss_major[i].size(); + constexpr auto reduce_size = quad_ps_to_rhss_major[number<0>{}].size(); + constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract( + typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{}); + return reduced_ps_to_rhss_major.push_back(number<2>{}); + } + else + { + // For all other sequences, keep them unchanged + return input_ps_to_rhss_major[i]; + } + }, + number{}); + + static constexpr auto minor_last_index = + full_out_hs_lengthss[number{}].size() - 1; + static constexpr auto major_last_index = full_out_hs_lengthss[number<0>{}].size() - 1; + + static constexpr auto dst_ps_to_rhss_minor = generate_tuple( + [](auto i) { + if constexpr(i == input_ps_to_rhss_minor.size() - 1) + { + constexpr auto current_size = input_ps_to_rhss_minor[i].size(); + constexpr auto reduce_size = quad_ps_to_rhss_minor[number<0>{}].size(); + constexpr auto reduced_ps_to_rhss_minor = input_ps_to_rhss_minor[i].extract( + typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{}); + return reduced_ps_to_rhss_minor.push_back(number{}); + } + else + { + // For all other sequences, keep them unchanged + return input_ps_to_rhss_minor[i]; + } + }, + number{}); + + // for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed + static constexpr auto swap_one_and_two = [](const index_t idx) { + return (idx == 1) ? 2 : (idx == 2) ? 1 : idx; + }; + static constexpr auto dst_ps_to_rhss_major = generate_tuple( + [](auto i) { return modified_ps_to_rhss_major[i].transform(swap_one_and_two); }, + number{}); + + static constexpr auto modified_input_ys_to_rhs_major = + input_ys_to_rhs_major.pop_back().push_back(number<1>{}); + + static constexpr auto dst_ys_to_rhs_major = generate_sequence_v2( + [](auto i) { return number{}; }, + number{}); + + static constexpr auto dst_ys_to_rhs_minor = + input_ys_to_rhs_minor.pop_back().push_back(number{}); + + using OutDstrEncode = tile_distribution_encoding, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>; +}; + +template +CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding() +{ + constexpr auto block_outer_dst_encoding = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 1>, + sequence<0, 0>>{}; + constexpr auto blk_distr_encode = + detail::make_embed_tile_distribution_encoding(block_outer_dst_encoding, InnerEncode{}); + + return blk_distr_encode; +} + +/** + * @brief transpose loads tile from a tensor and returns the resulting tensor with a new + * (transposed) tile distribution. use SFINAE to ensure the tile distribution encoding is valid. + * + * This function is intended for use with statically distributed tensor tiles, where the input + * and output tile distributions differ due to the transpose operation. It ensures that the + * element space size and vector length remain consistent between the input and output + * distributions. + * + * @tparam BottomTensorView_ The type of the bottom tensor view. + * @tparam WindowLengths_ The type representing the window lengths. + * @tparam TileDistribution_ The type representing the tile distribution. + * @tparam NumCoord The number of coordinates (dimensions). + * @tparam Policy The transpose policy to use (defaults to DefaultTranspose). + * the last is SFINAE to ensure the tile distribution encoding is valid. + * + * @param tile_window The tile window with static distribution to load and transpose. + * + * @return A statically distributed tensor containing the transposed tile data. + * + * @note + * - The function uses compile-time checks to ensure the input and output tile distributions + * are compatible in terms of element space size and vector length. + * - The transpose operation is performed according to the specified Policy. + */ +template < + typename BottomTensorView_, + typename WindowLengths_, + typename TileDistribution_, + index_t NumCoord, + typename Policy = DefaultTranspose, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE auto +load_tile_transpose(const tile_window_with_static_distribution& tile_window) +{ + using OutTileDstrEncode = + typename OutputTileDistributionTraits::OutDstrEncode; + auto out_tensor = make_static_distributed_tensor( + make_static_tile_distribution(OutTileDstrEncode{})); + auto trans_tensor = tile_window.template load_transpose(); + constexpr auto input_distr = TileDistribution_{}; + constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{}); + + constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor(); + constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor(); + + constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y(); + constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y(); + + constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths()); + constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths()); + + constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size(); + constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size(); + static_assert(y_in_element_space_size == y_out_element_space_size, + "the element space size is not the same!"); + static_assert(y_in_lengths[NDimYIn - 1] == y_out_lengths[NDimYOut - 1], + "the vector length is not the same!"); + constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1]; + constexpr index_t num_of_access = + reduce_on_sequence(y_in_lengths, multiplies{}, number<1>{}) / vecLoadSize; + + using DataVec = array; + static_for<0, num_of_access, 1>{}([&](auto iAccess) { + out_tensor.get_thread_buffer().template set_as( + number{}, + trans_tensor.get_thread_buffer().template get_as(number{})); + }); + + return out_tensor; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 656ce8d20d..9429a960d8 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -251,6 +251,33 @@ struct tensor_view bool_constant{}); } + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr remove_cvref_t + get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const + { + return buf_.template transpose_get( + coord.get_offset(), + linear_offset, + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); + } + + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr remove_cvref_t + get_transpose_vectorized_elements(const TensorCoord& coord, + index_t linear_offset, + bool is_valid_element // flag + ) const + { + return buf_.template transpose_get(coord.get_offset(), linear_offset, is_valid_element); + } // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template + CK_TILE_DEVICE auto load_transpose() const + { + constexpr auto tile_dstr = typename Base::TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + this->template load_transpose( + dst_tensor, number{}, bool_constant{}); + return dst_tensor; + } + + template + CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const + { + using Traits = typename Base::Traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = typename Base::TileDstr{}; + + constexpr auto group_func = Policy::group_func; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from bottom tensor + const vector_t vec_value = + this->get_bottom_tensor_view() + .template get_transpose_vectorized_elements( + bottom_tensor_thread_coord, 0); + // write into distributed tensor + static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto orig_idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr auto grouped_idx_ys = group_func(orig_idx_ys); + + constexpr index_t linear_distributed_index = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys); + + dst_tensor.get_thread_buffer().template at() = + vec_value.template get_as()[j]; + }); + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); + + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + template CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, @@ -415,7 +491,6 @@ struct tile_window_with_static_distribution { using Traits = typename Base::Traits; - // using vector_type_t = typename Traits::vector_type_t; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index f11610d658..56c5066774 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -613,6 +613,60 @@ struct tile_window_linear WINDOW_DISPATCH_ISSUE(); } + template + CK_TILE_DEVICE auto load_transpose() const + { + constexpr auto tile_dstr = typename Base::TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + this->template load_transpose_linear( + dst_tensor, number{}, bool_constant{}); + return dst_tensor; + } + + template + CK_TILE_DEVICE auto load_transpose_linear(DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const + { + using vector_t = typename traits::vector_t; + using SFC_Ys = typename traits::SFC_Ys; + + constexpr auto tile_dstr = typename Base::TileDstr{}; + + constexpr auto group_func = Policy::group_func; + + auto issue = [&](auto i_access_) { + constexpr auto IAccess = number{}; + constexpr auto non_linear_id = number{}; + auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; + auto bottom_tensor_flag = cached_flags_[IAccess]; + + constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess); + + // read from bottom tensor + const vector_t vec_value = + this->get_bottom_tensor_view().template get_transpose_vectorized_elements( + bottom_tensor_thread_coord, 0); + // write into distributed tensor + static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t linear_distributed_index = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + dst_tensor.get_thread_buffer().template at() = + vec_value.template get_as()[j]; + }); + }; + WINDOW_DISPATCH_ISSUE(); + } + template CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor,