mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add instance for elementwise normlization (#573)
* added instances for large N * add instance for elementwise normlization * added supported restrict in device_elementwise_normalization_impl.hpp
This commit is contained in:
@@ -533,6 +533,11 @@ struct DeviceElementwiseNormalizationImpl
|
||||
return (false);
|
||||
}
|
||||
|
||||
if(p_arg_->x_lds_size_ >= 65536)
|
||||
{
|
||||
return (false);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
|
||||
@@ -23,6 +23,11 @@ template <typename XElementwise, typename YElementwise, index_t Rank, index_t Re
|
||||
using device_elementwise_normalization_f16_instances =
|
||||
std::tuple <
|
||||
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
|
||||
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel for large N
|
||||
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel for large N
|
||||
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>, // fallback kernel for large N
|
||||
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>, // fallback kernel for large N
|
||||
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 4, 64, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel for large N
|
||||
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel
|
||||
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel
|
||||
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 1, 4, 1, 4, 4>, // fallback kernel
|
||||
|
||||
@@ -23,7 +23,7 @@ class TestElementwiseLayernorm : public ::testing::Test
|
||||
{
|
||||
// M, N
|
||||
std::vector<std::vector<ck::index_t>> lengths = {
|
||||
{1, 1}, {25, 16}, {39, 777}, {100, 200}, {1024, 1024}, {48 * 256, 2048}};
|
||||
{1, 1}, {25, 16}, {39, 777}, {100, 200}, {1024, 1024}, {48 * 256, 2048}, {4096, 8192}};
|
||||
|
||||
for(auto length : lengths)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user