Batchnorm splitk single kernel (#771)

* Use dim 0 as faster dim for writing mean/var/count workspace in batchnorm multiblock method [performance]

* Add CountDataType as template parameter in blockwise_welford

* Add utility/get_shift.hpp

* Add BatchNorm multiblock single-kernel implementation

* Add smem inline assembly based implementation of gms_init/gms_barrier/gms_reset for gfx90a

* Renaming in device_batchnorm_forward_impl.hpp

* Tiny fix in the batchnorm_fwd profiler

* Revert "Add smem inline assembly based implementation of gms_init/gms_barrier/gms_reset for gfx90a"

This reverts commit d16d00919c.

* Use the old two-kernel batchnorm multiblock method for gfx1030

* Use the old two-kernel batchnorm multiblock method for gfx908

* use the single-kernel batchnorm multiblock method only for gfx90a

* Remove get_wave_id() from utility/get_id.hpp since it is not used

* Set true for testing running mean/variance and saving mean/invvariance in the examples

* Fix to copy-right words

* Remove un-needed including in utility/get_id.hpp

* Add comments to workgroup_synchronization.hpp

* Remove un-used codes in gridwise_multiblock_batchnorm_forward.hpp

* Renaming in the kernels

* Remove un-used kernel file

[ROCm/composable_kernel commit: 8f5cafaf04]
This commit is contained in:
Qianfeng
2023-07-06 23:58:55 +08:00
committed by GitHub
parent da8a7b63ec
commit b7192d8e4c
14 changed files with 2330 additions and 119 deletions

View File

@@ -148,7 +148,7 @@ int profile_batchnorm_forward(int argc, char* argv[])
{
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
{
profile_batchnorm_forward_impl<F16, F16, F32, F16, F16, F16, 4, 3>(
profile_batchnorm_forward_impl<F16, F16, F32, F16, F16, F32, 4, 3>(
arg_parser.do_verification,
arg_parser.init_method,
arg_parser.do_dumpout,