Padded Generic Kernel Instance (#730)

* Add NumReduceDim template parameter to DeviceSoftmax and Softmax client API to simplify instances collecting

* Move the generic kernel instance to be the first of the instance list for elementwise op of normalization

* Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax

* Add testing of GetGenericInstance() in client_example of Softmax

* Revert "Add testing of GetGenericInstance() in client_example of Softmax"

This reverts commit f629cd9a93.

* Revert "Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax"

This reverts commit a9f0d000eb.

* Support generic kernel instance to be the first instance returned by GetInstances() for GroupNorm

* Move generic kernel instance to separate tuple for elementwise op of normalization

* Remove un-used files for softmax instance

* Store generic kernel instance to separate tuple for softmax

* Add IsSupported checking for generic instance to client example of softmax

* Replace the get_device_normalize_from_mean_meansquare_instances() by the DeviceOperationInstanceFactory class for elementwise-normalization

* clang-format fix

* Remove int8 from softmax instances

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>

[ROCm/composable_kernel commit: 0d9118226b]
This commit is contained in:
Qianfeng
2023-06-17 12:43:11 +08:00
committed by GitHub
parent 033da551a7
commit eebebd33c6
76 changed files with 552 additions and 790 deletions

View File

@@ -40,7 +40,11 @@ template <> std::string type_to_string<int8_t>() { return "int8"; }
template <> std::string type_to_string<int32_t>() { return "int32"; }
// clang-format on
template <typename InDataType, typename AccDataType, typename OutDataType, index_t Rank>
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim>
bool profile_softmax_impl(int do_verification,
int init_method,
bool do_log,
@@ -54,7 +58,13 @@ bool profile_softmax_impl(int do_verification,
if(Rank != in_length.size())
{
throw std::runtime_error("Input tensor rank is different from template argument Rank!");
}
};
if(NumReduceDim != reduce_dims.size())
{
throw std::runtime_error(
"Input reduce_dims rank is different from template argument NumReduceDim!");
};
Tensor<InDataType> in = in_strides.empty() ? Tensor<InDataType>(in_length)
: Tensor<InDataType>(in_length, in_strides);
@@ -92,8 +102,13 @@ bool profile_softmax_impl(int do_verification,
// add device softmax instances
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceOp = tensor_operation::device::
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>;
using DeviceOp = tensor_operation::device::DeviceSoftmax<InDataType,
AccDataType,
OutDataType,
PassThrough,
PassThrough,
Rank,
NumReduceDim>;
// get device op instances
const auto instances = tensor_operation::device::instance::DeviceOperationInstanceFactory<
@@ -112,13 +127,6 @@ bool profile_softmax_impl(int do_verification,
for(auto& inst_ptr : instances)
{
// Is this user's responsibility to check if problem mismatches kernel instance (ie. rank 3
// problem to rank 4 kernel) other than invoking IsSupportedArgument()?
if(!(inst_ptr->GetNumReduceDim() == static_cast<index_t>(reduce_dims.size())))
{
continue;
}
auto argument_ptr = inst_ptr->MakeArgumentPointer(in_tensor_lengths,
in_tensor_strides,
reduce_dims,