[Ck tile] support rmsnorm and related fusion (#1605)

* Add reduce2d new api

* Prevent user use cross warp reduction

* Fix bug of std caculation

* Add rmsnorm2d

* Add rmsnorm small example

* Remove static assert to prevent compile fail

* Add script to test performance and correctness

* Add missing cmake change

* refine naming

* refine example of rmsnorm

* Fix bug of rmsnorm

* Refine naming

* Fix cmake

* clang format

* Refine pipeline name

* Add add_rmsnorm2d_rdquant kernel

* Add reduce op

* host verification

* Fix bug of one pass pipeline

* Refine tile size

* Add two pass pipeline

* Rename two pass to three pass

* Fix bug of kSaveX == false

* Add instance library

* Add test script

* Fix bug of x verification

* Add save_x to trait

* Add README

* Move reduce2d into reduce folder

* Fix bug of welford when number of m warp > 1

* remove reduncant comment

* 1. move 06_rmsnorm2d to 10_rmsnorm2d
2. move 07_add_rmsnorm2d_rdquant to 11_add_rmsnorm2d_rdquant

* clang format and add missing header

* Add host validation of add + layernorm2d + rsquant

* Revert "Add host validation of add + layernorm2d + rsquant"

This reverts commit 936cb45797.

* Remove deprecated flag
This commit is contained in:
rocking
2024-10-30 15:22:56 +08:00
committed by GitHub
parent 8632221814
commit 3d60953477
90 changed files with 4674 additions and 128 deletions

View File

@@ -35,9 +35,9 @@ struct Layernorm2dFwdPipelineOnePass
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr"; // block per row
return "bpr_op"; // block per row
else
return "wpr"; // warp per row
return "wpr_op"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()

View File

@@ -35,9 +35,9 @@ struct Layernorm2dFwdPipelineTwoPass
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr"; // block per row
return "bpr_tp"; // block per row
else
return "wpr"; // warp per row
return "wpr_tp"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
@@ -118,8 +118,6 @@ struct Layernorm2dFwdPipelineTwoPass
ck_tile::index_t stride_to_right_most_window =
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
// x_window.foo();
// gamma_window.foo();
move_tile_window(x_window, {0, -Block_N});
move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(beta_window, {stride_to_right_most_window});