mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[Ck tile] layernorm2d fwd optimize (#1637)
* optimze small N case using vec io and using rcp div * [Ck_tile] layernorm, add param to control fastdiv; change generate codes and test pass * [Ck_tile] fix blockSize compute in Generic2dBlockShape * [Ck_tile]fix kfastfdiv template style * [Ck_tile] layernorm, fix stype in review --------- Co-authored-by: dummycoderfe <noplydummmycoder@163.com>
This commit is contained in:
@@ -38,9 +38,7 @@ namespace ck_tile {
|
||||
template <typename BlockTile_, // block size, seq<M, N>
|
||||
typename WarpPerBlock_, // num warps along seq<M, N>
|
||||
typename WarpTile_, // warp size, seq<M, N>
|
||||
typename Vector_, // contiguous pixels(vector size) along seq<M, N>
|
||||
index_t BlockSize_ =
|
||||
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
|
||||
typename Vector_> // contiguous pixels(vector size) along seq<M, N>)>
|
||||
struct Generic2dBlockShape
|
||||
{
|
||||
// block size
|
||||
@@ -68,10 +66,12 @@ struct Generic2dBlockShape
|
||||
static_assert(Warp_M % Vector_M == 0);
|
||||
static_assert(Warp_N % Vector_N == 0);
|
||||
// num of threads along seq<M, N>, within each warp
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
static constexpr index_t ThreadPerBlock_M = Block_M / Repeat_M / Vector_M;
|
||||
static constexpr index_t ThreadPerBlock_N = Block_N / Repeat_N / Vector_N;
|
||||
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user