mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Padded Generic Kernel Instance (#730)
* Add NumReduceDim template parameter to DeviceSoftmax and Softmax client API to simplify instances collecting * Move the generic kernel instance to be the first of the instance list for elementwise op of normalization * Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax * Add testing of GetGenericInstance() in client_example of Softmax * Revert "Add testing of GetGenericInstance() in client_example of Softmax" This reverts commitf629cd9a93. * Revert "Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax" This reverts commita9f0d000eb. * Support generic kernel instance to be the first instance returned by GetInstances() for GroupNorm * Move generic kernel instance to separate tuple for elementwise op of normalization * Remove un-used files for softmax instance * Store generic kernel instance to separate tuple for softmax * Add IsSupported checking for generic instance to client example of softmax * Replace the get_device_normalize_from_mean_meansquare_instances() by the DeviceOperationInstanceFactory class for elementwise-normalization * clang-format fix * Remove int8 from softmax instances --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -38,16 +38,9 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
OutDataType,
|
||||
InElementwiseOp,
|
||||
AccElementwiseOp,
|
||||
Rank>
|
||||
Rank,
|
||||
NumReduceDim>
|
||||
{
|
||||
static constexpr index_t kRank = Rank;
|
||||
static constexpr index_t kNumReduceDim = NumReduceDim;
|
||||
static constexpr index_t kNumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
virtual index_t GetRank() const override { return kRank; }
|
||||
|
||||
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
static constexpr index_t NumSrcDim = Rank;
|
||||
@@ -287,13 +280,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
{
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(kNumInvariantDim == 0)
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.inStrides_[kNumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
|
||||
if(arg.inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -316,7 +309,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
}
|
||||
|
||||
// To improve
|
||||
if(kNumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0)
|
||||
if(NumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user