Group norm (#417)

* Add groupnorm example by layernorm
1.  Reference is not ready
2. shape of gamma and beta need to be fix

* Let shape of gamma and beta can be same as x

* Modify test, instance and client example

* [What] Fix bug of layernorm for greater than 2 dimension.
[Why] We need to get upper length from merge transform instead of embed transform.

* Add reference for groupnorm

* Fuse sigmoid after groupnorm

* [What] Rename original layernorm into layernorm2d
[Why] Prepare to add groupnorm using layernorm5d

* clang-format

* Add groupnorm test

* Refine error message

* Add groupnorm ckProfiler

* Test groupnorm kernel from device_instance

* update example

* upadte profiler

* Fix test naming

* Fix argc number

* Move descriptor and sweeponce to argument for quick debugging

Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
rocking5566
2022-09-20 11:30:46 +08:00
committed by GitHub
parent f584ab0c54
commit 4eba345f6e
24 changed files with 1218 additions and 416 deletions

View File

@@ -17,34 +17,40 @@ using F32 = float;
using Pass = ck::tensor_operation::element_wise::PassThrough;
template <index_t Rank, index_t Reduce>
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_layernorm_f16_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1>, // fallback kernel
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 2, 2, 2>, // fallback kernel
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 4, 4, 4>, // fallback kernel
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 16, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 32, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 8, 8, 8>
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernormImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel
DeviceLayernormImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel
DeviceLayernormImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 1, 4, 1, 4, 4>, // fallback kernel
DeviceLayernormImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 2, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 2, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>
// clang-format on
>;
void add_device_layernorm_f16_rank2_instances(
std::vector<DeviceLayernormPtr<F16, F16, F16, F32, F16, Pass, 2, 1>>& instances)
void add_device_layernorm_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceLayernorm<F16, F16, F16, F32, F16, Pass, 2, 1>>>& instances)
{
add_device_operation_instances(instances, device_layernorm_f16_instances<2, 1>{});
add_device_operation_instances(instances, device_layernorm_f16_instances<Pass, 2, 1>{});
}
void add_device_layernorm_f16_rank4_instances(
std::vector<DeviceLayernormPtr<F16, F16, F16, F32, F16, Pass, 4, 3>>& instances)
void add_device_layernorm_rank_4_3_f16_instances(
std::vector<std::unique_ptr<DeviceLayernorm<F16, F16, F16, F32, F16, Pass, 4, 3>>>& instances)
{
add_device_operation_instances(instances, device_layernorm_f16_instances<4, 3>{});
add_device_operation_instances(instances, device_layernorm_f16_instances<Pass, 4, 3>{});
}
void add_device_layernorm_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceLayernorm<F16, F16, F16, F32, F16, Pass, 5, 3>>>& instances)
{
add_device_operation_instances(instances, device_layernorm_f16_instances<Pass, 5, 3>{});
}
} // namespace instance

View File

@@ -16,33 +16,39 @@ using F32 = float;
using Pass = ck::tensor_operation::element_wise::PassThrough;
template <index_t Rank, index_t Reduce>
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_layernorm_f32_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1>, // fallback kernel
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 2, 2, 2>, // fallback kernel
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 4, 64, 1, 8, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 8, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 16, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 32, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 4, 4, 4>
DeviceLayernormImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel
DeviceLayernormImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel
DeviceLayernormImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 4, 64, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 2, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 2, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 2, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>
// clang-format on
>;
void add_device_layernorm_f32_rank2_instances(
std::vector<DeviceLayernormPtr<F32, F32, F32, F32, F32, Pass, 2, 1>>& instances)
void add_device_layernorm_rank_2_1_f32_instances(
std::vector<std::unique_ptr<DeviceLayernorm<F32, F32, F32, F32, F32, Pass, 2, 1>>>& instances)
{
add_device_operation_instances(instances, device_layernorm_f32_instances<2, 1>{});
add_device_operation_instances(instances, device_layernorm_f32_instances<Pass, 2, 1>{});
}
void add_device_layernorm_f32_rank4_instances(
std::vector<DeviceLayernormPtr<F32, F32, F32, F32, F32, Pass, 4, 3>>& instances)
void add_device_layernorm_rank_4_3_f32_instances(
std::vector<std::unique_ptr<DeviceLayernorm<F32, F32, F32, F32, F32, Pass, 4, 3>>>& instances)
{
add_device_operation_instances(instances, device_layernorm_f32_instances<4, 3>{});
add_device_operation_instances(instances, device_layernorm_f32_instances<Pass, 4, 3>{});
}
void add_device_layernorm_rank_5_3_f32_instances(
std::vector<std::unique_ptr<DeviceLayernorm<F32, F32, F32, F32, F32, Pass, 5, 3>>>& instances)
{
add_device_operation_instances(instances, device_layernorm_f32_instances<Pass, 5, 3>{});
}
} // namespace instance