mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
update layernorm (#1570)
* port layernorm
* change warp_welford.hpp
* Update warpshuffle
* 1. Add save mean and save std back
2. Move construction of tensor_view and tile_window to operator()
* refine welford max count calculation
* unify layernorm api
* Rename file
* Remove save mean and inv std
* Revert "refine welford max count calculation"
This reverts commit 022365802b.
* Fix order of parameter
* refine welford max count calculation again
* Remove fp32 instances
* Fix bug of padding
* refactor api
* Support bf16
* Extract common function
* Refine arg of operator()
* Add kMThreadPerBlock to template parameter
* clang format
* Refine variable name
* Refine file name
* remove redundant line
* refactor layernorm2d pipeline and add block-per-block utility
* fix name
* rename more
* add more block-per-tile instance
* remove duplicated define
* update instance for 2048, 1024 case
* support up to 2048 now
* opt loading
* add n1536
* Add two pass pipeline
* format
* Fix incorrect type
* parallel compilation
* Use smaller N
* fix 2p pass
* Support Repeat_M in distribution
* Refine nameing
* Add reduce example
---------
Co-authored-by: letaoqin <letaoqin@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
This commit is contained in:
19
example/ck_tile/05_reduce/CMakeLists.txt
Normal file
19
example/ck_tile/05_reduce/CMakeLists.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
set(EXAMPLE_REDUCE "tile_example_reduce")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_REDUCE}")
|
||||
|
||||
add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL reduce.cpp)
|
||||
target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_REDUCE_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
|
||||
target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
110
example/ck_tile/05_reduce/reduce.cpp
Normal file
110
example/ck_tile/05_reduce/reduce.cpp
Normal file
@@ -0,0 +1,110 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "reduce.hpp"
|
||||
#include <cstring>
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3328", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using ADataType = DataType;
|
||||
using AccDataType = float;
|
||||
using BDataType = DataType;
|
||||
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host({m, n});
|
||||
ck_tile::HostTensor<BDataType> b_host_ref({m});
|
||||
ck_tile::HostTensor<BDataType> b_host_dev({m});
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
|
||||
|
||||
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_buf(b_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
a_buf.ToDevice(a_host.data());
|
||||
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using ThreadTile = ck_tile::sequence<8, 8>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Kernel = ck_tile::Reduce<ADataType,
|
||||
AccDataType,
|
||||
BDataType,
|
||||
kBlockSize,
|
||||
BlockWarps,
|
||||
BlockTile,
|
||||
WarpTile,
|
||||
ThreadTile>;
|
||||
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
|
||||
m,
|
||||
n));
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * m * n + sizeof(BDataType) * m;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
ck_tile::reference_reduce<ADataType, AccDataType, BDataType>(a_host, b_host_ref);
|
||||
b_buf.FromDevice(b_host_dev.mData.data());
|
||||
pass = ck_tile::check_err(b_host_dev, b_host_ref);
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
118
example/ck_tile/05_reduce/reduce.hpp
Normal file
118
example/ck_tile/05_reduce/reduce.hpp
Normal file
@@ -0,0 +1,118 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType,
|
||||
typename AccDataType,
|
||||
typename BDataType,
|
||||
index_t kBlockSize,
|
||||
typename BlockWarps, // num warps along seq<M, N>
|
||||
typename BlockTile, // block size, seq<M, N>
|
||||
typename WarpTile, // warp size, seq<M, N>
|
||||
typename ThreadTile> // contiguous pixels(vector size) along seq<M, N>
|
||||
struct Reduce
|
||||
{
|
||||
static constexpr index_t Block_M = BlockTile::at(number<0>{});
|
||||
static constexpr index_t Block_N = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
|
||||
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t Thread_M = ThreadTile::at(number<0>{});
|
||||
static constexpr index_t Thread_N = ThreadTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Thread_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Thread_N;
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
__device__ static constexpr auto MakeABlockTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M, ThreadPerWarp_M, Thread_M>,
|
||||
sequence<Repeat_N, WarpPerBlock_N, ThreadPerWarp_N, Thread_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
__device__ void operator()(const ADataType* p_a, BDataType* p_b, index_t M, index_t N) const
|
||||
{
|
||||
const auto a_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, N), make_tuple(N, 1), number<Thread_N>{}, number<1>{});
|
||||
|
||||
const auto iM = get_block_id() * Block_M;
|
||||
|
||||
// A window
|
||||
auto a_block_window = make_tile_window(a_m_n,
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
{iM, 0},
|
||||
MakeABlockTileDistribution());
|
||||
|
||||
const auto f_reduce = [](const auto& v0, const auto& v1) { return v0 + v1; };
|
||||
|
||||
const ADataType reduce_init_value = 0;
|
||||
|
||||
constexpr auto reduce_dims = sequence<1>{};
|
||||
|
||||
// Acc tile
|
||||
// TODO: support cross warp reduction
|
||||
auto acc_block_tensor = decltype(block_tile_reduce<AccDataType>(
|
||||
load_tile(a_block_window), reduce_dims, f_reduce, reduce_init_value)){};
|
||||
|
||||
// init Acc tile
|
||||
tile_elementwise_inout(
|
||||
[&](auto& acc) { acc = type_convert<AccDataType>(reduce_init_value); },
|
||||
acc_block_tensor);
|
||||
|
||||
// loop
|
||||
index_t iN = 0;
|
||||
|
||||
do
|
||||
{
|
||||
const auto a_block_tensor = load_tile(a_block_window);
|
||||
|
||||
// FIXME: support cross warp reduction
|
||||
block_tile_reduce(acc_block_tensor, a_block_tensor, reduce_dims, f_reduce);
|
||||
|
||||
move_tile_window(a_block_window, {0, Block_N});
|
||||
|
||||
iN += Block_N;
|
||||
|
||||
} while(iN < N);
|
||||
|
||||
// FIXME: support cross warp reduction
|
||||
block_tile_reduce_sync(acc_block_tensor, f_reduce);
|
||||
|
||||
// convert acc_block_tensor to b_block_tensor
|
||||
const auto b_block_tensor = tile_elementwise_in(
|
||||
[](const auto& acc) { return type_convert<BDataType>(acc); }, acc_block_tensor);
|
||||
|
||||
// B
|
||||
const auto b_m = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_b, make_tuple(M), number<32>{});
|
||||
|
||||
// B window
|
||||
auto b_block_window = make_tile_window(b_m, make_tuple(number<Block_M>{}), {iM});
|
||||
|
||||
// store B tile
|
||||
store_tile(b_block_window, b_block_tensor);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user