mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
added explanation comments
This commit is contained in:
@@ -3,6 +3,19 @@
|
||||
#include "vector_add.hpp"
|
||||
#include <cstring>
|
||||
|
||||
|
||||
// This example demonstrates how to use the ck_tile library to perform an elementwise vector addition
|
||||
// using a custom kernel. The kernel is defined in the vector_add.hpp file, and the reference implementation
|
||||
// is provided in the reference_vector_add.hpp file.
|
||||
|
||||
|
||||
|
||||
// parse command line arguments
|
||||
// -m: size of the vectors
|
||||
// -v: validation flag (1 for validation, 0 for no validation)
|
||||
// -prec: precision of the data type (fp16, fp32, int8, int32)
|
||||
// -warmup: number of warmup iterations (number of kernel launches before measuring performance)
|
||||
// -repeat: number of repeat iterations (number of kernel launches to measure performance)
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -19,45 +32,48 @@ auto create_args(int argc, char* argv[])
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using YDataType = DataType;
|
||||
using XDataType = DataType; // input data type
|
||||
using ComputeDataType = float; // compute data type
|
||||
using YDataType = DataType; // output data type
|
||||
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
ck_tile::index_t m = arg_parser.get_int("m"); // size of the vectors
|
||||
int do_validation = arg_parser.get_int("v"); // do we verify the result on cpu
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host_a({m});
|
||||
ck_tile::HostTensor<XDataType> x_host_b({m});
|
||||
ck_tile::HostTensor<XDataType> x_host_a({m}); // length input vector A, if given two arguments m, n the HostTensor will be created with shape (m, n)
|
||||
ck_tile::HostTensor<XDataType> x_host_b({m}); // length input vector B, if given two arguments m, n the HostTensor will be created with shape (m, n)
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host_a);
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host_b);
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host_a); // fill the input vector A with random values
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host_b);
|
||||
|
||||
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes()); // allocate device memory for input vector A (this a wrapper over hipMalloc)
|
||||
ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf_a.ToDevice(x_host_a.data());
|
||||
x_buf_a.ToDevice(x_host_a.data()); // copy the input vector A to device memory, this is a wrapper over hipMemcpy
|
||||
x_buf_b.ToDevice(x_host_b.data());
|
||||
|
||||
// using BlockTile = ck_tile::sequence<8192>;
|
||||
// using BlockWarps = ck_tile::sequence<4>;
|
||||
// using WarpTile = ck_tile::sequence<512>; // 8 * 64 = 512
|
||||
// using Vector = ck_tile::sequence<8>; // 8 * 16 = 128 bytes
|
||||
|
||||
// Dividing the problem into blocktile, warptile, and vector
|
||||
// The blocktile is the size of the tile that will be processed by a single block
|
||||
// The warptile is the size of the tile that will be processed by a single warp
|
||||
// The vector is the size of the tile that will be processed by a single thread
|
||||
// The problem is divided into blocks of size BlockTile, each block is further divided into warps of size WarpTile
|
||||
// and each warp is further divided into threads of size Vector
|
||||
using BlockTile = ck_tile::sequence<8192>; // Size of the block tile (Entire problem is divided into blocks of this size)
|
||||
using BlockWarps = ck_tile::sequence<8>; // How many concurrent warps are in a block (Each warp will cover some part of blockTile)
|
||||
using WarpTile = ck_tile::sequence<64>; // How many elements are covered by a warp
|
||||
using Vector = ck_tile::sequence<1>; // How many elements are covered by a thread (Each thread will cover some part of WarpTile)
|
||||
|
||||
// constexpr ck_tile::index_t kBlockSize = 256;
|
||||
// constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
|
||||
using BlockTile = ck_tile::sequence<8192>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
using Vector = ck_tile::sequence<1>;
|
||||
// Interpretation of above configurations
|
||||
// Each thread will cover 1 element (Vector)
|
||||
// Each WarpTile will cover 64 elements (WarpTile) --> since 64 threads in a warp
|
||||
// if we have 8 warps in a block (BlockWarps) then we have 8 * 64 = 512 threads in a block
|
||||
// if 8 warps are not enough to cover the entire blockTile then each of the 8 concurrent warps will iterate over the blockTile several times
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 512;
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
@@ -76,24 +92,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << "BlockWarps: " << BlockWarps::at(ck_tile::number<0>{}) << std::endl;
|
||||
std::cout << "WarpTile: " << WarpTile::at(ck_tile::number<0>{}) << std::endl;
|
||||
std::cout << "Vector: " << Vector::at(ck_tile::number<0>{}) << std::endl;
|
||||
std::cout << "Repeat: " << Shape::Repeat_M << std::endl;
|
||||
std::cout << "Repeat: " << Shape::Repeat_M << std::endl; // number of times a warp will iterate over the blockTile, covering different parts of the blockTile
|
||||
std::cout << "Threads per Block: " << kBlockSize << std::endl;
|
||||
std::cout << "ThreadBlocks per CU: " << kBlockPerCu << std::endl;
|
||||
|
||||
// What is a Problem in CKTile?
|
||||
// A Problem defines the shape of the data, the precision of the data
|
||||
using Problem =
|
||||
ck_tile::MultiplyVectorProblem<XDataType, ComputeDataType, YDataType, Shape>;
|
||||
|
||||
// What is a Policy in CKTile?
|
||||
// A Policy defines how to map the data between threads and data in memory
|
||||
|
||||
// The kernel is the function that will be executed on the device
|
||||
// It requires a Problem and Policy to be defined
|
||||
using Kernel = ck_tile::MultiplyVectorKernel<Problem>;
|
||||
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()),
|
||||
static_cast<XDataType*>(x_buf_b.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
m));
|
||||
// The kernel is launched with the following parameters:
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, // wrapper over hipStreamCreate
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>( // numOfThreadsPerBlock, numOfBlocksPerCU
|
||||
Kernel{}, // kernel
|
||||
kGridSize, // number of blocks in the grid
|
||||
kBlockSize, // number of threads in a block
|
||||
0, // shared memory size
|
||||
static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()), // input vector A
|
||||
static_cast<XDataType*>(x_buf_b.GetDeviceBuffer()), // input vector B
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()), // output vector
|
||||
m));
|
||||
|
||||
std::size_t num_btype = sizeof(XDataType) * m + sizeof(YDataType) * m;
|
||||
|
||||
|
||||
@@ -8,29 +8,32 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockWarps, // num warps along seq<M, N>
|
||||
typename BlockTile, // block size, seq<M, N>
|
||||
typename WarpTile, // warp size, seq<M, N>
|
||||
typename Vector> // contiguous pixels(vector size) along seq<M, N>
|
||||
|
||||
// struct that holds the tile size of the block, warp, and vector
|
||||
// and the number of warps per block
|
||||
// and the number of threads per warp
|
||||
// and the number of times the warp tile is repeated in the block tile
|
||||
// and the block size
|
||||
template <typename BlockWarps,
|
||||
typename BlockTile,
|
||||
typename WarpTile,
|
||||
typename Vector>
|
||||
struct MultiplyVector
|
||||
{
|
||||
static constexpr index_t Block_M = BlockTile::at(number<0>{});
|
||||
//static constexpr index_t Block_N = BlockTile::at(number<1>{});
|
||||
static constexpr index_t Block_M = BlockTile::at(number<0>{});
|
||||
|
||||
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
|
||||
//static constexpr index_t Warp_N = WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t Vector_M = Vector::at(number<0>{});
|
||||
//static constexpr index_t Vector_N = Vector::at(number<1>{});
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
|
||||
//static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
|
||||
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
|
||||
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
//static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
//static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); // Number of times the warp tile is repeated in the block tile
|
||||
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
@@ -48,6 +51,8 @@ struct MultiplyVectorProblem
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
};
|
||||
|
||||
|
||||
// data mapping beween threads and memory
|
||||
struct AddDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
@@ -56,12 +61,13 @@ struct AddDefaultPolicy
|
||||
using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>>,
|
||||
tuple<sequence<1>, sequence<1>>,
|
||||
tuple<sequence<1>, sequence<2>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 3>>{});
|
||||
sequence<>, // Replicate
|
||||
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>>, // Hierarchical
|
||||
tuple<sequence<1>, sequence<1>>, // Parallel
|
||||
tuple<sequence<1>, sequence<2>>, // Parallel
|
||||
sequence<1, 1>, // Yield
|
||||
sequence<0, 3>>{} // Yield
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -75,12 +81,17 @@ struct MultiplyVectorKernel
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
|
||||
|
||||
// body of the kernel
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x_a, const XDataType* p_x_b, YDataType* p_y, index_t M) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// create tensor view for the input and output data, this defines how the data is laid out in memory
|
||||
const auto x_m_n_a = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x_a, make_tuple(M), make_tuple(1), number<S::Vector_M>{});
|
||||
p_x_a, make_tuple(M), make_tuple(1), number<S::Vector_M>{}); // raw pointer, shape of the tensor, stride of the tensor, and lastGarunteedVectorLength
|
||||
|
||||
// lastGarunteedVectorLength --> intuitively, this is the number of elements in the last dimension of the tensor that are guaranteed to be fetched by same thread
|
||||
|
||||
const auto x_m_n_b = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x_b, make_tuple(M), make_tuple(1), number<S::Vector_M>{});
|
||||
@@ -88,8 +99,11 @@ struct MultiplyVectorKernel
|
||||
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y, make_tuple(M), make_tuple(1), number<S::Vector_M>{});
|
||||
|
||||
|
||||
// origin of the block tile
|
||||
const auto iM = get_block_id() * S::Block_M;
|
||||
|
||||
// creating tile windows for the input and output data
|
||||
auto x_window_a = make_tile_window(x_m_n_a,
|
||||
make_tuple(number<S::Block_M>{}),
|
||||
{iM},
|
||||
@@ -112,9 +126,9 @@ struct MultiplyVectorKernel
|
||||
|
||||
|
||||
|
||||
// Process the vector multiplication
|
||||
// Process the vector add
|
||||
constexpr auto spans = decltype(xa)::get_distributed_spans(); // shape of the tile
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx) { // iterate over the tile // idx+=4
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx) { // iterate over the tile
|
||||
const auto tile_idx = make_tuple(idx);
|
||||
const auto a_val = type_convert<ComputeDataType>(xa[tile_idx]);
|
||||
const auto b_val = type_convert<ComputeDataType>(xb[tile_idx]);
|
||||
|
||||
Reference in New Issue
Block a user