mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
* Add NumReduceDim template parameter to DeviceSoftmax and Softmax client API to simplify instances collecting * Move the generic kernel instance to be the first of the instance list for elementwise op of normalization * Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax * Add testing of GetGenericInstance() in client_example of Softmax * Revert "Add testing of GetGenericInstance() in client_example of Softmax" This reverts commitf629cd9a93. * Revert "Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax" This reverts commita9f0d000eb. * Support generic kernel instance to be the first instance returned by GetInstances() for GroupNorm * Move generic kernel instance to separate tuple for elementwise op of normalization * Remove un-used files for softmax instance * Store generic kernel instance to separate tuple for softmax * Add IsSupported checking for generic instance to client example of softmax * Replace the get_device_normalize_from_mean_meansquare_instances() by the DeviceOperationInstanceFactory class for elementwise-normalization * clang-format fix * Remove int8 from softmax instances --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
33 lines
758 B
C++
33 lines
758 B
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include <algorithm>
|
|
#include <stdexcept>
|
|
#include <vector>
|
|
|
|
#include "gtest/gtest.h"
|
|
#include "test_softmax_util.hpp"
|
|
|
|
template <ck::index_t N>
|
|
using I = ck::Number<N>;
|
|
|
|
using F16 = ck::half_t;
|
|
using F32 = float;
|
|
|
|
template <typename Tuple>
|
|
class TestSoftmax : public ck::TestSoftmax<Tuple>
|
|
{
|
|
};
|
|
|
|
// clang-format off
|
|
using KernelTypes = ::testing::Types<
|
|
// InDataType, AccDataType, OutDataType, Rank
|
|
std::tuple< F16, F32, F16, I<3>>,
|
|
std::tuple< F32, F32, F32, I<3>>
|
|
>;
|
|
// clang-format on
|
|
|
|
TYPED_TEST_SUITE(TestSoftmax, KernelTypes);
|
|
|
|
#include "test_softmax_ut_cases.inc"
|