Vectorized Transpose for Batched Transpose CK Tile Operator (#2131)

* Shared Memory for single data point

* CKTile Transpose vectorize CP1

* CKTile Transpose vectorize CP2

* CKTile Transpose vectorize CP2.1

* fixed the compile error of the transpose tile 2d

* Have the correct result for the current test sample

* Changes to printing tensor

* fp8 support added

* Debugging for transpose

* solving the corner issue

* Changed padding flag

* Intermideate Debugging

* Intermidiate Debugging

* Intermediate Debugging

* Finished debugging of the transpose op

* Code Cleanup

* Adding edge case smoke tests

* Adding Transpose test to CI/CD

* Adding Transpose test to CI/CD

* Adding Transpose test to CI/CD

* Addressing Review Comment

* Addressing Comments

* Addressing Comments

* Measuring Perf Tests

* Code Cleanup

* Changlog

* Added the running iterations

* clang format

* Fix the changelog

* Fix the compilation error

* change the printing factor

---------

Co-authored-by: ThruptiRajLakshmanaGowda <tlakshma@amd.com>

[ROCm/composable_kernel commit: 9d1e44e56a]
This commit is contained in:
Thomas Ning
2025-05-12 00:41:45 -07:00
committed by GitHub
parent 2dc3190518
commit 1f7c2a88c0
14 changed files with 311 additions and 152 deletions

View File

@@ -19,7 +19,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
### Optimized
None
* Added Vectorize Transpose optimization for CK Tile (#2131)
### Fixes

73
Jenkinsfile vendored
View File

@@ -362,6 +362,20 @@ def cmake_build(Map conf=[:]){
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
}
}
if (params.RUN_CK_TILE_TRANSPOSE_TESTS){
try{
archiveArtifacts "perf_transpose_*.log"
if (arch_type == 1){
stash includes: "perf_transpose_**_gfx90a.log", name: "perf_transpose_log_gfx90a"
}
else if (arch_type == 2){
stash includes: "perf_transpose_**_gfx942.log", name: "perf_transpose_log_gfx942"
}
}
catch(Exception err){
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
}
}
if (params.RUN_CK_TILE_GEMM_TESTS){
try{
archiveArtifacts "perf_tile_gemm_**.log"
@@ -698,6 +712,15 @@ def process_results(Map conf=[:]){
echo "could not locate the FMHA performance logs: ${err.getMessage()}."
}
}
if (params.RUN_CK_TILE_TRANSPOSE_TESTS){
try{
unstash "perf_transpose_log_gfx942"
unstash "perf_transpose_log_gfx90a"
}
catch(Exception err){
echo "could not locate the Transpose performance logs: ${err.getMessage()}."
}
}
if (params.RUN_CK_TILE_GEMM_TESTS){
try{
unstash "perf_tile_gemm_log_gfx942"
@@ -753,7 +776,7 @@ def process_results(Map conf=[:]){
}
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.4;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.4;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
0 21 * * * % ROCMVERSION=6.4;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
@@ -833,6 +856,10 @@ pipeline {
name: "RUN_CK_TILE_FMHA_TESTS",
defaultValue: false,
description: "Run the ck_tile FMHA tests (default: OFF)")
booleanParam(
name: "RUN_CK_TILE_TRANSPOSE_TESTS",
defaultValue: false,
description: "Run the ck_tile Transpose tests (default: OFF)")
booleanParam(
name: "RUN_CK_TILE_GEMM_TESTS",
defaultValue: false,
@@ -1032,6 +1059,50 @@ pipeline {
}
}
}
stage("Run CK_TILE_TRANSPOSE Tests")
{
parallel
{
stage("Run CK_TILE_TRANSPOSE Tests on gfx90a")
{
when {
beforeAgent true
expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
make -j64 tile_example_batched_transpose && \
cd ../ &&
example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
stage("Run CK_TILE_TRANSPOSE Tests on gfx942")
{
when {
beforeAgent true
expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx942") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
make -j64 tile_example_batched_transpose && \
cd ../ &&
example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
}
}
stage("Run CK_TILE_GEMM Tests")
{
parallel

View File

@@ -24,4 +24,6 @@ args:
-layout_out output tensor data layout - NHWC by default
-seed seed to be used, -1 means random every time (default:-1)
-k_name t to 1 will print kernel name (default:0)
-warmup warmup iterations to run this kernel (default:50)
-repeat number of iterations to run this kernel (default:100)
```

View File

@@ -1,7 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "batched_transpose_example.hpp"
#include <iostream>
template <typename ts_type,
ck_tile::index_t block_x,
@@ -9,23 +8,23 @@ template <typename ts_type,
ck_tile::index_t warp_x,
ck_tile::index_t warp_y,
ck_tile::index_t thread_x,
ck_tile::index_t thread_y>
ck_tile::index_t thread_y,
bool kPadM,
bool kPadN>
float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s)
{
uint32_t dim_block_h = (a.height + block_y - 1) / block_y;
uint32_t dim_block_w = (a.width + block_x - 1) / block_x;
uint32_t dim_stride = a.height * a.width;
uint32_t dim_stride = a.height * a.width;
a.dim_stride = dim_stride;
a.dim_block_h = dim_block_h;
a.dim_block_w = dim_block_w;
a.dim_block_h = block_y;
a.dim_block_w = block_x;
using block_tile = ck_tile::sequence<block_x, block_y>;
using warp_tile = ck_tile::sequence<warp_x, warp_y>;
using thread_tile = ck_tile::sequence<thread_x, thread_y>;
using ts_problem =
ck_tile::BatchedTransposeProblem<ts_type, block_tile, warp_tile, thread_tile>;
ck_tile::BatchedTransposeProblem<ts_type, block_tile, warp_tile, thread_tile, kPadM, kPadN>;
using ts_pipeline = ck_tile::BatchedTransposePipeline<ts_problem>;
using kernel = ck_tile::BatchedTransposeKernel<ts_pipeline>;
@@ -35,25 +34,40 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
printf("Grid: %u %u %u\n", grids.x, grids.y, grids.z);
printf("Block: %u %u %u\n", blocks.x, blocks.y, blocks.z);
printf("kargs: kargs.batch %d kargs.height %d kargs.width %d kargs.dim_strid %d\n",
kargs.batch,
kargs.height,
kargs.width,
kargs.dim_stride);
printf("Launching Kernel...\n");
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
printf("Kernel finished...\n");
return ave_time;
}
// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y
#define FOREACH_TRANSPOSE_PARAM(F) \
F(fp16, ck_tile::fp16_t, 16, 16, 8, 8, 1, 1) \
F(bf16, ck_tile::bf16_t, 16, 16, 8, 8, 1, 1) \
F(fp32, ck_tile::fp32_t, 16, 16, 8, 8, 1, 1) \
F(int8, ck_tile::int8_t, 16, 16, 8, 8, 1, 1)
#define FOREACH_TRANSPOSE_PARAM(F) \
F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, true, true) \
F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, false, false) \
F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, true, true) \
F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, false, false) \
F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, true, true) \
F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, false, false)
// Macro that defines one static function per line
#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY) \
static float transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY( \
batched_transpose_kargs& a, ck_tile::stream_config& s) \
{ \
return batched_transpose_dispatch<REAL_TYPE, BX, BY, WX, WY, TX, TY>(a, s); \
#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY, PADM, PADN) \
static float \
transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY##_##PADM##_##PADN( \
batched_transpose_kargs& a, ck_tile::stream_config& s) \
{ \
return batched_transpose_dispatch<REAL_TYPE, BX, BY, WX, WY, TX, TY, PADM, PADN>(a, s); \
}
FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN)
@@ -62,21 +76,38 @@ float batched_transpose(batched_transpose_trait t,
batched_transpose_kargs a,
ck_tile::stream_config s)
{
if(t.type == "fp16")
if(t.type == "fp8")
{
return transpose_fn_fp16_16_16_8_8_1_1(a, s);
if(a.height % 64 == 0 && a.width % 64 == 0)
{
return transpose_fn_fp8_64_64_64_64_8_8_false_false(a, s);
}
else
{
return transpose_fn_fp8_64_64_64_64_8_8_true_true(a, s);
}
}
else if(t.type == "fp16")
{
if(a.height % 64 == 0 && a.width % 64 == 0)
{
return transpose_fn_fp16_64_64_64_64_8_8_false_false(a, s);
}
else
{
return transpose_fn_fp16_64_64_64_64_8_8_true_true(a, s);
}
}
else if(t.type == "bf16")
{
return transpose_fn_bf16_16_16_8_8_1_1(a, s);
}
else if(t.type == "fp32")
{
return transpose_fn_fp32_16_16_8_8_1_1(a, s);
}
else if(t.type == "int8")
{
return transpose_fn_int8_16_16_8_8_1_1(a, s);
if(a.height % 64 == 0 && a.width % 64 == 0)
{
return transpose_fn_bf16_64_64_64_64_8_8_false_false(a, s);
}
else
{
return transpose_fn_bf16_64_64_64_64_8_8_true_true(a, s);
}
}
return -1;
}

View File

@@ -21,13 +21,13 @@ void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
std::cout << "[";
for(size_t i = 0; i < len[0]; i++)
{
std::cout << i << ": [";
std::cout << "Batch " << i << ":" << std::endl;
for(size_t j = 0; j < len[1]; j++)
{
std::cout << j << ": [";
std::cout << " Channel " << j << ":" << std::endl;
for(size_t k = 0; k < len[2]; k++)
{
std::cout << k << ": [";
std::cout << " Row " << k << ": ";
for(size_t v = 0; v < len[3]; v++)
{
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
@@ -41,15 +41,15 @@ void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
}
else
{
std::cout << x(std::vector<std::size_t>{i, j, k, v}) << " ";
std::cout << static_cast<int>(x(std::vector<std::size_t>{i, j, k, v}))
<< " ";
}
}
std::cout << "]" << std::endl;
std::cout << std::endl;
}
std::cout << "]" << std::endl;
}
std::cout << std::endl;
}
std::cout << "]" << std::endl;
std::cout << "--------------------" << std::endl;
}
#endif
@@ -93,12 +93,14 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "whether do CPU validation or not")
.insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
.insert("N", "2", "input batch size. ")
.insert("C", "16", "input channel size.")
.insert("H", "1", "input height size.")
.insert("W", "16", "input width size. ")
.insert("N", "1", "input batch size. ")
.insert("C", "64", "input channel size.")
.insert("H", "18", "input height size.")
.insert("W", "64", "input width size. ")
.insert("layout_in", "NCHW", "input tensor data layout - NCHW by default")
.insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "t to 1 will print kernel name");
@@ -115,6 +117,8 @@ bool run_batched_transpose(ck_tile::ArgParser args)
int C = args.get_int("C");
int H = args.get_int("H");
int W = args.get_int("W");
int n_warmup = args.get_int("warmup");
int n_repeat = args.get_int("repeat");
std::string layout_in = args.get_str("layout_in");
std::string layout_out = args.get_str("layout_out");
int seed = args.get_int("seed");
@@ -177,7 +181,7 @@ bool run_batched_transpose(ck_tile::ArgParser args)
return a_;
}();
ck_tile::stream_config sc{nullptr, true};
ck_tile::stream_config sc{nullptr, true, n_warmup, n_repeat};
auto ms = batched_transpose(trait, karg, sc);
@@ -202,7 +206,8 @@ bool run_batched_transpose(ck_tile::ArgParser args)
layout_in.c_str(),
ms);
if(ms < 0)
printf("not supported\n");
printf("------------------------------------not "
"supported-------------------------------------\n");
fflush(stdout);
if(ms < 0)
@@ -227,7 +232,9 @@ bool run_batched_transpose(ck_tile::ArgParser args)
rtn &= ck_tile::check_err(
y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol);
}
printf("valid:%s\n", rtn ? "y" : "n");
printf("-----------------------------------------------------------------------valid:%s--------"
"--------------------------------------------------------------------\n",
rtn ? "y" : "n");
fflush(stdout);
return rtn;
}
@@ -240,9 +247,9 @@ int main(int argc, char** argv)
std::string prec = args.get_str("pr");
bool r = true;
if(prec.compare("fp32") == 0)
if(prec.compare("fp8") == 0)
{
r &= run_batched_transpose<float>(args);
r &= run_batched_transpose<ck_tile::fp8_t>(args);
}
else if(prec.compare("fp16") == 0)
{
@@ -252,10 +259,6 @@ int main(int argc, char** argv)
{
r &= run_batched_transpose<ck_tile::bf16_t>(args);
}
else if(prec.compare("int8") == 0)
{
r &= run_batched_transpose<ck_tile::int8_t>(args);
}
return r ? 0 : -1;
}

View File

@@ -0,0 +1,11 @@
#!/bin/sh
EXE=./build/bin/tile_example_batched_transpose
for pr in "fp8" "fp16" "bf16"; do
$EXE -pr=$pr -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=1 -C=1024 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=1 -C=1024 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=1 -C=4096 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC'
done

View File

@@ -0,0 +1,38 @@
#!/bin/bash
#
# in order to run this script you'd first need to build the tile_example_batched_transpose executables in ../build/bin/
#
# run the script as "./run_full_test.sh <tag for your test environment> <branch name> <host name> <gpu_arch>
# input arguments:
# environment tag : a string describing the specifics of your test environment
# branch name : name of the branch in git repo (git status | grep -e 'On branch')
# host name : $hostname
# gpu architecture: e.g., gfx90a, or gfx942, etc.
#get the command line arguments:
export env_type=$1
echo 'Environment type: ' $env_type
export branch=$2
echo 'Branch name: ' $branch
export host_name=$3
echo 'Host name: ' $host_name
export GPU_arch=$4
echo 'GPU_arch: ' $GPU_arch
function print_log_header(){
rm -f $1;
echo 'On branch ' $3 &> $1;
echo 'Node name: ' $4 >> $1;
#get GPU_arch and number of compute units from rocminfo
echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1;
rocminfo | grep "Compute Unit:" >> $1;
hipcc --version | grep -e 'HIP version' >> $1;
echo 'Environment type: ' $2 >> $1;
/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1;
}
#run verification tests
example/ck_tile/35_batched_transpose/script/smoke_test.sh
#run performance benchmarks

View File

@@ -2,10 +2,26 @@
EXE=./build/bin/tile_example_batched_transpose
for pr in "fp32" "fp16" "int8" ; do
for pr in "fp8" "fp16" "bf16"; do
$EXE -pr=$pr -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC'
done
$EXE -pr=$pr -N=1 -C=1 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=1 -C=1 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=128 -C=1024 -H=64 -W=64 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=128 -C=1024 -H=64 -W=64 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=16 -C=64 -H=32 -W=128 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=16 -C=64 -H=128 -W=32 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=1 -C=2048 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=1 -C=2048 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=1 -C=1 -H=1024 -W=1024 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=1 -C=1 -H=1024 -W=1024 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=1 -C=64 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=1 -C=64 -H=1024 -W=1 -layout_in='NHWC' -layout_out='NCHW'
done

View File

@@ -384,22 +384,6 @@ struct tensor_view
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_view{");
// buf_
printf("buf_: ");
print(buf_);
printf(", ");
// desc_
printf("desc_: ");
print(desc_);
printf("}");
}
// member
buffer_view buf_;
TensorDesc desc_;
@@ -494,6 +478,7 @@ template <typename TensorView,
CK_TILE_HOST_DEVICE constexpr auto
pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads)
{
constexpr index_t num_dim = DoPads::size();
static_assert(num_dim == TileLengths::size() && num_dim == TensorView::get_num_of_dimension(),

View File

@@ -85,7 +85,12 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
// SFC
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
[&](auto i) {
if constexpr(vec_length_in == 1)
return 1;
else
return (i == y_dim_vec_in || i == y_dim_vec_out) ? y_lengths[i] : 1;
},
number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
@@ -103,13 +108,19 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
// loop over SFC
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...] in the order of input tensor
constexpr auto idx_y = SFC_Y::get_index(iAccess);
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y);
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
constexpr auto idx_y_in =
generate_tuple([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
static_assert(in_offset % vec_length_in == 0);
constexpr auto idx_y_out_tmp =
generate_array([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
constexpr auto idx_y_out =
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
if constexpr(vec_length_in == 1)
{
out_tensor.get_thread_buffer()[number<out_offset>{}] =
in_tensor.get_thread_buffer()[number<in_offset>{}];
}

View File

@@ -19,7 +19,6 @@ struct BatchedTransposeHostArgs
index_t batch;
index_t height;
index_t width;
// index_t dim_blocks;
index_t dim_stride;
index_t dim_block_h;
index_t dim_block_w;
@@ -28,8 +27,10 @@ struct BatchedTransposeHostArgs
template <typename Pipeline_>
struct BatchedTransposeKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
CK_TILE_DEVICE static index_t counter = 0;
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using Type = typename Problem::InputType;
@@ -46,11 +47,11 @@ struct BatchedTransposeKernel
using Kargs = BatchedTransposeKargs;
using Hargs = BatchedTransposeHostArgs;
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
CK_TILE_HOST static constexpr auto GridSize(const Hargs& host_args)
{
size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w;
size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h;
size_t grid_size_z = h.batch;
size_t grid_size_x = (host_args.height + host_args.dim_block_h - 1) / host_args.dim_block_h;
size_t grid_size_y = (host_args.width + host_args.dim_block_w - 1) / host_args.dim_block_w;
size_t grid_size_z = host_args.batch;
return dim3(grid_size_x, grid_size_y, grid_size_z);
}
@@ -70,58 +71,52 @@ struct BatchedTransposeKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr ck_tile::index_t VectorSizeInput = Problem::VectorSizeInput;
static constexpr ck_tile::index_t VectorSizeOutput = Problem::VectorSizeOutput;
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
const auto iDim = blockIdx.z;
static constexpr ck_tile::index_t kMPerThread = Problem::kMPerThread;
static constexpr ck_tile::index_t kNPerThread = Problem::kNPerThread;
static_assert(kMPerThread == 1 && kNPerThread == 1);
const auto iDim = blockIdx.z;
const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const Type*>(kargs.p_input) + iDim * kargs.dim_stride,
make_tuple(kargs.height, kargs.width),
make_tuple(kargs.width, 1),
number<kNPerThread>{}, // TODO thread load value
number<VectorSizeInput>{},
number<1>{});
return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
sequence<kPadN, kPadM>{});
}();
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
const auto y_n_m = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<Type*>(kargs.p_output) + iDim * kargs.dim_stride,
make_tuple(kargs.width, kargs.height),
make_tuple(kargs.height, 1),
number<kMPerThread>{},
number<VectorSizeOutput>{},
number<1>{});
return pad_tensor_view(y_dram_naive,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
sequence<kPadN, kPadM>{});
sequence<kPadM, kPadN>{});
}();
auto x_block_window =
make_tile_window(x_m_n,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{static_cast<ck_tile::index_t>(iM * kMPerBlock),
static_cast<ck_tile::index_t>(iN * kNPerBlock)});
auto x_block_window = make_tile_window(
x_m_n,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{static_cast<ck_tile::index_t>(iM), static_cast<ck_tile::index_t>(iN)});
auto y_block_window =
make_tile_window(y_n_m,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
{static_cast<ck_tile::index_t>(iN * kNPerBlock),
static_cast<ck_tile::index_t>(iM * kMPerBlock)});
auto y_block_window = make_tile_window(
y_n_m,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
{static_cast<ck_tile::index_t>(iN), static_cast<ck_tile::index_t>(iM)});
Pipeline{}(x_block_window, y_block_window);
}

View File

@@ -29,24 +29,18 @@ struct BatchedTransposePipeline
{
auto inp_win =
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
auto input_tile = load_tile(inp_win);
auto output_tile = make_static_distributed_tensor<InputType>(
Policy::template MakeOutputDistribution<Problem>());
transpose_tile2d(output_tile, input_tile);
auto out_win =
make_tile_window(out_window, Policy::template MakeOutputDistribution<Problem>());
auto x = load_tile(inp_win); // x->thread input_win->block
auto y = make_static_distributed_tensor<InputType>(
Policy::template MakeOutputDistribution<Problem>());
constexpr auto span_2d_x = decltype(x)::get_distributed_spans();
sweep_tile_span(span_2d_x[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d_x[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx1, idx0);
y(i_j_idx) = x(i_j_idx);
});
});
store_tile(out_win, y);
store_tile(out_win, output_tile);
}
};
} // namespace ck_tile

View File

@@ -14,31 +14,34 @@ struct BatchedTransposePolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
{
using S = Problem;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>,
sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::kMPerBlock;
constexpr index_t NPerBlock = Problem::kNPerBlock;
constexpr index_t VecLoadSize = Problem::VectorSizeInput;
using TileEncodingPattern =
TileDistributionEncodingPattern2D<BlockSize,
MPerBlock,
NPerBlock,
VecLoadSize,
tile_distribution_pattern::thread_raked>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
{
using S = Problem;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>,
sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>>,
tuple<sequence<2, 1>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<2, 1>,
sequence<2, 2>>{});
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::kMPerBlock;
constexpr index_t NPerBlock = Problem::kNPerBlock;
constexpr index_t VecLoadSize = Problem::VectorSizeOutput;
using TileEncodingPattern =
TileDistributionEncodingPattern2D<BlockSize,
NPerBlock,
MPerBlock,
VecLoadSize,
tile_distribution_pattern::thread_raked>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
}
};
} // namespace ck_tile

View File

@@ -4,7 +4,6 @@
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
#define VectorLoadSize 16
@@ -12,11 +11,11 @@
namespace ck_tile {
template <typename InputType_,
typename BlockTile, // Sequence<...
typename WarpTile, // Sequence<...
typename ThreadTile, // Sequence<...
bool kPadM_ = true,
bool kPadN_ = true>
typename BlockTile, // Sequence<...
typename WarpTile, // Sequence<...
typename ThreadTile,
bool kPadM_ = false,
bool kPadN_ = false> // Sequence<...
struct BatchedTransposeProblem
{
using InputType = remove_cvref_t<InputType_>;
@@ -42,7 +41,7 @@ struct BatchedTransposeProblem
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr index_t AlignmentM = kPadM ? VectorLoadSize / sizeof(InputType) : 1; // TODO
static constexpr index_t AlignmentN = kPadN ? VectorLoadSize / sizeof(InputType) : 1;
static constexpr index_t VectorSizeInput = kPadM ? 1 : VectorLoadSize / sizeof(InputType);
static constexpr index_t VectorSizeOutput = kPadN ? 1 : VectorLoadSize / sizeof(InputType);
};
} // namespace ck_tile