mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Padded Generic Kernel Instance (#730)
* Add NumReduceDim template parameter to DeviceSoftmax and Softmax client API to simplify instances collecting * Move the generic kernel instance to be the first of the instance list for elementwise op of normalization * Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax * Add testing of GetGenericInstance() in client_example of Softmax * Revert "Add testing of GetGenericInstance() in client_example of Softmax" This reverts commitf629cd9a93. * Revert "Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax" This reverts commita9f0d000eb. * Support generic kernel instance to be the first instance returned by GetInstances() for GroupNorm * Move generic kernel instance to separate tuple for elementwise op of normalization * Remove un-used files for softmax instance * Store generic kernel instance to separate tuple for softmax * Add IsSupported checking for generic instance to client example of softmax * Replace the get_device_normalize_from_mean_meansquare_instances() by the DeviceOperationInstanceFactory class for elementwise-normalization * clang-format fix * Remove int8 from softmax instances --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -61,8 +61,92 @@ class TestSoftmax : public ::testing::Test
|
||||
int init_method = 1; // integer value initialization
|
||||
bool log = false;
|
||||
std::vector<ck::index_t> strides; // intenionally empty, to get packed layout.
|
||||
bool pass = ck::profiler::profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank>(
|
||||
verify_, init_method, log, bench_, in_length, strides, reduce_dims, alpha, beta);
|
||||
bool pass = false;
|
||||
|
||||
if constexpr(Rank == 3)
|
||||
{
|
||||
if(reduce_dims.size() == 1)
|
||||
pass = ck::profiler::
|
||||
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 1>(verify_,
|
||||
init_method,
|
||||
log,
|
||||
bench_,
|
||||
in_length,
|
||||
strides,
|
||||
reduce_dims,
|
||||
alpha,
|
||||
beta);
|
||||
else if(reduce_dims.size() == 2)
|
||||
pass = ck::profiler::
|
||||
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 2>(verify_,
|
||||
init_method,
|
||||
log,
|
||||
bench_,
|
||||
in_length,
|
||||
strides,
|
||||
reduce_dims,
|
||||
alpha,
|
||||
beta);
|
||||
else if(reduce_dims.size() == 3)
|
||||
pass = ck::profiler::
|
||||
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 3>(verify_,
|
||||
init_method,
|
||||
log,
|
||||
bench_,
|
||||
in_length,
|
||||
strides,
|
||||
reduce_dims,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
else if constexpr(Rank == 4)
|
||||
{
|
||||
if(reduce_dims.size() == 1)
|
||||
pass = ck::profiler::
|
||||
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 1>(verify_,
|
||||
init_method,
|
||||
log,
|
||||
bench_,
|
||||
in_length,
|
||||
strides,
|
||||
reduce_dims,
|
||||
alpha,
|
||||
beta);
|
||||
else if(reduce_dims.size() == 2)
|
||||
pass = ck::profiler::
|
||||
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 2>(verify_,
|
||||
init_method,
|
||||
log,
|
||||
bench_,
|
||||
in_length,
|
||||
strides,
|
||||
reduce_dims,
|
||||
alpha,
|
||||
beta);
|
||||
else if(reduce_dims.size() == 3)
|
||||
pass = ck::profiler::
|
||||
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 3>(verify_,
|
||||
init_method,
|
||||
log,
|
||||
bench_,
|
||||
in_length,
|
||||
strides,
|
||||
reduce_dims,
|
||||
alpha,
|
||||
beta);
|
||||
else if(reduce_dims.size() == 4)
|
||||
pass = ck::profiler::
|
||||
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 4>(verify_,
|
||||
init_method,
|
||||
log,
|
||||
bench_,
|
||||
in_length,
|
||||
strides,
|
||||
reduce_dims,
|
||||
alpha,
|
||||
beta);
|
||||
};
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user