mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[Ck tile] support rmsnorm and related fusion (#1605)
* Add reduce2d new api
* Prevent user use cross warp reduction
* Fix bug of std caculation
* Add rmsnorm2d
* Add rmsnorm small example
* Remove static assert to prevent compile fail
* Add script to test performance and correctness
* Add missing cmake change
* refine naming
* refine example of rmsnorm
* Fix bug of rmsnorm
* Refine naming
* Fix cmake
* clang format
* Refine pipeline name
* Add add_rmsnorm2d_rdquant kernel
* Add reduce op
* host verification
* Fix bug of one pass pipeline
* Refine tile size
* Add two pass pipeline
* Rename two pass to three pass
* Fix bug of kSaveX == false
* Add instance library
* Add test script
* Fix bug of x verification
* Add save_x to trait
* Add README
* Move reduce2d into reduce folder
* Fix bug of welford when number of m warp > 1
* remove reduncant comment
* 1. move 06_rmsnorm2d to 10_rmsnorm2d
2. move 07_add_rmsnorm2d_rdquant to 11_add_rmsnorm2d_rdquant
* clang format and add missing header
* Add host validation of add + layernorm2d + rsquant
* Revert "Add host validation of add + layernorm2d + rsquant"
This reverts commit 936cb45797.
* Remove deprecated flag
This commit is contained in:
@@ -35,9 +35,9 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr"; // block per row
|
||||
return "bpr_op"; // block per row
|
||||
else
|
||||
return "wpr"; // warp per row
|
||||
return "wpr_op"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
|
||||
@@ -35,9 +35,9 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr"; // block per row
|
||||
return "bpr_tp"; // block per row
|
||||
else
|
||||
return "wpr"; // warp per row
|
||||
return "wpr_tp"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -118,8 +118,6 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
// x_window.foo();
|
||||
// gamma_window.foo();
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(beta_window, {stride_to_right_most_window});
|
||||
|
||||
Reference in New Issue
Block a user