Files
composable_kernel/test/softmax/test_softmax_ut_cases.inc

81 lines
1.9 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TestSoftmax, ReduceOutermostDim)
{
std::vector<ck::index_t> reduce_dims{this->Rank - 1};
this->Run(reduce_dims);
}
TYPED_TEST(TestSoftmax, ReduceMiddleDim)
{
for(int dim = 0; dim < this->Rank - 1; ++dim)
{
std::vector<ck::index_t> reduce_dims{dim};
this->Run(reduce_dims);
}
}
TYPED_TEST(TestSoftmax, ReduceMultipleDimsWithOutermost)
{
for(int dim = 0; dim < this->Rank - 1; ++dim)
{
std::vector<ck::index_t> reduce_dims{dim, this->Rank - 1};
this->Run(reduce_dims);
}
}
TYPED_TEST(TestSoftmax, ReduceMultipleMiddleDims)
{
std::vector<ck::index_t> reduce_dims{0, 1};
if(this->Rank >= 3)
{
this->Run(reduce_dims);
}
if(this->Rank >= 4)
{
reduce_dims = std::vector<ck::index_t>{0, 2};
this->Run(reduce_dims);
reduce_dims = std::vector<ck::index_t>{0, 1, 2};
this->Run(reduce_dims);
}
}
TYPED_TEST(TestSoftmax, ReduceAllDims)
{
std::vector<ck::index_t> reduce_dims(this->Rank);
std::iota(std::begin(reduce_dims), std::end(reduce_dims), 0);
this->Run(reduce_dims);
}
TYPED_TEST(TestSoftmax, ReduceOddLengths)
{
this->in_lengths_ = {{3, 63, 1032}};
if(this->Rank >= 4)
{
this->in_lengths_ = {{1, 3, 63, 1032}};
}
this->Run({this->Rank - 1});
this->Run({this->Rank - 2});
}
int main(int argc, char** argv)
{
testing::InitGoogleTest(&argc, argv);
if(argc == 1) {}
else if(argc == 3)
{
param_mask = strtol(argv[1], nullptr, 0);
instance_index = atoi(argv[2]);
}
else
{
std::cout << "Usage of " << argv[0] << std::endl;
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
}
return RUN_ALL_TESTS();
}