mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
Padded Generic Kernel Instance (#730)
* Add NumReduceDim template parameter to DeviceSoftmax and Softmax client API to simplify instances collecting * Move the generic kernel instance to be the first of the instance list for elementwise op of normalization * Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax * Add testing of GetGenericInstance() in client_example of Softmax * Revert "Add testing of GetGenericInstance() in client_example of Softmax" This reverts commitf629cd9a93. * Revert "Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax" This reverts commita9f0d000eb. * Support generic kernel instance to be the first instance returned by GetInstances() for GroupNorm * Move generic kernel instance to separate tuple for elementwise op of normalization * Remove un-used files for softmax instance * Store generic kernel instance to separate tuple for softmax * Add IsSupported checking for generic instance to client example of softmax * Replace the get_device_normalize_from_mean_meansquare_instances() by the DeviceOperationInstanceFactory class for elementwise-normalization * clang-format fix * Remove int8 from softmax instances --------- Co-authored-by: zjing14 <zhangjing14@gmail.com> [ROCm/composable_kernel commit:0d9118226b]
This commit is contained in:
@@ -40,7 +40,11 @@ template <> std::string type_to_string<int8_t>() { return "int8"; }
|
||||
template <> std::string type_to_string<int32_t>() { return "int32"; }
|
||||
// clang-format on
|
||||
|
||||
template <typename InDataType, typename AccDataType, typename OutDataType, index_t Rank>
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
bool profile_softmax_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
@@ -54,7 +58,13 @@ bool profile_softmax_impl(int do_verification,
|
||||
if(Rank != in_length.size())
|
||||
{
|
||||
throw std::runtime_error("Input tensor rank is different from template argument Rank!");
|
||||
}
|
||||
};
|
||||
|
||||
if(NumReduceDim != reduce_dims.size())
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Input reduce_dims rank is different from template argument NumReduceDim!");
|
||||
};
|
||||
|
||||
Tensor<InDataType> in = in_strides.empty() ? Tensor<InDataType>(in_length)
|
||||
: Tensor<InDataType>(in_length, in_strides);
|
||||
@@ -92,8 +102,13 @@ bool profile_softmax_impl(int do_verification,
|
||||
|
||||
// add device softmax instances
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using DeviceOp = tensor_operation::device::
|
||||
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>;
|
||||
using DeviceOp = tensor_operation::device::DeviceSoftmax<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
// get device op instances
|
||||
const auto instances = tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
@@ -112,13 +127,6 @@ bool profile_softmax_impl(int do_verification,
|
||||
|
||||
for(auto& inst_ptr : instances)
|
||||
{
|
||||
// Is this user's responsibility to check if problem mismatches kernel instance (ie. rank 3
|
||||
// problem to rank 4 kernel) other than invoking IsSupportedArgument()?
|
||||
if(!(inst_ptr->GetNumReduceDim() == static_cast<index_t>(reduce_dims.size())))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto argument_ptr = inst_ptr->MakeArgumentPointer(in_tensor_lengths,
|
||||
in_tensor_strides,
|
||||
reduce_dims,
|
||||
|
||||
Reference in New Issue
Block a user