mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
81 lines
1.9 KiB
C++
81 lines
1.9 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
TYPED_TEST(TestSoftmax, ReduceOutermostDim)
|
|
{
|
|
std::vector<ck::index_t> reduce_dims{this->Rank - 1};
|
|
this->Run(reduce_dims);
|
|
}
|
|
|
|
TYPED_TEST(TestSoftmax, ReduceMiddleDim)
|
|
{
|
|
for(int dim = 0; dim < this->Rank - 1; ++dim)
|
|
{
|
|
std::vector<ck::index_t> reduce_dims{dim};
|
|
this->Run(reduce_dims);
|
|
}
|
|
}
|
|
|
|
TYPED_TEST(TestSoftmax, ReduceMultipleDimsWithOutermost)
|
|
{
|
|
for(int dim = 0; dim < this->Rank - 1; ++dim)
|
|
{
|
|
std::vector<ck::index_t> reduce_dims{dim, this->Rank - 1};
|
|
this->Run(reduce_dims);
|
|
}
|
|
}
|
|
|
|
TYPED_TEST(TestSoftmax, ReduceMultipleMiddleDims)
|
|
{
|
|
std::vector<ck::index_t> reduce_dims{0, 1};
|
|
if(this->Rank >= 3)
|
|
{
|
|
this->Run(reduce_dims);
|
|
}
|
|
|
|
if(this->Rank >= 4)
|
|
{
|
|
reduce_dims = std::vector<ck::index_t>{0, 2};
|
|
this->Run(reduce_dims);
|
|
reduce_dims = std::vector<ck::index_t>{0, 1, 2};
|
|
this->Run(reduce_dims);
|
|
}
|
|
}
|
|
|
|
TYPED_TEST(TestSoftmax, ReduceAllDims)
|
|
{
|
|
std::vector<ck::index_t> reduce_dims(this->Rank);
|
|
std::iota(std::begin(reduce_dims), std::end(reduce_dims), 0);
|
|
this->Run(reduce_dims);
|
|
}
|
|
|
|
TYPED_TEST(TestSoftmax, ReduceOddLengths)
|
|
{
|
|
this->in_lengths_ = {{3, 63, 1032}};
|
|
if(this->Rank >= 4)
|
|
{
|
|
this->in_lengths_ = {{1, 3, 63, 1032}};
|
|
}
|
|
this->Run({this->Rank - 1});
|
|
this->Run({this->Rank - 2});
|
|
}
|
|
|
|
int main(int argc, char** argv)
|
|
{
|
|
testing::InitGoogleTest(&argc, argv);
|
|
if(argc == 1) {}
|
|
else if(argc == 3)
|
|
{
|
|
param_mask = strtol(argv[1], nullptr, 0);
|
|
instance_index = atoi(argv[2]);
|
|
}
|
|
else
|
|
{
|
|
std::cout << "Usage of " << argv[0] << std::endl;
|
|
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
|
|
}
|
|
return RUN_ALL_TESTS();
|
|
}
|