mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
add batched_transpose implement (#1660)
* add batched_transpose implement --------- Co-authored-by: root <root@ctr-ubbsmc16.amd.com> Co-authored-by: ThruptiRajLakshmanaGowda <tlakshma@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
9
example/ck_tile/35_batched_transpose/CMakeLists.txt
Normal file
9
example/ck_tile/35_batched_transpose/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
set(TARGET_NAME tile_example_batched_transpose)
|
||||
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL batched_transpose_example.cpp batched_transpose_api.cpp)
|
||||
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
target_compile_options(tile_example_batched_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS})
|
||||
|
||||
27
example/ck_tile/35_batched_transpose/README.md
Normal file
27
example/ck_tile/35_batched_transpose/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# Batched Transpose
|
||||
This folder contains example for batched Transpose using ck_tile tile-programming implementation. Currently, it supports the batched transpose with NCHW to NHWC or NHWC to NCHW. So in this way from NCHW you could transpose to either NHWC or NWCH(two transposes). Now the transpose read with single data point. We would soon put it in vectorized transpose.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
# Make the transpose executable
|
||||
make tile_example_batched_transpose -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_batched_transpose`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-N input batch size (default:2)
|
||||
-C input channel size. (default:16)
|
||||
-H input height size. (default:1)
|
||||
-W input width size. (default:16)
|
||||
-v whether do CPU validation or not (default: 1)
|
||||
-layout_in input tensor data layout - NCHW by default
|
||||
-layout_out output tensor data layout - NHWC by default
|
||||
-seed seed to be used, -1 means random every time (default:-1)
|
||||
-k_name t to 1 will print kernel name (default:0)
|
||||
```
|
||||
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "batched_transpose_example.hpp"
|
||||
#include <iostream>
|
||||
|
||||
template <typename ts_type,
|
||||
ck_tile::index_t block_x,
|
||||
ck_tile::index_t block_y,
|
||||
ck_tile::index_t warp_x,
|
||||
ck_tile::index_t warp_y,
|
||||
ck_tile::index_t thread_x,
|
||||
ck_tile::index_t thread_y>
|
||||
float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s)
|
||||
{
|
||||
uint32_t dim_block_h = (a.height + block_y - 1) / block_y;
|
||||
uint32_t dim_block_w = (a.width + block_x - 1) / block_x;
|
||||
uint32_t dim_stride = a.height * a.width;
|
||||
|
||||
a.dim_stride = dim_stride;
|
||||
a.dim_block_h = dim_block_h;
|
||||
a.dim_block_w = dim_block_w;
|
||||
|
||||
using block_tile = ck_tile::sequence<block_x, block_y>;
|
||||
using warp_tile = ck_tile::sequence<warp_x, warp_y>;
|
||||
using thread_tile = ck_tile::sequence<thread_x, thread_y>;
|
||||
|
||||
using ts_problem =
|
||||
ck_tile::BatchedTransposeProblem<ts_type, block_tile, warp_tile, thread_tile>;
|
||||
using ts_pipeline = ck_tile::BatchedTransposePipeline<ts_problem>;
|
||||
|
||||
using kernel = ck_tile::BatchedTransposeKernel<ts_pipeline>;
|
||||
|
||||
auto kargs = kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = kernel::GridSize(a);
|
||||
constexpr dim3 blocks = kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y
|
||||
#define FOREACH_TRANSPOSE_PARAM(F) \
|
||||
F(fp16, ck_tile::fp16_t, 16, 16, 8, 8, 1, 1) \
|
||||
F(bf16, ck_tile::bf16_t, 16, 16, 8, 8, 1, 1) \
|
||||
F(fp32, ck_tile::fp32_t, 16, 16, 8, 8, 1, 1) \
|
||||
F(int8, ck_tile::int8_t, 16, 16, 8, 8, 1, 1)
|
||||
|
||||
// Macro that defines one static function per line
|
||||
#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY) \
|
||||
static float transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY( \
|
||||
batched_transpose_kargs& a, ck_tile::stream_config& s) \
|
||||
{ \
|
||||
return batched_transpose_dispatch<REAL_TYPE, BX, BY, WX, WY, TX, TY>(a, s); \
|
||||
}
|
||||
|
||||
FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN)
|
||||
|
||||
float batched_transpose(batched_transpose_trait t,
|
||||
batched_transpose_kargs a,
|
||||
ck_tile::stream_config s)
|
||||
{
|
||||
if(t.type == "fp16")
|
||||
{
|
||||
return transpose_fn_fp16_16_16_8_8_1_1(a, s);
|
||||
}
|
||||
else if(t.type == "bf16")
|
||||
{
|
||||
return transpose_fn_bf16_16_16_8_8_1_1(a, s);
|
||||
}
|
||||
else if(t.type == "fp32")
|
||||
{
|
||||
return transpose_fn_fp32_16_16_8_8_1_1(a, s);
|
||||
}
|
||||
else if(t.type == "int8")
|
||||
{
|
||||
return transpose_fn_int8_16_16_8_8_1_1(a, s);
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <time.h>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "batched_transpose_example.hpp"
|
||||
|
||||
#if 0
|
||||
template <typename T>
|
||||
void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
|
||||
{
|
||||
auto len = x.get_lengths();
|
||||
assert(len.size() == 4);
|
||||
std::cout << "[";
|
||||
for(size_t i = 0; i < len[0]; i++)
|
||||
{
|
||||
std::cout << i << ": [";
|
||||
for(size_t j = 0; j < len[1]; j++)
|
||||
{
|
||||
std::cout << j << ": [";
|
||||
for(size_t k = 0; k < len[2]; k++)
|
||||
{
|
||||
std::cout << k << ": [";
|
||||
for(size_t v = 0; v < len[3]; v++)
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
|
||||
{
|
||||
auto m =
|
||||
ck_tile::type_convert<float>(x(std::vector<std::size_t>{i, j, k, v}));
|
||||
|
||||
std::cout << m;
|
||||
if(v != len[3] - 1)
|
||||
std::cout << ",";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << x(std::vector<std::size_t>{i, j, k, v}) << " ";
|
||||
}
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << "--------------------" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
{
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
unsigned max_rounding_point_distance = 0;
|
||||
double atol = 2e-3;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
unsigned max_rounding_point_distance = 1;
|
||||
double atol = 0.0625;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "whether do CPU validation or not")
|
||||
.insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
|
||||
.insert("N", "2", "input batch size. ")
|
||||
.insert("C", "16", "input channel size.")
|
||||
.insert("H", "1", "input height size.")
|
||||
.insert("W", "16", "input width size. ")
|
||||
.insert("layout_in", "NCHW", "input tensor data layout - NCHW by default")
|
||||
.insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "t to 1 will print kernel name");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename Type>
|
||||
bool run_batched_transpose(ck_tile::ArgParser args)
|
||||
{
|
||||
int validate = args.get_int("v");
|
||||
std::string prec = args.get_str("pr");
|
||||
int N = args.get_int("N");
|
||||
int C = args.get_int("C");
|
||||
int H = args.get_int("H");
|
||||
int W = args.get_int("W");
|
||||
std::string layout_in = args.get_str("layout_in");
|
||||
std::string layout_out = args.get_str("layout_out");
|
||||
int seed = args.get_int("seed");
|
||||
|
||||
int dim_in[4], dim_out[4];
|
||||
int stride_dim_in[4], stride_dim_out[4];
|
||||
bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC";
|
||||
bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW";
|
||||
assert(nchw2nhwc != nhwc2nchw);
|
||||
(void)nhwc2nchw;
|
||||
|
||||
dim_in[0] = N;
|
||||
dim_in[1] = nchw2nhwc ? C : H;
|
||||
dim_in[2] = nchw2nhwc ? H : W;
|
||||
dim_in[3] = nchw2nhwc ? W : C;
|
||||
dim_out[0] = N;
|
||||
dim_out[1] = nchw2nhwc ? H : C;
|
||||
dim_out[2] = nchw2nhwc ? W : H;
|
||||
dim_out[3] = nchw2nhwc ? C : W;
|
||||
stride_dim_in[0] = C * H * W;
|
||||
stride_dim_in[1] = nchw2nhwc ? H * W : C * W;
|
||||
stride_dim_in[2] = nchw2nhwc ? W : C;
|
||||
stride_dim_in[3] = 1;
|
||||
stride_dim_out[0] = C * H * W;
|
||||
stride_dim_out[1] = nchw2nhwc ? C * W : H * W;
|
||||
stride_dim_out[2] = nchw2nhwc ? C : W;
|
||||
stride_dim_out[3] = 1;
|
||||
|
||||
if(seed < 0)
|
||||
{
|
||||
seed = std::time(nullptr);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<Type> x_host(
|
||||
{dim_in[0], dim_in[1], dim_in[2], dim_in[3]},
|
||||
{stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]});
|
||||
ck_tile::HostTensor<Type> y_host(
|
||||
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
|
||||
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
|
||||
|
||||
ck_tile::FillUniformDistribution<Type>{-.5f, .5f}(x_host);
|
||||
|
||||
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_dev.ToDevice(x_host.data());
|
||||
|
||||
auto trait = batched_transpose_trait{prec, layout_in};
|
||||
|
||||
uint32_t height = nchw2nhwc ? C : H * W;
|
||||
uint32_t width = nchw2nhwc ? H * W : C;
|
||||
|
||||
batched_transpose_kargs karg = [&]() {
|
||||
batched_transpose_kargs a_;
|
||||
a_.p_input = x_dev.GetDeviceBuffer();
|
||||
a_.p_output = y_dev.GetDeviceBuffer();
|
||||
a_.batch = N;
|
||||
a_.height = height;
|
||||
a_.width = width;
|
||||
return a_;
|
||||
}();
|
||||
|
||||
ck_tile::stream_config sc{nullptr, true};
|
||||
|
||||
auto ms = batched_transpose(trait, karg, sc);
|
||||
|
||||
std::size_t num_operations = N * C * H * (W - 1);
|
||||
std::size_t num_bytes = N * C * H * W * sizeof(Type);
|
||||
|
||||
float ave_time = ms * 1E-3;
|
||||
float gb_per_sec = num_bytes / ms * 1.E-6;
|
||||
float tflops = static_cast<float>(num_operations) / ms * 1.E-6;
|
||||
|
||||
std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H
|
||||
<< ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out
|
||||
<< " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops"
|
||||
<< gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n",
|
||||
prec.c_str(),
|
||||
N,
|
||||
C,
|
||||
H,
|
||||
W,
|
||||
layout_in.c_str(),
|
||||
ms);
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
fflush(stdout);
|
||||
|
||||
if(ms < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
y_dev.FromDevice(y_host.data());
|
||||
|
||||
bool rtn = true;
|
||||
if(validate)
|
||||
{
|
||||
// this host buffer will not copy to GPU, so no need use stride
|
||||
ck_tile::HostTensor<Type> y_ref(
|
||||
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
|
||||
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
|
||||
|
||||
ck_tile::reference_batched_transpose<Type>(x_host, y_ref, layout_in, layout_out);
|
||||
|
||||
auto [rtol, atol] = get_elimit<Type>("");
|
||||
|
||||
rtn &= ck_tile::check_err(
|
||||
y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol);
|
||||
}
|
||||
printf("valid:%s\n", rtn ? "y" : "n");
|
||||
fflush(stdout);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
std::string prec = args.get_str("pr");
|
||||
|
||||
bool r = true;
|
||||
if(prec.compare("fp32") == 0)
|
||||
{
|
||||
r &= run_batched_transpose<float>(args);
|
||||
}
|
||||
else if(prec.compare("fp16") == 0)
|
||||
{
|
||||
r &= run_batched_transpose<ck_tile::fp16_t>(args);
|
||||
}
|
||||
else if(prec.compare("bf16") == 0)
|
||||
{
|
||||
r &= run_batched_transpose<ck_tile::bf16_t>(args);
|
||||
}
|
||||
else if(prec.compare("int8") == 0)
|
||||
{
|
||||
r &= run_batched_transpose<ck_tile::int8_t>(args);
|
||||
}
|
||||
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "ck_tile/ops/batched_transpose.hpp"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#pragma once
|
||||
|
||||
struct batched_transpose_trait
|
||||
{
|
||||
std::string type;
|
||||
std::string layout;
|
||||
};
|
||||
|
||||
struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
float batched_transpose(batched_transpose_trait t,
|
||||
batched_transpose_kargs a,
|
||||
ck_tile::stream_config s);
|
||||
11
example/ck_tile/35_batched_transpose/script/smoke_test.sh
Executable file
11
example/ck_tile/35_batched_transpose/script/smoke_test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/sh
|
||||
|
||||
EXE=./build/bin/tile_example_batched_transpose
|
||||
|
||||
for pr in "fp32" "fp16" "int8" ; do
|
||||
$EXE -pr=$pr -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC'
|
||||
$EXE -pr=$pr -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW'
|
||||
$EXE -pr=$pr -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW'
|
||||
$EXE -pr=$pr -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC'
|
||||
$EXE -pr=$pr -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC'
|
||||
done
|
||||
Reference in New Issue
Block a user