mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Padded Generic Kernel Instance (#730)
* Add NumReduceDim template parameter to DeviceSoftmax and Softmax client API to simplify instances collecting * Move the generic kernel instance to be the first of the instance list for elementwise op of normalization * Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax * Add testing of GetGenericInstance() in client_example of Softmax * Revert "Add testing of GetGenericInstance() in client_example of Softmax" This reverts commitf629cd9a93. * Revert "Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax" This reverts commita9f0d000eb. * Support generic kernel instance to be the first instance returned by GetInstances() for GroupNorm * Move generic kernel instance to separate tuple for elementwise op of normalization * Remove un-used files for softmax instance * Store generic kernel instance to separate tuple for softmax * Add IsSupported checking for generic instance to client example of softmax * Replace the get_device_normalize_from_mean_meansquare_instances() by the DeviceOperationInstanceFactory class for elementwise-normalization * clang-format fix * Remove int8 from softmax instances --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -172,18 +172,19 @@ int main()
|
||||
BLayout,
|
||||
CLayout>();
|
||||
|
||||
const auto normalize_ptrs =
|
||||
ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances<
|
||||
CDataType,
|
||||
ReduceDataType,
|
||||
ReduceDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
LayerNormOutDataType>();
|
||||
|
||||
std::cout << "found " << gemm_reduce_ptrs.size()
|
||||
<< " gemm_reduceMean_reduceSquareMean instances" << std::endl;
|
||||
|
||||
using NormalizeDeviceOp = ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<CDataType, ReduceDataType, ReduceDataType, GammaDataType, BetaDataType>,
|
||||
ck::Tuple<LayerNormOutDataType>,
|
||||
ck::tensor_operation::element_wise::Normalize,
|
||||
2>;
|
||||
|
||||
const auto normalize_ptrs =
|
||||
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
NormalizeDeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl;
|
||||
|
||||
auto f_matrix_space_size =
|
||||
|
||||
Reference in New Issue
Block a user