From d53bca08e9072adbe3a56b97c6ea17aa1a0a4dbc Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Sat, 26 Jul 2025 21:51:54 -0700 Subject: [PATCH] [CK-Tile] Merge transpose examples (#2450) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * unify pipeline signature with existing example * iwyu * move stuff around in load-tile-transpose * cleanups in batched transpose pipeline * comments * use same inputs size * cleaner printf * print host args * use 64 block sides in the 37_transpose example * roll back grid dimension size adjustment for 37_transpose example * transpose grid for 37_transpose to unify with 35_batched_transpose * unify grid computation logic * make policy methods device only (since they are used only on device from the pipeline) * more host/device attribute cleanups * copy over problem * move over pipeline and policy * add switch to batched transpose api * make the lds problem more similar to original problem * factor out logic into traits * factor out conditional compilation into trait parameter * propagate pipeline to args * unhardcode pipeline dispatch parameter * refactor vector size * put warp tile out of dispatch * rename template parameter for trait * rewrite vector size in terms of problem * mark policy-internal struct variable as device * factor out input distribution and thread access pattern from policies * reword vector size * use datatype across batched transpose pipelines, problems and kernel * remove transpose traits from lds pipeline * add padding to the lds pipeline *interface* * add comment * remove ck_tile example #37 * update cmakelists * add test for new pipeline * update batched transpose test * roll back load_tile_transpose changes * remove comments * pack dispatch parameters into a config * padM can be enabled * adjust lds vector size to enable padding along N * update test * clean up logic * swap m/n input vector size * adjust perf test script * sweep over C/W in perf test * count both read and written bytes into bandwidth (x2 the number) * clang-format * widen size range for perf test * remove 64k x 64k case; it's too large for index * remove thread tile from dispatch * Solve merge conflict * fix compile * modify the transpose * solve the test error and clang format * Add v3 support for Groupd fwd conv+bias+clamp & ckProfiler (#2463) * Add logging to IsSupported. * Less casting in AddClamp * Conv+bias+clamp instances & profiler BF16 * Fix 3D instances & run just 1x for verification. * :Run just once for verification conv fwd. * ckProfiler conv fwd clampwq * Remove exec bit & formatting * Add support for MultiD for grouped conv fwd v3. * Enable 2Lds. * clean * align instances * align instances * profiler fixes * Fixes * fix * fix --------- Co-authored-by: Adam Osewski Co-authored-by: Bartłomiej Kocot * Fixing 0ms and inf GB/s issue in img2col (#2565) issue : ==== ``` sh $ bin/tile_example_img2col Perf: 0 ms, inf GB/s ``` solution : ====== Problem occured because config.time_kernel is false by default. if false, then no need to calculate perf, just print proper message `image_to_coloumn: pass, No Perf generated due to config.time_kernel=0` * merge with develop * solve clang format --------- Co-authored-by: ThomasNing Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: Adam Osewski Co-authored-by: Bartłomiej Kocot Co-authored-by: rahjain-amd [ROCm/composable_kernel commit: 821cd26c13323672b50c4cd2b35510d94f2199b8] --- .../batched_transpose_api.cpp | 215 +++++++++++---- .../batched_transpose_example.cpp | 16 +- .../batched_transpose_example.hpp | 1 + .../35_batched_transpose/script/perf_test.sh | 12 +- .../35_batched_transpose/script/smoke_test.sh | 42 +-- example/ck_tile/37_transpose/CMakeLists.txt | 9 - example/ck_tile/37_transpose/README.md | 27 -- .../37_transpose/batched_transpose_kernel.hpp | 120 -------- .../ck_tile/37_transpose/block_transpose.hpp | 149 ---------- .../ck_tile/37_transpose/transpose_api.cpp | 59 ---- .../37_transpose/transpose_example.cpp | 257 ------------------ .../37_transpose/transpose_example.hpp | 27 -- example/ck_tile/CMakeLists.txt | 1 - include/ck_tile/ops/batched_transpose.hpp | 4 + .../kernel/batched_transpose_kernel.hpp | 4 +- .../batched_transpose_common_policy.hpp | 33 +++ .../batched_transpose_lds_pipeline.hpp | 67 +++++ .../pipeline/batched_transpose_lds_policy.hpp | 58 +--- .../batched_transpose_lds_problem.hpp | 73 +++++ .../pipeline/batched_transpose_pipeline.hpp | 15 +- .../pipeline/batched_transpose_policy.hpp | 34 +-- .../pipeline/batched_transpose_problem.hpp | 31 +-- include/ck_tile/ops/gemm.hpp | 2 +- .../batched_transpose_api.cpp | 44 ++- 24 files changed, 431 insertions(+), 869 deletions(-) delete mode 100644 example/ck_tile/37_transpose/CMakeLists.txt delete mode 100644 example/ck_tile/37_transpose/README.md delete mode 100644 example/ck_tile/37_transpose/batched_transpose_kernel.hpp delete mode 100644 example/ck_tile/37_transpose/block_transpose.hpp delete mode 100644 example/ck_tile/37_transpose/transpose_api.cpp delete mode 100644 example/ck_tile/37_transpose/transpose_example.cpp delete mode 100644 example/ck_tile/37_transpose/transpose_example.hpp create mode 100644 include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp create mode 100644 include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp rename example/ck_tile/37_transpose/transpose_policy.hpp => include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp (65%) create mode 100644 include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp index 1eb0445c84..1f0f0b9bc1 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp @@ -2,41 +2,93 @@ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "batched_transpose_example.hpp" -template +namespace { + +template +struct kernel_traits; + +template <> +struct kernel_traits<0> +{ + template + using Problem = + ck_tile::BatchedTransposeProblem; + using Policy = ck_tile::BatchedTransposePolicy; + template + using Pipeline = + ck_tile::BatchedTransposePipeline, + Policy>; +}; + +template <> +struct kernel_traits<1> +{ + template + using Problem = + ck_tile::BatchedTransposeLdsProblem; + using Policy = ck_tile::BatchedTransposeLdsPolicy; + template + using Pipeline = ck_tile::BatchedTransposeLdsPipeline< + Problem, + Policy>; +}; +} // namespace + +template +struct BatchedTransposeConfig +{ + using InputType = InputType_; + static constexpr ck_tile::index_t kBlockX = BlockX_; + static constexpr ck_tile::index_t kBlockY = BlockY_; + static constexpr ck_tile::index_t kNumWarpsX = NumWarpsX_; + static constexpr ck_tile::index_t kNumWarpsY = NumWarpsY_; + static constexpr bool kPadM = PadM_; + static constexpr bool kPadN = PadN_; + static constexpr ck_tile::index_t kPipelineId = PipelineId_; +}; + +template float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) { uint32_t dim_stride = a.height * a.width; a.dim_stride = dim_stride; - a.dim_block_h = block_y; - a.dim_block_w = block_x; + a.dim_block_h = Config::kBlockY; + a.dim_block_w = Config::kBlockX; - using block_tile = ck_tile::sequence; - using warp_tile = ck_tile::sequence; - using thread_tile = ck_tile::sequence; - - using ts_problem = - ck_tile::BatchedTransposeProblem; - using ts_pipeline = ck_tile::BatchedTransposePipeline; - - using kernel = ck_tile::BatchedTransposeKernel; + // TODO: this is fragile and slow to compile + using kernel = ck_tile::BatchedTransposeKernel< + typename kernel_traits::template Pipeline< + typename Config::InputType, + ck_tile::sequence, + ck_tile::sequence, + Config::kPadM, + Config::kPadN>>; auto kargs = kernel::MakeKargs(a); const dim3 grids = kernel::GridSize(a); constexpr dim3 blocks = kernel::BlockSize(); - printf("Grid: %u %u %u\n", grids.x, grids.y, grids.z); - printf("Block: %u %u %u\n", blocks.x, blocks.y, blocks.z); - printf("kargs: kargs.batch %d kargs.height %d kargs.width %d kargs.dim_strid %d\n", + printf("Pipeline: %d\n", Config::kPipelineId); + printf("Grid: x=%u y=%u z=%u\n", grids.x, grids.y, grids.z); + printf("Block: x=%u y=%u z=%u\n", blocks.x, blocks.y, blocks.z); + printf( + "Host args: batch=%d, height=%d, width=%d, dim_stride=%d, dim_block_h=%d, dim_block_w=%d\n", + a.batch, + a.height, + a.width, + a.dim_stride, + a.dim_block_h, + a.dim_block_w); + printf("kargs: kargs.batch=%d kargs.height=%d kargs.width=%d kargs.dim_stride=%d\n", kargs.batch, kargs.height, kargs.width, @@ -52,22 +104,29 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con return ave_time; } -// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y -#define FOREACH_TRANSPOSE_PARAM(F) \ - F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, false, false) \ - F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, false, false) \ - F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, false, false) +// Param Comb: type_size, block_x & y, WarpNum_x & y +#define FOREACH_TRANSPOSE_PARAM(F) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, true, true, 0) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, false, false, 0) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, true, true, 0) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, false, false, 0) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, true, true, 0) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, false, false, 0) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, true, true, 1) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, false, false, 1) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, true, true, 1) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, false, false, 1) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, true, true, 1) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, false, false, 1) // Macro that defines one static function per line -#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY, PADM, PADN) \ - static float \ - transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY##_##PADM##_##PADN( \ - batched_transpose_kargs& a, ck_tile::stream_config& s) \ - { \ - return batched_transpose_dispatch(a, s); \ +#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, PADM, PADN, PIPE) \ + static float \ + transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##PADM##_##PADN##_v##PIPE( \ + batched_transpose_kargs& a, ck_tile::stream_config& s) \ + { \ + return batched_transpose_dispatch< \ + BatchedTransposeConfig>(a, s); \ } FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN) @@ -76,38 +135,78 @@ float batched_transpose(batched_transpose_trait t, batched_transpose_kargs a, ck_tile::stream_config s) { - if(t.type == "fp8") + if(t.pipeline == "0") { - if(a.height % 64 == 0 && a.width % 64 == 0) + if(t.type == "fp8") { - return transpose_fn_fp8_64_64_64_64_8_8_false_false(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp8_64_64_1_1_false_false_v0(a, s); + } + else + { + return transpose_fn_fp8_64_64_1_1_true_true_v0(a, s); + } } - else + else if(t.type == "fp16") { - return transpose_fn_fp8_64_64_64_64_8_8_true_true(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp16_64_64_1_1_false_false_v0(a, s); + } + else + { + return transpose_fn_fp16_64_64_1_1_true_true_v0(a, s); + } + } + else if(t.type == "bf16") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_bf16_64_64_1_1_false_false_v0(a, s); + } + else + { + return transpose_fn_bf16_64_64_1_1_true_true_v0(a, s); + } } } - else if(t.type == "fp16") + else if(t.pipeline == "1") { - if(a.height % 64 == 0 && a.width % 64 == 0) + if(t.type == "fp8") { - return transpose_fn_fp16_64_64_64_64_8_8_false_false(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp8_64_64_1_1_false_false_v1(a, s); + } + else + { + return transpose_fn_fp8_64_64_1_1_true_true_v1(a, s); + } } - else + else if(t.type == "fp16") { - return transpose_fn_fp16_64_64_64_64_8_8_true_true(a, s); - } - } - else if(t.type == "bf16") - { - if(a.height % 64 == 0 && a.width % 64 == 0) - { - return transpose_fn_bf16_64_64_64_64_8_8_false_false(a, s); - } - else - { - return transpose_fn_bf16_64_64_64_64_8_8_true_true(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp16_64_64_1_1_false_false_v1(a, s); + } + else + { + return transpose_fn_fp16_64_64_1_1_true_true_v1(a, s); + } + } + else if(t.type == "bf16") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_bf16_64_64_1_1_false_false_v1(a, s); + } + else + { + return transpose_fn_bf16_64_64_1_1_true_true_v1(a, s); + } } } + return -1; } diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp index 33b6f0eacf..571386694b 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp @@ -102,7 +102,8 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("seed", "-1", "seed to be used, -1 means random every time") - .insert("kname", "0", "t to 1 will print kernel name"); + .insert("kname", "0", "t to 1 will print kernel name") + .insert("pipeline", "0", "0: no LDS usage, 1: LDS-accelerated (gfx950)"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -121,6 +122,7 @@ bool run_batched_transpose(ck_tile::ArgParser args) int n_repeat = args.get_int("repeat"); std::string layout_in = args.get_str("layout_in"); std::string layout_out = args.get_str("layout_out"); + std::string pipeline = args.get_str("pipeline"); int seed = args.get_int("seed"); int dim_in[4], dim_out[4]; @@ -166,7 +168,7 @@ bool run_batched_transpose(ck_tile::ArgParser args) x_dev.ToDevice(x_host.data()); - auto trait = batched_transpose_trait{prec, layout_in}; + auto trait = batched_transpose_trait{prec, layout_in, pipeline}; uint32_t height = nchw2nhwc ? C : H * W; uint32_t width = nchw2nhwc ? H * W : C; @@ -185,17 +187,15 @@ bool run_batched_transpose(ck_tile::ArgParser args) auto ms = batched_transpose(trait, karg, sc); - std::size_t num_operations = N * C * H * (W - 1); - std::size_t num_bytes = N * C * H * W * sizeof(Type); + std::size_t num_bytes = N * C * H * W * sizeof(Type) * 2; // read + written - float ave_time = ms * 1E-3; float gb_per_sec = num_bytes / ms * 1.E-6; - float tflops = static_cast(num_operations) / ms * 1.E-6; std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H << ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out - << " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops" - << gb_per_sec << " GB/s, " << std::endl; + << " : " << std::endl + << ms << " ms " << std::endl + << gb_per_sec << " GB/s " << std::endl; printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n", prec.c_str(), diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp index 487ddc17b2..c37dbed4b3 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp @@ -14,6 +14,7 @@ struct batched_transpose_trait { std::string type; std::string layout; + std::string pipeline; }; struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs diff --git a/example/ck_tile/35_batched_transpose/script/perf_test.sh b/example/ck_tile/35_batched_transpose/script/perf_test.sh index dde646eb2a..f19242af28 100755 --- a/example/ck_tile/35_batched_transpose/script/perf_test.sh +++ b/example/ck_tile/35_batched_transpose/script/perf_test.sh @@ -5,10 +5,14 @@ EXE=./build/bin/tile_example_batched_transpose +for C in "64" "256" "1024" "4096" "16384"; do +for W in "64" "256" "1024" "4096" "16384"; do for pr in "fp8" "fp16" "bf16"; do -$EXE -pr=$pr -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1024 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1024 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=4096 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC' +for pipeline in "0" "1"; do + +$EXE -pipeline=$pipeline -pr=$pr -N=1 -C=$C -H=1 -W=$W -layout_in='NCHW' -layout_out='NHWC' done +done +done +done \ No newline at end of file diff --git a/example/ck_tile/35_batched_transpose/script/smoke_test.sh b/example/ck_tile/35_batched_transpose/script/smoke_test.sh index 5ba2743364..a8bd692183 100755 --- a/example/ck_tile/35_batched_transpose/script/smoke_test.sh +++ b/example/ck_tile/35_batched_transpose/script/smoke_test.sh @@ -6,25 +6,27 @@ EXE=./build/bin/tile_example_batched_transpose for pr in "fp8" "fp16" "bf16"; do -$EXE -pr=$pr -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=128 -C=1024 -H=64 -W=64 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=128 -C=1024 -H=64 -W=64 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=16 -C=64 -H=32 -W=128 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=16 -C=64 -H=128 -W=32 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=1 -C=2048 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=2048 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=1 -C=1 -H=1024 -W=1024 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1 -H=1024 -W=1024 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=1 -C=64 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=64 -H=1024 -W=1 -layout_in='NHWC' -layout_out='NCHW' +for pipeline in "0" "1"; do +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=1 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=1 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=128 -C=1024 -H=64 -W=64 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=128 -C=1024 -H=64 -W=64 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=16 -C=64 -H=32 -W=128 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=16 -C=64 -H=128 -W=32 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=2048 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=2048 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=1 -H=1024 -W=1024 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=1 -H=1024 -W=1024 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=8 -C=16 -H=8 -W=16 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=8 -C=16 -H=8 -W=16 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=64 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=64 -H=1024 -W=1 -layout_in='NHWC' -layout_out='NCHW' done +done diff --git a/example/ck_tile/37_transpose/CMakeLists.txt b/example/ck_tile/37_transpose/CMakeLists.txt deleted file mode 100644 index d6f374a9b4..0000000000 --- a/example/ck_tile/37_transpose/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -set(TARGET_NAME tile_example_transpose) -add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL transpose_example.cpp transpose_api.cpp) -target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) - -# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) -# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) -target_compile_options(tile_example_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS}) - diff --git a/example/ck_tile/37_transpose/README.md b/example/ck_tile/37_transpose/README.md deleted file mode 100644 index 21578dd00e..0000000000 --- a/example/ck_tile/37_transpose/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# Batched Transpose -This folder contains example for transpose load for architecture gfx950. This transpose load has some constraints in input tile distribution. - -## build -``` -# in the root of ck_tile -mkdir build && cd build -# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ -# Make the transpose executable -make tile_example_transpose -j -``` -This will result in an executable `build/bin/tile_example_transpose` - -## example -``` -args: - -N input batch size (default:2) - -C input channel size. (default:64) - -H input height size. (default:1) - -W input width size. (default:64) - -v whether do CPU validation or not (default: 1) - -layout_in input tensor data layout - NCHW by default - -layout_out output tensor data layout - NHWC by default - -seed seed to be used, -1 means random every time (default:-1) - -k_name t to 1 will print kernel name (default:0) -``` \ No newline at end of file diff --git a/example/ck_tile/37_transpose/batched_transpose_kernel.hpp b/example/ck_tile/37_transpose/batched_transpose_kernel.hpp deleted file mode 100644 index 4681a12cf7..0000000000 --- a/example/ck_tile/37_transpose/batched_transpose_kernel.hpp +++ /dev/null @@ -1,120 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common.hpp" -#include "ck_tile/ops/elementwise.hpp" -#include "ck_tile/host/hip_check_error.hpp" -#include -#include - -namespace ck_tile { - -struct BatchedTransposeHostArgs -{ - const void* p_input; - void* p_output; - index_t batch; - index_t height; - index_t width; - // index_t dim_blocks; - index_t dim_stride; - index_t dim_block_h; - index_t dim_block_w; -}; - -template -struct BatchedTransposeKernel -{ - using Pipeline = remove_cvref_t; - using Problem = remove_cvref_t; - - using Type = typename Problem::DataType; - - struct BatchedTransposeKargs - { - const void* p_input; - void* p_output; - index_t batch; - index_t height; - index_t width; - index_t dim_stride; - }; - - using Kargs = BatchedTransposeKargs; - using Hargs = BatchedTransposeHostArgs; - - CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) - { - size_t grid_size_x = h.dim_block_w; - size_t grid_size_y = h.dim_block_h; - size_t grid_size_z = h.batch; - return dim3(grid_size_x, grid_size_y, grid_size_z); - } - - CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) - { - Kargs k; - k.p_input = h.p_input; - k.p_output = h.p_output; - k.batch = h.batch; - k.height = h.height; - k.width = h.width; - k.dim_stride = h.dim_stride; - return k; - } - - CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; } - - CK_TILE_DEVICE void operator()(Kargs kargs) const - { - __shared__ char smem[Pipeline::GetSmemSize()]; - static constexpr ck_tile::index_t kMPerBlock = Problem::kSecondSizePerBlock; - static constexpr ck_tile::index_t kNPerBlock = Problem::kLeadSizePerBlock; - - const auto iDim = blockIdx.z; - const auto x_m_n = [&]() { - const auto x_dram_naive = make_naive_tensor_view( - static_cast(kargs.p_input) + iDim * kargs.dim_stride, - make_tuple(kargs.height, kargs.width), - make_tuple(kargs.width, 1), - number{}, - number<1>{}); - - return pad_tensor_view(x_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.y * kMPerBlock); - const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.x * kNPerBlock); - - const auto y_n_m = [&]() { - const auto y_dram_naive = make_naive_tensor_view( - static_cast(kargs.p_output) + iDim * kargs.dim_stride, - make_tuple(kargs.width, kargs.height), - make_tuple(kargs.height, 1), - number{}, - number<1>{}); - - return pad_tensor_view(y_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - auto x_block_window = make_tile_window( - x_m_n, - make_tuple(number{}, number{}), - {static_cast(iM), static_cast(iN)}); - - auto y_block_window = make_tile_window( - y_n_m, - make_tuple(number{}, number{}), - {static_cast(iN), static_cast(iM)}); - - Pipeline{}(x_block_window, y_block_window, smem); - } -}; -} // namespace ck_tile diff --git a/example/ck_tile/37_transpose/block_transpose.hpp b/example/ck_tile/37_transpose/block_transpose.hpp deleted file mode 100644 index 5c0baab846..0000000000 --- a/example/ck_tile/37_transpose/block_transpose.hpp +++ /dev/null @@ -1,149 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "transpose_policy.hpp" - -namespace ck_tile { - -template -struct TransposeTraits -{ - static constexpr index_t kLeadDim = kCol; - static constexpr index_t kSecondDim = kRow; -}; - -template -struct TransposeTraits -{ - static constexpr index_t kLeadDim = kRow; - static constexpr index_t kSecondDim = kCol; -}; - -// supports 2D transpose which will store to lds, then use ds_read_b*_tr_b* instruction to get the -// transposed data; Layout in TransposePipelineProblem is the original layout of the data in the -// global memory -template // col number per xdl ops -struct TransposePipelineProblem -{ - static_assert(kRowWarps_ * kColWarps_ * get_warp_size() == kBlockSize_, - "the block size is not correct!"); - using DataType = remove_cvref_t; - using Layout = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t kLeadNumWarps = - TransposeTraits::kLeadDim; - static constexpr index_t kSecondNumWarps = - TransposeTraits::kSecondDim; - static constexpr index_t kLeadSizePerBlock = - TransposeTraits::kLeadDim; - static constexpr index_t kSecondSizePerBlock = - TransposeTraits::kSecondDim; - static constexpr index_t kLeadSizePerXdl = - TransposeTraits::kLeadDim; - static constexpr index_t kSecondSizePerXdl = - TransposeTraits::kSecondDim; - - static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits::kleadDim; - static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits::ksecondDim; - - static_assert(kLeadSizePerBlock % kLeadNumWarps == 0, - "block dim should be divided by warp dim!"); - static_assert(kSecondSizePerBlock % kSecondNumWarps == 0, - "block dim should be divided by warp dim!"); - // how many rows/cols implemented in one warp - static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps; - static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps; - - static_assert(kLeadSizePerWarp % kLeadSizePerXdl == 0, - "warp dim should be divided by xdl dim!"); - static_assert(kSecondSizePerWarp % kSecondSizePerXdl == 0, - "warp dim should be divided by xdl dim!"); - - // warp rows/cols is divided into xdl. - static constexpr index_t kLeadXdlNumPerWarp = kLeadSizePerWarp / kLeadSizePerXdl; - static constexpr index_t kSecondXdlNumPerWarp = kSecondSizePerWarp / kSecondSizePerXdl; - - static_assert(kLeadSizePerXdl % kQuadrantLeadDim == 0, - "xdl dim should be divided by quad dim!"); - static_assert(kSecondSizePerXdl % kQuadrantSecondDim == 0, - "xdl dim should be divided by quad dim!"); - // xdl rows/cols is divided into quadrants. - static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerXdl / kQuadrantLeadDim; - static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerXdl / kQuadrantSecondDim; - - static constexpr index_t kIterationsInSecondDim = - kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size(); -}; - -template -struct BlockTranspose -{ - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - - using DataType = remove_cvref_t; - using Layout = remove_cvref_t; - - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock; - static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock; - - static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize(); } - - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - template - CK_TILE_DEVICE void operator()(const InputTileWindow& input_window, - OutputTileWindow& output_window, - void* __restrict__ p_smem) - { - auto input_tile_window = - make_tile_window(input_window, Policy::template MakeInputDistribution()); - auto output_tile_window = - make_tile_window(output_window, Policy::template MakeOutputDistribution()); - - DataType* p_lds_ptr = static_cast(p_smem); - constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor(); - auto input_lds_block = - make_tensor_view(p_lds_ptr, in_lds_block_desc); - - constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor(); - auto output_lds_block = - make_tensor_view(p_lds_ptr, out_lds_block_desc); - - auto copy_to_lds_window = - make_tile_window(input_lds_block, - make_tuple(number{}, number{}), - {0, 0}); - auto load_from_lds_window = - make_tile_window(output_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeLdsLoadTileDistribution()); - - auto x = load_tile(input_tile_window); - - store_tile(copy_to_lds_window, x); - block_sync_lds(); - - auto y = load_tile_transpose(load_from_lds_window); - - store_tile(output_tile_window, y); - } -}; - -} // namespace ck_tile diff --git a/example/ck_tile/37_transpose/transpose_api.cpp b/example/ck_tile/37_transpose/transpose_api.cpp deleted file mode 100644 index fe184b4023..0000000000 --- a/example/ck_tile/37_transpose/transpose_api.cpp +++ /dev/null @@ -1,59 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include "transpose_example.hpp" -#include - -template -float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) -{ - uint32_t dim_block_h = (a.height + block_y - 1) / block_y; - uint32_t dim_block_w = (a.width + block_x - 1) / block_x; - uint32_t dim_stride = a.height * a.width; - - a.dim_stride = dim_stride; - a.dim_block_h = dim_block_h; - a.dim_block_w = dim_block_w; - - using ts_problem = ck_tile::TransposePipelineProblem; - using ts_pipeline = ck_tile::BlockTranspose; - - using kernel = ck_tile::BatchedTransposeKernel; - - auto kargs = kernel::MakeKargs(a); - - const dim3 grids = kernel::GridSize(a); - constexpr dim3 blocks = kernel::BlockSize(); - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); - - return ave_time; -} - -float batched_transpose(batched_transpose_trait t, - batched_transpose_kargs a, - ck_tile::stream_config s) -{ - if(t.type == "fp16") - { - return batched_transpose_dispatch(a, s); - } - else if(t.type == "fp8") - { - return batched_transpose_dispatch(a, s); - } - - return -1; -} diff --git a/example/ck_tile/37_transpose/transpose_example.cpp b/example/ck_tile/37_transpose/transpose_example.cpp deleted file mode 100644 index ac27ca7911..0000000000 --- a/example/ck_tile/37_transpose/transpose_example.cpp +++ /dev/null @@ -1,257 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "transpose_example.hpp" - -#if 0 -template -void dump_host_tensor_4d(const ck_tile::HostTensor& x) -{ - auto len = x.get_lengths(); - assert(len.size() == 4); - std::cout << "["; - for(size_t i = 0; i < len[0]; i++) - { - std::cout << i << ": ["; - for(size_t j = 0; j < len[1]; j++) - { - std::cout << j << ": ["; - for(size_t k = 0; k < len[2]; k++) - { - std::cout << k << ": ["; - for(size_t v = 0; v < len[3]; v++) - { - if constexpr(std::is_same_v) - { - auto m = - ck_tile::type_convert(x(std::vector{i, j, k, v})); - - std::cout << m; - if(v != len[3] - 1) - std::cout << ","; - } - else - { - std::cout << x(std::vector{i, j, k, v}) << " "; - } - } - std::cout << "]" << std::endl; - } - std::cout << "]" << std::endl; - } - std::cout << std::endl; - } - std::cout << "--------------------" << std::endl; -} -#endif - -// different threshold for different dtype -template -auto get_elimit(std::string /*init_method*/) -{ - double rtol = 1e-3; - double atol = 1e-3; - return ck_tile::make_tuple(rtol, atol); -} - -template <> -auto get_elimit(std::string /*init_method*/) -{ - double rtol = 1e-2; - double atol = 1e-2; - return ck_tile::make_tuple(rtol, atol); -} - -template <> -auto get_elimit(std::string init_method) -{ - if(init_method == "ui" || init_method == "ni") - { - unsigned max_rounding_point_distance = 0; - double atol = 2e-3; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } - else - { - unsigned max_rounding_point_distance = 1; - double atol = 0.0625; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } -} - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("v", "1", "whether do CPU validation or not") - .insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)") - .insert("N", "2", "input batch size. ") - .insert("C", "64", "input channel size.") - .insert("H", "1", "input height size.") - .insert("W", "64", "input width size. ") - .insert("layout_in", "NCHW", "input tensor data layout - NCHW by default") - .insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ") - .insert("seed", "-1", "seed to be used, -1 means random every time") - .insert("kname", "0", "t to 1 will print kernel name"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -template -bool run_batched_transpose(ck_tile::ArgParser args) -{ - int validate = args.get_int("v"); - std::string prec = args.get_str("pr"); - int N = args.get_int("N"); - int C = args.get_int("C"); - int H = args.get_int("H"); - int W = args.get_int("W"); - std::string layout_in = args.get_str("layout_in"); - std::string layout_out = args.get_str("layout_out"); - int seed = args.get_int("seed"); - - int dim_in[4], dim_out[4]; - int stride_dim_in[4], stride_dim_out[4]; - bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC"; - bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW"; - assert(nchw2nhwc != nhwc2nchw); - (void)nhwc2nchw; - - dim_in[0] = N; - dim_in[1] = nchw2nhwc ? C : H; - dim_in[2] = nchw2nhwc ? H : W; - dim_in[3] = nchw2nhwc ? W : C; - dim_out[0] = N; - dim_out[1] = nchw2nhwc ? H : C; - dim_out[2] = nchw2nhwc ? W : H; - dim_out[3] = nchw2nhwc ? C : W; - stride_dim_in[0] = C * H * W; - stride_dim_in[1] = nchw2nhwc ? H * W : C * W; - stride_dim_in[2] = nchw2nhwc ? W : C; - stride_dim_in[3] = 1; - stride_dim_out[0] = C * H * W; - stride_dim_out[1] = nchw2nhwc ? C * W : H * W; - stride_dim_out[2] = nchw2nhwc ? C : W; - stride_dim_out[3] = 1; - - if(seed < 0) - { - seed = std::time(nullptr); - } - - ck_tile::HostTensor x_host( - {dim_in[0], dim_in[1], dim_in[2], dim_in[3]}, - {stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]}); - ck_tile::HostTensor y_host( - {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, - {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); - - ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); - - ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes()); - - x_dev.ToDevice(x_host.data()); - - auto trait = batched_transpose_trait{prec, layout_in}; - - uint32_t height = nchw2nhwc ? C : H * W; - uint32_t width = nchw2nhwc ? H * W : C; - - batched_transpose_kargs karg = [&]() { - batched_transpose_kargs a_; - a_.p_input = x_dev.GetDeviceBuffer(); - a_.p_output = y_dev.GetDeviceBuffer(); - a_.batch = N; - a_.height = height; - a_.width = width; - return a_; - }(); - - ck_tile::stream_config sc{nullptr, true}; - - auto ms = batched_transpose(trait, karg, sc); - - std::size_t num_operations = N * C * H * (W - 1); - std::size_t num_bytes = N * C * H * W * sizeof(Type); - - float ave_time = ms * 1E-3; - float gb_per_sec = num_bytes / ms * 1.E-6; - float tflops = static_cast(num_operations) / ms * 1.E-6; - - std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H - << ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out - << " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops" - << gb_per_sec << " GB/s, " << std::endl; - - printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n", - prec.c_str(), - N, - C, - H, - W, - layout_in.c_str(), - ms); - if(ms < 0) - printf("not supported\n"); - fflush(stdout); - - if(ms < 0) - { - return false; - } - - y_dev.FromDevice(y_host.data()); - - bool rtn = true; - if(validate) - { - // this host buffer will not copy to GPU, so no need use stride - ck_tile::HostTensor y_ref( - {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, - {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); - - ck_tile::reference_batched_transpose(x_host, y_ref, layout_in, layout_out); - - auto [rtol, atol] = get_elimit(""); - - rtn &= ck_tile::check_err( - y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol); - } - printf("valid:%s\n", rtn ? "y" : "n"); - fflush(stdout); - return rtn; -} - -int main(int argc, char** argv) -{ - auto [result, args] = create_args(argc, argv); - if(!result) - return -1; - std::string prec = args.get_str("pr"); - - bool r = true; - if(prec.compare("fp16") == 0) - { - r &= run_batched_transpose(args); - } - else if(prec.compare("fp8") == 0) - { - r &= run_batched_transpose(args); - } - else - { - std::cerr << "Unsupported data type: " << prec << std::endl; - } - - return r ? 0 : -1; -} diff --git a/example/ck_tile/37_transpose/transpose_example.hpp b/example/ck_tile/37_transpose/transpose_example.hpp deleted file mode 100644 index 8128d583ef..0000000000 --- a/example/ck_tile/37_transpose/transpose_example.hpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "ck_tile/ops/reduce.hpp" -#include "batched_transpose_kernel.hpp" -#include "block_transpose.hpp" -#include "transpose_policy.hpp" - -#include -#include - -#pragma once - -struct batched_transpose_trait -{ - std::string type; - std::string layout; -}; - -struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs -{ -}; - -float batched_transpose(batched_transpose_trait t, - batched_transpose_kargs a, - ck_tile::stream_config s); diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index f85346e9be..630b96ede0 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -22,5 +22,4 @@ add_subdirectory(19_gemm_multi_d) add_subdirectory(20_grouped_convolution) add_subdirectory(21_elementwise) add_subdirectory(35_batched_transpose) -add_subdirectory(37_transpose) add_subdirectory(38_block_scale_gemm) diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index 200e2a618c..ca0088c812 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -4,6 +4,10 @@ #pragma once #include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" diff --git a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp index 4c3aa2ba29..a89a190489 100644 --- a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp +++ b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp @@ -32,7 +32,7 @@ struct BatchedTransposeKernel using Pipeline = remove_cvref_t; using Problem = remove_cvref_t; - using Type = typename Problem::InputType; + using Type = typename Problem::DataType; struct BatchedTransposeKargs { @@ -67,7 +67,7 @@ struct BatchedTransposeKernel return k; } - CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::kBlockSize; } CK_TILE_DEVICE void operator()(Kargs kargs) const { diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp new file mode 100644 index 0000000000..e344c24bf5 --- /dev/null +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +struct BatchedTransposeCommonPolicy +{ + CK_TILE_DEVICE static constexpr auto TileAccessPattern = + tile_distribution_pattern::thread_raked; + + template + CK_TILE_DEVICE static constexpr auto MakeInputDistribution() + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t LeadDimPerBlock = Problem::kMPerBlock; + constexpr index_t SecondDimPerBlock = Problem::kNPerBlock; + + constexpr index_t kVectorSize = Problem::VectorSizeOutput; + + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp new file mode 100644 index 0000000000..ef0b7fa229 --- /dev/null +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +template +struct BatchedTransposeLdsPipeline +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using DataType = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock; + static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock; + + static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize(); } + + CK_TILE_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE void operator()(const InputTileWindow& input_window, + OutputTileWindow& output_window) + { + __shared__ char smem[GetSmemSize()]; + auto input_tile_window = + make_tile_window(input_window, Policy::template MakeInputDistribution()); + auto output_tile_window = + make_tile_window(output_window, Policy::template MakeOutputDistribution()); + + DataType* p_lds_ptr = reinterpret_cast(smem); + constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor(); + auto input_lds_block = + make_tensor_view(p_lds_ptr, in_lds_block_desc); + + constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor(); + auto output_lds_block = + make_tensor_view(p_lds_ptr, out_lds_block_desc); + + auto copy_to_lds_window = + make_tile_window(input_lds_block, + make_tuple(number{}, number{}), + {0, 0}); + auto load_from_lds_window = + make_tile_window(output_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeLdsLoadTileDistribution()); + + auto x = load_tile(input_tile_window); + + store_tile(copy_to_lds_window, x); + block_sync_lds(); + + auto y = load_tile_transpose(load_from_lds_window); + + store_tile(output_tile_window, y); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/37_transpose/transpose_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp similarity index 65% rename from example/ck_tile/37_transpose/transpose_policy.hpp rename to include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp index b7e52a94f7..77c3db9c06 100644 --- a/example/ck_tile/37_transpose/transpose_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp @@ -1,24 +1,17 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" +#include "batched_transpose_common_policy.hpp" namespace ck_tile { -struct TransposePolicy +struct BatchedTransposeLdsPolicy : public BatchedTransposeCommonPolicy { - static constexpr auto TileAccessPattern = tile_distribution_pattern::thread_raked; - template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSize() - { - return 16 / sizeof(typename Problem::DataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + CK_TILE_DEVICE static constexpr index_t GetSmemSize() { return integer_least_multiple( sizeof(typename Problem::DataType) * @@ -27,23 +20,7 @@ struct TransposePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution() - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t LeadDimPerBlock = Problem::kLeadSizePerBlock; - constexpr index_t SecondDimPerBlock = Problem::kSecondSizePerBlock; - constexpr index_t VecLoadSize = 16 / sizeof(typename Problem::DataType); - - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution() + CK_TILE_DEVICE static constexpr auto MakeOutputDistribution() { constexpr auto input_dstr = MakeLdsLoadTileDistribution(); @@ -56,11 +33,11 @@ struct TransposePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor() + CK_TILE_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor() { constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock; constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock; - constexpr index_t kVectorSize = 16 / sizeof(typename Problem::DataType); + constexpr index_t kVectorSize = Problem::LDSVectorSize; constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, @@ -82,12 +59,11 @@ struct TransposePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor() + CK_TILE_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor() { constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock; constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock; - - constexpr index_t kVectorSize = 8 / sizeof(typename Problem::DataType); + constexpr index_t kVectorSize = Problem::LDSVectorSize; constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, @@ -109,25 +85,19 @@ struct TransposePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeLdsLoadTileDistribution() { using DataType = typename Problem::DataType; - // Extract base dimensions from the traits - constexpr index_t kBaseLeadDim = LaneGroupTransposeTraits::kleadDim; - constexpr index_t kBaseSecondDim = LaneGroupTransposeTraits::ksecondDim; - // Calculate block-level dimensions - constexpr index_t kLead = Problem::kLeadSizePerXdl; - constexpr index_t kSecond = Problem::kSecondSizePerXdl; - constexpr index_t kLeadIterPerWarp = Problem::kLeadXdlNumPerWarp; - constexpr index_t kSecondIterPerWarp = Problem::kSecondXdlNumPerWarp; + constexpr index_t kLeadIterPerWarp = 1; + constexpr index_t kSecondIterPerWarp = 1; constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps; constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps; // Calculate repetitions of base pattern - constexpr index_t kLeadRepetitions = kLead / kBaseLeadDim; - constexpr index_t kSecondRepetitions = kSecond / kBaseSecondDim; + constexpr index_t kLeadRepetitions = Problem::kQuadNumPerLeadDim; + constexpr index_t kSecondRepetitions = Problem::kQuadNumPerSecondDim; constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim; constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations; diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp new file mode 100644 index 0000000000..491db37564 --- /dev/null +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// supports 2D transpose which will store to lds, +// then use ds_read_b*_tr_b* instruction to get the transposed data +template + typename NumWarps, + bool kPadM_, + bool kPadN_> +struct BatchedTransposeLdsProblem +{ + using DataType = remove_cvref_t; + + static constexpr index_t kRowWarps_ = NumWarps::at(number<1>{}); + static constexpr index_t kColWarps_ = NumWarps::at(number<0>{}); + static constexpr index_t kBlockSize_ = get_warp_size() * kRowWarps_ * kColWarps_; + static constexpr index_t kRowPerBlock_ = BlockTile::at(number<1>{}); + static constexpr index_t kColPerBlock_ = BlockTile::at(number<0>{}); + + static constexpr index_t kBlockSize = kBlockSize_; + // warps per block + static constexpr index_t kLeadNumWarps = kRowWarps_; + static constexpr index_t kSecondNumWarps = kColWarps_; + + static constexpr index_t kLeadSizePerBlock = kRowPerBlock_; + static constexpr index_t kSecondSizePerBlock = kColPerBlock_; + + static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits::kleadDim; + static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits::ksecondDim; + + static_assert(kLeadSizePerBlock % kLeadNumWarps == 0, + "block dim should be divided by warp count!"); + static_assert(kSecondSizePerBlock % kSecondNumWarps == 0, + "block dim should be divided by warp count!"); + // rows/cols per warp + static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps; + static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps; + + static_assert(kLeadSizePerWarp % kQuadrantLeadDim == 0, + "xdl dim should be divided by quad dim!"); + static_assert(kSecondSizePerWarp % kQuadrantSecondDim == 0, + "xdl dim should be divided by quad dim!"); + // xdl rows/cols is divided into quadrants. + static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerWarp / kQuadrantLeadDim; + static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerWarp / kQuadrantSecondDim; + + static constexpr index_t kIterationsInSecondDim = + kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size(); + + // definitions to adapt to BatchedTransposeKernel + + // FIXME: support padding + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + + static constexpr auto kMPerBlock = kLeadSizePerBlock; + static constexpr auto kNPerBlock = kSecondSizePerBlock; + + // 128-bit is the max single-instruction bandwidth for load/store + static constexpr index_t MaxLoadStoreSize = 16; + static constexpr auto VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType); + static constexpr auto VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType); + static constexpr auto LDSVectorSize = MaxLoadStoreSize / sizeof(DataType); +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp index e815313c06..633827f3c3 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp @@ -5,8 +5,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" -#include -#include namespace ck_tile { @@ -14,15 +12,8 @@ template struct BatchedTransposePipeline { // TODO: this kernel only support warp per row - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using InputType = ck_tile::remove_cvref_t; - static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock; - static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock; - static constexpr index_t AlignmentM = Problem::AlignmentM; - static constexpr index_t AlignmentN = Problem::AlignmentN; - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; template CK_TILE_DEVICE auto operator()(const InputWindow& input_window, OutputWindow& out_window) @@ -32,7 +23,7 @@ struct BatchedTransposePipeline auto input_tile = load_tile(inp_win); - auto output_tile = make_static_distributed_tensor( + auto output_tile = make_static_distributed_tensor( Policy::template MakeOutputDistribution()); transpose_tile2d(output_tile, input_tile); diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp index dd9a6d79a8..5238fecdc5 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp @@ -4,43 +4,25 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/softmax.hpp" -#include "ck_tile/ops/topk.hpp" +#include "batched_transpose_common_policy.hpp" namespace ck_tile { -struct BatchedTransposePolicy +struct BatchedTransposePolicy : public BatchedTransposeCommonPolicy { template - CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution() - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::kMPerBlock; - constexpr index_t NPerBlock = Problem::kNPerBlock; - constexpr index_t VecLoadSize = Problem::VectorSizeInput; - using TileEncodingPattern = - TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution() + CK_TILE_DEVICE static constexpr auto MakeOutputDistribution() { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t MPerBlock = Problem::kMPerBlock; constexpr index_t NPerBlock = Problem::kNPerBlock; constexpr index_t VecLoadSize = Problem::VectorSizeOutput; - using TileEncodingPattern = - TileDistributionEncodingPattern2D; + using TileEncodingPattern = TileDistributionEncodingPattern2D; return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } }; diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp index fd5ea004b6..2be979723b 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp @@ -6,42 +6,31 @@ #include "ck_tile/core.hpp" #include -#define VectorLoadSize 16 - namespace ck_tile { -template // Sequence<... struct BatchedTransposeProblem { - using InputType = remove_cvref_t; + using DataType = remove_cvref_t; - static constexpr index_t kMPerThread = ThreadTile::at(number<0>{}); - static constexpr index_t kNPerThread = ThreadTile::at(number<1>{}); - - static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); - static constexpr index_t kNPerWarp = WarpTile::at(number<1>{}); - - static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; - static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread; + static constexpr index_t kMPerWarp = WarpLayout::at(number<0>{}); + static constexpr index_t kNPerWarp = WarpLayout::at(number<1>{}); static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); static constexpr index_t kNPerBlock = BlockTile::at(number<1>{}); - static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp; - static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp; - - static constexpr index_t kBlockSize = - kMThreadPerWarp * kNThreadPerWarp * kMWarpPerBlock * kNWarpPerBlock; + static constexpr index_t kBlockSize = kMPerWarp * kNPerWarp * get_warp_size(); static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; - static constexpr index_t VectorSizeInput = kPadM ? 1 : VectorLoadSize / sizeof(InputType); - static constexpr index_t VectorSizeOutput = kPadN ? 1 : VectorLoadSize / sizeof(InputType); + // 128-bit is the max single-instruction bandwidth for load/store + static constexpr index_t MaxLoadStoreSize = 16; + static constexpr index_t VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType); + static constexpr index_t VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType); }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 9d00de5f73..c201293389 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -29,9 +29,9 @@ #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" -#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" diff --git a/test/ck_tile/batched_transpose/batched_transpose_api.cpp b/test/ck_tile/batched_transpose/batched_transpose_api.cpp index 27c2269a06..973a1967f2 100644 --- a/test/ck_tile/batched_transpose/batched_transpose_api.cpp +++ b/test/ck_tile/batched_transpose/batched_transpose_api.cpp @@ -7,8 +7,6 @@ template float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) @@ -20,11 +18,10 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con a.dim_block_w = block_x; using block_tile = ck_tile::sequence; - using warp_tile = ck_tile::sequence; - using thread_tile = ck_tile::sequence; + using warp_layout = ck_tile::sequence; using ts_problem = - ck_tile::BatchedTransposeProblem; + ck_tile::BatchedTransposeProblem; using ts_pipeline = ck_tile::BatchedTransposePipeline; using kernel = ck_tile::BatchedTransposeKernel; @@ -53,21 +50,20 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con } // Param Comb: type_size, block_x & y, warp_x & y, thread_x & y -#define FOREACH_TRANSPOSE_PARAM(F) \ - F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, false, false) \ - F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, false, false) \ - F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, false, false) +#define FOREACH_TRANSPOSE_PARAM(F) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, true, true) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, false, false) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, true, true) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, false, false) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, true, true) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, false, false) // Macro that defines one static function per line -#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY, PADM, PADN) \ - static float \ - transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY##_##PADM##_##PADN( \ - batched_transpose_kargs& a, ck_tile::stream_config& s) \ - { \ - return batched_transpose_dispatch(a, s); \ +#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, PADM, PADN) \ + static float transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##PADM##_##PADN( \ + batched_transpose_kargs& a, ck_tile::stream_config& s) \ + { \ + return batched_transpose_dispatch(a, s); \ } FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN) @@ -80,33 +76,33 @@ float batched_transpose(batched_transpose_trait t, { if(a.height % 64 == 0 && a.width % 64 == 0) { - return transpose_fn_fp8_64_64_64_64_8_8_false_false(a, s); + return transpose_fn_fp8_64_64_1_1_false_false(a, s); } else { - return transpose_fn_fp8_64_64_64_64_8_8_true_true(a, s); + return transpose_fn_fp8_64_64_1_1_true_true(a, s); } } else if(t.type == "fp16") { if(a.height % 64 == 0 && a.width % 64 == 0) { - return transpose_fn_fp16_64_64_64_64_8_8_false_false(a, s); + return transpose_fn_fp16_64_64_1_1_false_false(a, s); } else { - return transpose_fn_fp16_64_64_64_64_8_8_true_true(a, s); + return transpose_fn_fp16_64_64_1_1_true_true(a, s); } } else if(t.type == "bf16") { if(a.height % 64 == 0 && a.width % 64 == 0) { - return transpose_fn_bf16_64_64_64_64_8_8_false_false(a, s); + return transpose_fn_bf16_64_64_1_1_false_false(a, s); } else { - return transpose_fn_bf16_64_64_64_64_8_8_true_true(a, s); + return transpose_fn_bf16_64_64_1_1_true_true(a, s); } } return -1;