mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Standalone softmax kernel (#284)
* initial stub for standalone softmax
* start device_softmax_mk_to_mk as a wrapper to device_reduce_mk_to_m
* host softmax validates
* compiles; to implement beta scaling
* use NaN trick to efficiently ignore OOB values during sum of exponentials
* freeload device_reduce's utility functions
* clean up interface
* adding prior value (beta scaling)
* remove restriction related to perf considerations
* apply clang-format
* clean; disable diagnostics
* resolve conflicts
* add exp wrapper
* honor HostTensorDesc interface; allow implicit cast from different vector<T> type
* test softmax for fp16/fp32
* update readme
* amend commit NaN trick
* remove redundant param added during development
* format
* replace ScalarDataType with AccDataType
* separate out test programs by precision type
* move softmax sample code to its own folder
* format
* keep up with recent changes in reduction API
* remove extra header
[ROCm/composable_kernel commit: 15c89e81f0]
This commit is contained in:
26
test/softmax/test_softmax_fp32.cpp
Normal file
26
test/softmax/test_softmax_fp32.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_softmax_util.hpp"
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestSoftmaxFP32 : public ck::TestSoftmax<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
|
||||
std::tuple<float, float, float, I<3>, I<1>, I<256>, I<8>, I<32>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<1>, I<256>, I<4>, I<64>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<1>, I<256>, I<2>, I<128>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<1>, I<256>, I<1>, I<256>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<2>, I<256>, I<8>, I<32>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<2>, I<256>, I<4>, I<64>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<2>, I<256>, I<2>, I<128>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<2>, I<256>, I<1>, I<256>, I<1>, I<4>, I<1>, I<4>, I<4>>
|
||||
>;
|
||||
// clang-format on
|
||||
TYPED_TEST_SUITE(TestSoftmaxFP32, KernelTypes);
|
||||
TYPED_TEST(TestSoftmaxFP32, Test_FP32) { this->Run(); }
|
||||
Reference in New Issue
Block a user