mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[Ck tile] support rmsnorm and related fusion (#1605)
* Add reduce2d new api
* Prevent user use cross warp reduction
* Fix bug of std caculation
* Add rmsnorm2d
* Add rmsnorm small example
* Remove static assert to prevent compile fail
* Add script to test performance and correctness
* Add missing cmake change
* refine naming
* refine example of rmsnorm
* Fix bug of rmsnorm
* Refine naming
* Fix cmake
* clang format
* Refine pipeline name
* Add add_rmsnorm2d_rdquant kernel
* Add reduce op
* host verification
* Fix bug of one pass pipeline
* Refine tile size
* Add two pass pipeline
* Rename two pass to three pass
* Fix bug of kSaveX == false
* Add instance library
* Add test script
* Fix bug of x verification
* Add save_x to trait
* Add README
* Move reduce2d into reduce folder
* Fix bug of welford when number of m warp > 1
* remove reduncant comment
* 1. move 06_rmsnorm2d to 10_rmsnorm2d
2. move 07_add_rmsnorm2d_rdquant to 11_add_rmsnorm2d_rdquant
* clang format and add missing header
* Add host validation of add + layernorm2d + rsquant
* Revert "Add host validation of add + layernorm2d + rsquant"
This reverts commit 936cb45797.
* Remove deprecated flag
This commit is contained in:
@@ -19,9 +19,9 @@ auto create_args(int argc, char* argv[])
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using ADataType = DataType;
|
||||
using AccDataType = float;
|
||||
using BDataType = DataType;
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using YDataType = DataType;
|
||||
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
@@ -29,35 +29,39 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host({m, n});
|
||||
ck_tile::HostTensor<BDataType> b_host_ref({m});
|
||||
ck_tile::HostTensor<BDataType> b_host_dev({m});
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n});
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m});
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
|
||||
|
||||
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_buf(b_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
a_buf.ToDevice(a_host.data());
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
using ReduceOp = ck_tile::ReduceOp::Add;
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using ThreadTile = ck_tile::sequence<8, 8>;
|
||||
using Vector = ck_tile::sequence<8, 8>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
// cross warp-reduce
|
||||
// using BlockWarps = ck_tile::sequence<2, 2>;
|
||||
// using BlockTile = ck_tile::sequence<2, 1024>;
|
||||
// using WarpTile = ck_tile::sequence<1, 512>;
|
||||
// using Vector = ck_tile::sequence<1, 8>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 512;
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Kernel = ck_tile::Reduce<ADataType,
|
||||
AccDataType,
|
||||
BDataType,
|
||||
kBlockSize,
|
||||
BlockWarps,
|
||||
BlockTile,
|
||||
WarpTile,
|
||||
ThreadTile>;
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, Vector>;
|
||||
using Porblem =
|
||||
ck_tile::Reduce2dProblem<XDataType, ComputeDataType, YDataType, Shape, ReduceOp>;
|
||||
|
||||
using Kernel = ck_tile::Reduce<Porblem>;
|
||||
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
@@ -65,12 +69,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
m,
|
||||
n));
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * m * n + sizeof(BDataType) * m;
|
||||
std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
@@ -81,9 +85,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
ck_tile::reference_reduce<ADataType, AccDataType, BDataType>(a_host, b_host_ref);
|
||||
b_buf.FromDevice(b_host_dev.mData.data());
|
||||
pass = ck_tile::check_err(b_host_dev, b_host_ref);
|
||||
ck_tile::reference_reduce<XDataType, ComputeDataType, YDataType>(
|
||||
x_host, y_host_ref, ReduceOp{});
|
||||
y_buf.FromDevice(y_host_dev.mData.data());
|
||||
pass = ck_tile::check_err(y_host_dev, y_host_ref);
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
@@ -103,8 +108,8 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
// else if(data_type == "bf16")
|
||||
// {
|
||||
// return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
// }
|
||||
}
|
||||
|
||||
@@ -5,20 +5,16 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType,
|
||||
typename AccDataType,
|
||||
typename BDataType,
|
||||
index_t kBlockSize,
|
||||
typename BlockWarps, // num warps along seq<M, N>
|
||||
template <typename BlockWarps, // num warps along seq<M, N>
|
||||
typename BlockTile, // block size, seq<M, N>
|
||||
typename WarpTile, // warp size, seq<M, N>
|
||||
typename ThreadTile> // contiguous pixels(vector size) along seq<M, N>
|
||||
struct Reduce
|
||||
typename Vector> // contiguous pixels(vector size) along seq<M, N>
|
||||
struct Reduce2dShape
|
||||
{
|
||||
static constexpr index_t Block_M = BlockTile::at(number<0>{});
|
||||
static constexpr index_t Block_N = BlockTile::at(number<1>{});
|
||||
@@ -26,93 +22,143 @@ struct Reduce
|
||||
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
|
||||
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t Thread_M = ThreadTile::at(number<0>{});
|
||||
static constexpr index_t Thread_N = ThreadTile::at(number<1>{});
|
||||
static constexpr index_t Vector_M = Vector::at(number<0>{});
|
||||
static constexpr index_t Vector_N = Vector::at(number<1>{});
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Thread_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Thread_N;
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
__device__ static constexpr auto MakeABlockTileDistribution()
|
||||
static constexpr index_t BlockSize =
|
||||
warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
};
|
||||
|
||||
template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename BlockShape_,
|
||||
typename ReduceOp_>
|
||||
struct Reduce2dProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using ReduceOp = ReduceOp_;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = BlockReduce2dDefaultPolicy>
|
||||
struct Reduce
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
|
||||
#if 0
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, index_t M, index_t N)
|
||||
const
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M, ThreadPerWarp_M, Thread_M>,
|
||||
sequence<Repeat_N, WarpPerBlock_N, ThreadPerWarp_N, Thread_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
__device__ void operator()(const ADataType* p_a, BDataType* p_b, index_t M, index_t N) const
|
||||
{
|
||||
const auto a_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, N), make_tuple(N, 1), number<Thread_N>{}, number<1>{});
|
||||
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
|
||||
|
||||
const auto iM = get_block_id() * Block_M;
|
||||
const auto y_m = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_y, make_tuple(M), number<1>{});
|
||||
|
||||
// A window
|
||||
auto a_block_window = make_tile_window(a_m_n,
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
{iM, 0},
|
||||
MakeABlockTileDistribution());
|
||||
const auto iM = get_block_id() * S::Block_M;
|
||||
|
||||
auto x_window = make_tile_window(x_m_n,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto y_window = make_tile_window(y_m, make_tuple(number<S::Block_M>{}), {iM});
|
||||
|
||||
const auto f_reduce = [](const auto& v0, const auto& v1) { return v0 + v1; };
|
||||
|
||||
const ADataType reduce_init_value = 0;
|
||||
const XDataType reduce_init_value = 0;
|
||||
|
||||
constexpr auto reduce_dims = sequence<1>{};
|
||||
|
||||
// Acc tile
|
||||
// TODO: support cross warp reduction
|
||||
auto acc_block_tensor = decltype(block_tile_reduce<AccDataType>(
|
||||
load_tile(a_block_window), reduce_dims, f_reduce, reduce_init_value)){};
|
||||
auto y_compute = decltype(block_tile_reduce<ComputeDataType>(
|
||||
load_tile(x_window), reduce_dims, f_reduce, reduce_init_value)){};
|
||||
|
||||
// init Acc tile
|
||||
tile_elementwise_inout(
|
||||
[&](auto& acc) { acc = type_convert<AccDataType>(reduce_init_value); },
|
||||
acc_block_tensor);
|
||||
set_tile(y_compute, reduce_init_value);
|
||||
|
||||
// loop
|
||||
index_t iN = 0;
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N));
|
||||
|
||||
do
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto a_block_tensor = load_tile(a_block_window);
|
||||
const auto x = load_tile(x_window);
|
||||
block_tile_reduce(y_compute, x, reduce_dims, f_reduce);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
// FIXME: support cross warp reduction
|
||||
block_tile_reduce(acc_block_tensor, a_block_tensor, reduce_dims, f_reduce);
|
||||
block_tile_reduce_sync(y_compute, f_reduce);
|
||||
|
||||
move_tile_window(a_block_window, {0, Block_N});
|
||||
|
||||
iN += Block_N;
|
||||
|
||||
} while(iN < N);
|
||||
|
||||
// FIXME: support cross warp reduction
|
||||
block_tile_reduce_sync(acc_block_tensor, f_reduce);
|
||||
|
||||
// convert acc_block_tensor to b_block_tensor
|
||||
const auto b_block_tensor = tile_elementwise_in(
|
||||
[](const auto& acc) { return type_convert<BDataType>(acc); }, acc_block_tensor);
|
||||
|
||||
// B
|
||||
const auto b_m = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_b, make_tuple(M), number<32>{});
|
||||
|
||||
// B window
|
||||
auto b_block_window = make_tile_window(b_m, make_tuple(number<Block_M>{}), {iM});
|
||||
|
||||
// store B tile
|
||||
store_tile(b_block_window, b_block_tensor);
|
||||
store_tile(y_window, cast_tile<YDataType>(y_compute));
|
||||
}
|
||||
#else
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, index_t M, index_t N) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
|
||||
|
||||
const auto y_m = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_y, make_tuple(M), number<1>{});
|
||||
|
||||
const auto iM = get_block_id() * S::Block_M;
|
||||
|
||||
auto x_window = make_tile_window(x_m_n,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto y_window = make_tile_window(y_m, make_tuple(number<S::Block_M>{}), {iM});
|
||||
|
||||
__shared__ char smem[Policy::template GetSmemSize<Problem>()];
|
||||
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N));
|
||||
|
||||
auto reduce_func = typename Problem::ReduceOp{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(load_tile(x_window));
|
||||
auto y_compute = block_reduce2d.template MakeYBlockTile<XTensorType>();
|
||||
set_tile(y_compute, reduce_func.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
block_reduce2d(x, y_compute, reduce_func);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_compute, reduce_func);
|
||||
block_reduce2d_cross_warp_sync(y_compute, smem, reduce_func);
|
||||
|
||||
store_tile(y_window, cast_tile<YDataType>(y_compute));
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user