From e4d8548dc54e78ffc2b6fcbaac1e218ae7e540a3 Mon Sep 17 00:00:00 2001 From: fangche123 Date: Wed, 29 Jan 2025 08:22:02 +0800 Subject: [PATCH] add batched_transpose implement (#1660) * add batched_transpose implement --------- Co-authored-by: root Co-authored-by: ThruptiRajLakshmanaGowda Co-authored-by: ThomasNing [ROCm/composable_kernel commit: c5fff071e5c60af87ed7e3a9d130d8151b353384] --- .../35_batched_transpose/CMakeLists.txt | 9 + .../ck_tile/35_batched_transpose/README.md | 27 ++ .../batched_transpose_api.cpp | 82 ++++++ .../batched_transpose_example.cpp | 261 ++++++++++++++++++ .../batched_transpose_example.hpp | 25 ++ .../35_batched_transpose/script/smoke_test.sh | 11 + example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/host.hpp | 1 + .../reference/reference_batched_transpose.hpp | 59 ++++ include/ck_tile/ops/batched_transpose.hpp | 11 + .../kernel/batched_transpose_kernel.hpp | 129 +++++++++ .../pipeline/batched_transpose_pipeline.hpp | 52 ++++ .../pipeline/batched_transpose_policy.hpp | 44 +++ .../pipeline/batched_transpose_problem.hpp | 48 ++++ 14 files changed, 760 insertions(+) create mode 100644 example/ck_tile/35_batched_transpose/CMakeLists.txt create mode 100644 example/ck_tile/35_batched_transpose/README.md create mode 100644 example/ck_tile/35_batched_transpose/batched_transpose_api.cpp create mode 100644 example/ck_tile/35_batched_transpose/batched_transpose_example.cpp create mode 100644 example/ck_tile/35_batched_transpose/batched_transpose_example.hpp create mode 100755 example/ck_tile/35_batched_transpose/script/smoke_test.sh create mode 100644 include/ck_tile/host/reference/reference_batched_transpose.hpp create mode 100644 include/ck_tile/ops/batched_transpose.hpp create mode 100644 include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp create mode 100644 include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp create mode 100644 include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp create mode 100644 include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp diff --git a/example/ck_tile/35_batched_transpose/CMakeLists.txt b/example/ck_tile/35_batched_transpose/CMakeLists.txt new file mode 100644 index 0000000000..a08fcebb74 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/CMakeLists.txt @@ -0,0 +1,9 @@ +set(TARGET_NAME tile_example_batched_transpose) +add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL batched_transpose_example.cpp batched_transpose_api.cpp) +target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) +target_compile_options(tile_example_batched_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS}) + diff --git a/example/ck_tile/35_batched_transpose/README.md b/example/ck_tile/35_batched_transpose/README.md new file mode 100644 index 0000000000..d0583e7529 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/README.md @@ -0,0 +1,27 @@ +# Batched Transpose +This folder contains example for batched Transpose using ck_tile tile-programming implementation. Currently, it supports the batched transpose with NCHW to NHWC or NHWC to NCHW. So in this way from NCHW you could transpose to either NHWC or NWCH(two transposes). Now the transpose read with single data point. We would soon put it in vectorized transpose. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# Make the transpose executable +make tile_example_batched_transpose -j +``` +This will result in an executable `build/bin/tile_example_batched_transpose` + +## example +``` +args: + -N input batch size (default:2) + -C input channel size. (default:16) + -H input height size. (default:1) + -W input width size. (default:16) + -v whether do CPU validation or not (default: 1) + -layout_in input tensor data layout - NCHW by default + -layout_out output tensor data layout - NHWC by default + -seed seed to be used, -1 means random every time (default:-1) + -k_name t to 1 will print kernel name (default:0) +``` \ No newline at end of file diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp new file mode 100644 index 0000000000..77d768fe3f --- /dev/null +++ b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "batched_transpose_example.hpp" +#include + +template +float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) +{ + uint32_t dim_block_h = (a.height + block_y - 1) / block_y; + uint32_t dim_block_w = (a.width + block_x - 1) / block_x; + uint32_t dim_stride = a.height * a.width; + + a.dim_stride = dim_stride; + a.dim_block_h = dim_block_h; + a.dim_block_w = dim_block_w; + + using block_tile = ck_tile::sequence; + using warp_tile = ck_tile::sequence; + using thread_tile = ck_tile::sequence; + + using ts_problem = + ck_tile::BatchedTransposeProblem; + using ts_pipeline = ck_tile::BatchedTransposePipeline; + + using kernel = ck_tile::BatchedTransposeKernel; + + auto kargs = kernel::MakeKargs(a); + + const dim3 grids = kernel::GridSize(a); + constexpr dim3 blocks = kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); + + return ave_time; +} + +// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y +#define FOREACH_TRANSPOSE_PARAM(F) \ + F(fp16, ck_tile::fp16_t, 16, 16, 8, 8, 1, 1) \ + F(bf16, ck_tile::bf16_t, 16, 16, 8, 8, 1, 1) \ + F(fp32, ck_tile::fp32_t, 16, 16, 8, 8, 1, 1) \ + F(int8, ck_tile::int8_t, 16, 16, 8, 8, 1, 1) + +// Macro that defines one static function per line +#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY) \ + static float transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY( \ + batched_transpose_kargs& a, ck_tile::stream_config& s) \ + { \ + return batched_transpose_dispatch(a, s); \ + } + +FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN) + +float batched_transpose(batched_transpose_trait t, + batched_transpose_kargs a, + ck_tile::stream_config s) +{ + if(t.type == "fp16") + { + return transpose_fn_fp16_16_16_8_8_1_1(a, s); + } + else if(t.type == "bf16") + { + return transpose_fn_bf16_16_16_8_8_1_1(a, s); + } + else if(t.type == "fp32") + { + return transpose_fn_fp32_16_16_8_8_1_1(a, s); + } + else if(t.type == "int8") + { + return transpose_fn_int8_16_16_8_8_1_1(a, s); + } + return -1; +} diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp new file mode 100644 index 0000000000..48fc2859bf --- /dev/null +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "batched_transpose_example.hpp" + +#if 0 +template +void dump_host_tensor_4d(const ck_tile::HostTensor& x) +{ + auto len = x.get_lengths(); + assert(len.size() == 4); + std::cout << "["; + for(size_t i = 0; i < len[0]; i++) + { + std::cout << i << ": ["; + for(size_t j = 0; j < len[1]; j++) + { + std::cout << j << ": ["; + for(size_t k = 0; k < len[2]; k++) + { + std::cout << k << ": ["; + for(size_t v = 0; v < len[3]; v++) + { + if constexpr(std::is_same_v) + { + auto m = + ck_tile::type_convert(x(std::vector{i, j, k, v})); + + std::cout << m; + if(v != len[3] - 1) + std::cout << ","; + } + else + { + std::cout << x(std::vector{i, j, k, v}) << " "; + } + } + std::cout << "]" << std::endl; + } + std::cout << "]" << std::endl; + } + std::cout << std::endl; + } + std::cout << "--------------------" << std::endl; +} +#endif + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string init_method) +{ + if(init_method == "ui" || init_method == "ni") + { + unsigned max_rounding_point_distance = 0; + double atol = 2e-3; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } + else + { + unsigned max_rounding_point_distance = 1; + double atol = 0.0625; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "whether do CPU validation or not") + .insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)") + .insert("N", "2", "input batch size. ") + .insert("C", "16", "input channel size.") + .insert("H", "1", "input height size.") + .insert("W", "16", "input width size. ") + .insert("layout_in", "NCHW", "input tensor data layout - NCHW by default") + .insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ") + .insert("seed", "-1", "seed to be used, -1 means random every time") + .insert("kname", "0", "t to 1 will print kernel name"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run_batched_transpose(ck_tile::ArgParser args) +{ + int validate = args.get_int("v"); + std::string prec = args.get_str("pr"); + int N = args.get_int("N"); + int C = args.get_int("C"); + int H = args.get_int("H"); + int W = args.get_int("W"); + std::string layout_in = args.get_str("layout_in"); + std::string layout_out = args.get_str("layout_out"); + int seed = args.get_int("seed"); + + int dim_in[4], dim_out[4]; + int stride_dim_in[4], stride_dim_out[4]; + bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC"; + bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW"; + assert(nchw2nhwc != nhwc2nchw); + (void)nhwc2nchw; + + dim_in[0] = N; + dim_in[1] = nchw2nhwc ? C : H; + dim_in[2] = nchw2nhwc ? H : W; + dim_in[3] = nchw2nhwc ? W : C; + dim_out[0] = N; + dim_out[1] = nchw2nhwc ? H : C; + dim_out[2] = nchw2nhwc ? W : H; + dim_out[3] = nchw2nhwc ? C : W; + stride_dim_in[0] = C * H * W; + stride_dim_in[1] = nchw2nhwc ? H * W : C * W; + stride_dim_in[2] = nchw2nhwc ? W : C; + stride_dim_in[3] = 1; + stride_dim_out[0] = C * H * W; + stride_dim_out[1] = nchw2nhwc ? C * W : H * W; + stride_dim_out[2] = nchw2nhwc ? C : W; + stride_dim_out[3] = 1; + + if(seed < 0) + { + seed = std::time(nullptr); + } + + ck_tile::HostTensor x_host( + {dim_in[0], dim_in[1], dim_in[2], dim_in[3]}, + {stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]}); + ck_tile::HostTensor y_host( + {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, + {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + + ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes()); + + x_dev.ToDevice(x_host.data()); + + auto trait = batched_transpose_trait{prec, layout_in}; + + uint32_t height = nchw2nhwc ? C : H * W; + uint32_t width = nchw2nhwc ? H * W : C; + + batched_transpose_kargs karg = [&]() { + batched_transpose_kargs a_; + a_.p_input = x_dev.GetDeviceBuffer(); + a_.p_output = y_dev.GetDeviceBuffer(); + a_.batch = N; + a_.height = height; + a_.width = width; + return a_; + }(); + + ck_tile::stream_config sc{nullptr, true}; + + auto ms = batched_transpose(trait, karg, sc); + + std::size_t num_operations = N * C * H * (W - 1); + std::size_t num_bytes = N * C * H * W * sizeof(Type); + + float ave_time = ms * 1E-3; + float gb_per_sec = num_bytes / ms * 1.E-6; + float tflops = static_cast(num_operations) / ms * 1.E-6; + + std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H + << ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out + << " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops" + << gb_per_sec << " GB/s, " << std::endl; + + printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n", + prec.c_str(), + N, + C, + H, + W, + layout_in.c_str(), + ms); + if(ms < 0) + printf("not supported\n"); + fflush(stdout); + + if(ms < 0) + { + return false; + } + + y_dev.FromDevice(y_host.data()); + + bool rtn = true; + if(validate) + { + // this host buffer will not copy to GPU, so no need use stride + ck_tile::HostTensor y_ref( + {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, + {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); + + ck_tile::reference_batched_transpose(x_host, y_ref, layout_in, layout_out); + + auto [rtol, atol] = get_elimit(""); + + rtn &= ck_tile::check_err( + y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol); + } + printf("valid:%s\n", rtn ? "y" : "n"); + fflush(stdout); + return rtn; +} + +int main(int argc, char** argv) +{ + auto [result, args] = create_args(argc, argv); + if(!result) + return -1; + std::string prec = args.get_str("pr"); + + bool r = true; + if(prec.compare("fp32") == 0) + { + r &= run_batched_transpose(args); + } + else if(prec.compare("fp16") == 0) + { + r &= run_batched_transpose(args); + } + else if(prec.compare("bf16") == 0) + { + r &= run_batched_transpose(args); + } + else if(prec.compare("int8") == 0) + { + r &= run_batched_transpose(args); + } + + return r ? 0 : -1; +} diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp new file mode 100644 index 0000000000..487ddc17b2 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/ops/batched_transpose.hpp" + +#include +#include + +#pragma once + +struct batched_transpose_trait +{ + std::string type; + std::string layout; +}; + +struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs +{ +}; + +float batched_transpose(batched_transpose_trait t, + batched_transpose_kargs a, + ck_tile::stream_config s); diff --git a/example/ck_tile/35_batched_transpose/script/smoke_test.sh b/example/ck_tile/35_batched_transpose/script/smoke_test.sh new file mode 100755 index 0000000000..fdfef2cea8 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/script/smoke_test.sh @@ -0,0 +1,11 @@ +#!/bin/sh + +EXE=./build/bin/tile_example_batched_transpose + +for pr in "fp32" "fp16" "int8" ; do +$EXE -pr=$pr -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC' +done diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 296eb1ecef..7f4ba2ed35 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -17,3 +17,4 @@ add_subdirectory(14_moe_smoothquant) add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) +add_subdirectory(35_batched_transpose) diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 440b306705..bb5d8bfa86 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -34,3 +34,4 @@ #include "ck_tile/host/reference/reference_topk.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/timer.hpp" +#include "ck_tile/host/reference/reference_batched_transpose.hpp" diff --git a/include/ck_tile/host/reference/reference_batched_transpose.hpp b/include/ck_tile/host/reference/reference_batched_transpose.hpp new file mode 100644 index 0000000000..454ab42e32 --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_transpose.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +CK_TILE_HOST void reference_batched_transpose(const HostTensor& x, + HostTensor& y, + std::string layout_in = "NCHW", + std::string layout_out = "NHWC") +{ + const int N = x.mDesc.get_lengths()[0]; + + auto f = [&](auto batch) { + if(layout_in == "NCHW" && layout_out == "NHWC") + { + const int C = x.mDesc.get_lengths()[1]; + const int H = x.mDesc.get_lengths()[2]; + const int W = x.mDesc.get_lengths()[3]; + for(int c = 0; c < C; ++c) + { + for(int h = 0; h < H; ++h) + { + for(int w = 0; w < W; ++w) + { + Type v_x = x(batch, c, h, w); + y(batch, h, w, c) = v_x; + } + } + } + } + else if(layout_in == "NHWC" && layout_out == "NCHW") + { + const int H = x.mDesc.get_lengths()[1]; + const int W = x.mDesc.get_lengths()[2]; + const int C = x.mDesc.get_lengths()[3]; + for(int h = 0; h < H; ++h) + { + for(int w = 0; w < W; ++w) + { + for(int c = 0; c < C; ++c) + { + Type v_x = x(batch, h, w, c); + y(batch, c, h, w) = v_x; + } + } + } + } + }; + + make_ParallelTensorFunctor(f, N)(std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp new file mode 100644 index 0000000000..8741e0a49b --- /dev/null +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp new file mode 100644 index 0000000000..7e7dd03c6a --- /dev/null +++ b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include +#include + +namespace ck_tile { + +struct BatchedTransposeHostArgs +{ + const void* p_input; + void* p_output; + index_t batch; + index_t height; + index_t width; + // index_t dim_blocks; + index_t dim_stride; + index_t dim_block_h; + index_t dim_block_w; +}; + +template +struct BatchedTransposeKernel +{ + using Pipeline = remove_cvref_t; + using Problem = remove_cvref_t; + + using Type = typename Problem::InputType; + + struct BatchedTransposeKargs + { + const void* p_input; + void* p_output; + index_t batch; + index_t height; + index_t width; + index_t dim_stride; + }; + + using Kargs = BatchedTransposeKargs; + using Hargs = BatchedTransposeHostArgs; + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) + { + size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w; + size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h; + size_t grid_size_z = h.batch; + return dim3(grid_size_x, grid_size_y, grid_size_z); + } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_input = h.p_input; + k.p_output = h.p_output; + k.batch = h.batch; + k.height = h.height; + k.width = h.width; + k.dim_stride = h.dim_stride; + return k; + } + + CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + + static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock; + static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + + static constexpr ck_tile::index_t kMPerThread = Problem::kMPerThread; + static constexpr ck_tile::index_t kNPerThread = Problem::kNPerThread; + + static_assert(kMPerThread == 1 && kNPerThread == 1); + + const auto iDim = blockIdx.z; + const auto x_m_n = [&]() { + const auto x_dram_naive = make_naive_tensor_view( + static_cast(kargs.p_input) + iDim * kargs.dim_stride, + make_tuple(kargs.height, kargs.width), + make_tuple(kargs.width, 1), + number{}, // TODO thread load value + number<1>{}); + + return pad_tensor_view(x_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock); + const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock); + + const auto y_n_m = [&]() { + const auto y_dram_naive = make_naive_tensor_view( + static_cast(kargs.p_output) + iDim * kargs.dim_stride, + make_tuple(kargs.width, kargs.height), + make_tuple(kargs.height, 1), + number{}, + number<1>{}); + + return pad_tensor_view(y_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto x_block_window = + make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {static_cast(iM * kMPerBlock), + static_cast(iN * kNPerBlock)}); + + auto y_block_window = + make_tile_window(y_n_m, + make_tuple(number{}, number{}), + {static_cast(iN * kNPerBlock), + static_cast(iM * kMPerBlock)}); + + Pipeline{}(x_block_window, y_block_window); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp new file mode 100644 index 0000000000..aa62333918 --- /dev/null +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" +#include +#include + +namespace ck_tile { + +template +struct BatchedTransposePipeline +{ + // TODO: this kernel only support warp per row + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using InputType = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock; + static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock; + static constexpr index_t AlignmentM = Problem::AlignmentM; + static constexpr index_t AlignmentN = Problem::AlignmentN; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + + template + CK_TILE_DEVICE auto operator()(const InputWindow& input_window, OutputWindow& out_window) + { + auto inp_win = + make_tile_window(input_window, Policy::template MakeInputDistribution()); + auto out_win = + make_tile_window(out_window, Policy::template MakeOutputDistribution()); + + auto x = load_tile(inp_win); // x->thread input_win->block + + auto y = make_static_distributed_tensor( + Policy::template MakeOutputDistribution()); + + constexpr auto span_2d_x = decltype(x)::get_distributed_spans(); + + sweep_tile_span(span_2d_x[number<0>{}], [&](auto idx0) { + sweep_tile_span(span_2d_x[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx1, idx0); + y(i_j_idx) = x(i_j_idx); + }); + }); + + store_tile(out_win, y); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp new file mode 100644 index 0000000000..9953e8b8bf --- /dev/null +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/softmax.hpp" +#include "ck_tile/ops/topk.hpp" + +namespace ck_tile { + +struct BatchedTransposePolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution() + { + using S = Problem; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 1>>, + sequence<1, 2>, + sequence<2, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution() + { + using S = Problem; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 1>>, + sequence<2, 1>, + sequence<2, 2>>{}); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp new file mode 100644 index 0000000000..af6b2d51aa --- /dev/null +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include +#include + +#define VectorLoadSize 16 + +namespace ck_tile { + +template +struct BatchedTransposeProblem +{ + using InputType = remove_cvref_t; + + static constexpr index_t kMPerThread = ThreadTile::at(number<0>{}); + static constexpr index_t kNPerThread = ThreadTile::at(number<1>{}); + + static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); + static constexpr index_t kNPerWarp = WarpTile::at(number<1>{}); + + static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; + static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread; + + static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); + static constexpr index_t kNPerBlock = BlockTile::at(number<1>{}); + + static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp; + static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp; + + static constexpr index_t kBlockSize = + kMThreadPerWarp * kNThreadPerWarp * kMWarpPerBlock * kNWarpPerBlock; + + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + + static constexpr index_t AlignmentM = kPadM ? VectorLoadSize / sizeof(InputType) : 1; // TODO + static constexpr index_t AlignmentN = kPadN ? VectorLoadSize / sizeof(InputType) : 1; +}; +} // namespace ck_tile