mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] layernorm support fused-quant/fused-add (#1604)
* add prenorm/postnorm support, refactor using generate.py * update README * update README * fix format * update some description and fix format * update format * format * use non-raw for loading * format and update n4096 * dynamic-quant ready * update readme * support fused dynamic-quant * update fused-quant, with smooth * update README * update args * update some based on comment
This commit is contained in:
77
include/ck_tile/ops/common/generic_2d_block_shape.hpp
Normal file
77
include/ck_tile/ops/common/generic_2d_block_shape.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
// clang-format off
|
||||
|
||||
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
|
||||
|
||||
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
|
||||
+<----------------------< Repeat_N(2)>--------------------->+
|
||||
| |
|
||||
+<-- <WarpPerBlock_N(2)> -->+
|
||||
Warp_N
|
||||
+--------------+--------------+--------------+--------------+----+----------------+
|
||||
Warp_M | wrap_0 | wrap_1 | | ^ ^
|
||||
+--------------+--------------+ | <WarpPerBlock_M(2)> |
|
||||
| wrap_2 | wrap_3 | | v
|
||||
+--------------+--------------+--------------+--------------+----+ Block_M
|
||||
| | |
|
||||
+ + |
|
||||
| | | v
|
||||
+--------------+--------------+--------------+--------------+ +
|
||||
|
||||
each Warp-tile (e.g 16 thrd per row)
|
||||
|
||||
Vector_N (contiguous pixels each thrd holds along N, or vector size)
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
// clang-format on
|
||||
*/
|
||||
template <typename BlockTile_, // block size, seq<M, N>
|
||||
typename WarpPerBlock_, // num warps along seq<M, N>
|
||||
typename WarpTile_, // warp size, seq<M, N>
|
||||
typename Vector_, // contiguous pixels(vector size) along seq<M, N>
|
||||
index_t BlockSize_ =
|
||||
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
|
||||
struct Generic2dBlockShape
|
||||
{
|
||||
// block size
|
||||
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
|
||||
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
|
||||
|
||||
// num warps along seq<M, N>, within each block
|
||||
static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
|
||||
|
||||
// warp size
|
||||
static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
|
||||
static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
|
||||
|
||||
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
|
||||
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
|
||||
// repeat of each thread along seq<M, N>
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
// vector size along seq<M, N>
|
||||
static constexpr index_t Vector_M = Vector_::at(number<0>{});
|
||||
static constexpr index_t Vector_N = Vector_::at(number<1>{});
|
||||
|
||||
static_assert(Warp_M % Vector_M == 0);
|
||||
static_assert(Warp_N % Vector_N == 0);
|
||||
// num of threads along seq<M, N>, within each warp
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user