update layernorm (#1570)

* port layernorm

* change warp_welford.hpp

* Update warpshuffle

* 1. Add save mean and save std back
2. Move construction of tensor_view and tile_window to operator()

* refine welford max count calculation

* unify layernorm api

* Rename file

* Remove save mean and inv std

* Revert "refine welford max count calculation"

This reverts commit 022365802b.

* Fix order of parameter

* refine welford max count calculation again

* Remove fp32 instances

* Fix bug of padding

* refactor api

* Support bf16

* Extract common function

* Refine arg of operator()

* Add kMThreadPerBlock to template parameter

* clang format

* Refine variable name

* Refine file name

* remove redundant line

* refactor layernorm2d pipeline and add block-per-block utility

* fix name

* rename more

* add more block-per-tile instance

* remove duplicated define

* update instance for 2048, 1024 case

* support up to 2048 now

* opt loading

* add n1536

* Add two pass pipeline

* format

* Fix incorrect type

* parallel compilation

* Use smaller N

* fix 2p pass

* Support Repeat_M in distribution

* Refine nameing

* Add reduce example

---------

Co-authored-by: letaoqin <letaoqin@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
This commit is contained in:
ltqin
2024-10-22 09:26:18 +08:00
committed by GitHub
parent 3f710930f6
commit 0394f8a713
59 changed files with 2917 additions and 1042 deletions

View File

@@ -6,8 +6,7 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_layernorm2d_fwd -j
```
This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
@@ -20,4 +19,4 @@ args:
-e epsilon (default:1e-5)
-v cpu validation or not (default:1)
-prec precision (default:fp16)
```
```