mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Padded Generic Kernel Instance (#730)
* Add NumReduceDim template parameter to DeviceSoftmax and Softmax client API to simplify instances collecting * Move the generic kernel instance to be the first of the instance list for elementwise op of normalization * Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax * Add testing of GetGenericInstance() in client_example of Softmax * Revert "Add testing of GetGenericInstance() in client_example of Softmax" This reverts commitf629cd9a93. * Revert "Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax" This reverts commita9f0d000eb. * Support generic kernel instance to be the first instance returned by GetInstances() for GroupNorm * Move generic kernel instance to separate tuple for elementwise op of normalization * Remove un-used files for softmax instance * Store generic kernel instance to separate tuple for softmax * Add IsSupported checking for generic instance to client example of softmax * Replace the get_device_normalize_from_mean_meansquare_instances() by the DeviceOperationInstanceFactory class for elementwise-normalization * clang-format fix * Remove int8 from softmax instances --------- Co-authored-by: zjing14 <zhangjing14@gmail.com> [ROCm/composable_kernel commit:0d9118226b]
This commit is contained in:
@@ -30,7 +30,12 @@ using device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances = std:
|
||||
//###################|<in, mean, square_mean, gamma, beta>| <out>| functor| NDim| MPerThread| <in, mean, square_mean, gamma, beta ScalarPerVector>| <out ScalarPerVector>|
|
||||
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 8, Sequence<8, 1, 1, 8, 8>, Sequence<8> >,
|
||||
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 4, Sequence<4, 1, 1, 4, 4>, Sequence<4> >,
|
||||
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 2, Sequence<2, 1, 1, 2, 2>, Sequence<2> >,
|
||||
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 2, Sequence<2, 1, 1, 2, 2>, Sequence<2> >
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_generic_instance = std::tuple<
|
||||
// clang-format off
|
||||
DeviceElementwiseImpl<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 1, Sequence<1, 1, 1, 1, 1>, Sequence<1> >
|
||||
// clang-format on
|
||||
>;
|
||||
@@ -39,6 +44,9 @@ void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(
|
||||
std::vector<DeviceElementwisePtr<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_generic_instance{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances{});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user