mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-12 01:08:27 +00:00
[CK] add composable kernel support on gfx1250 (#6978) ## Motivation Add composable kernel support on gfx1250. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Qun Lin <qlin@amd.com> Co-authored-by: jialuo12_amdeng <jia.luo@amd.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: hsivasun_amdeng <haresh.sivasuntharampillai@amd.com>
119 lines
3.6 KiB
C++
119 lines
3.6 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include "gtest/gtest.h"
|
|
#include "profiler/profile_permute_scale_impl.hpp"
|
|
|
|
using F16 = ck::half_t;
|
|
using F32 = float;
|
|
using ck::index_t;
|
|
static ck::index_t param_mask = 0xffff;
|
|
static ck::index_t instance_index = -1;
|
|
template <typename Tuple>
|
|
class TestPermute : public ::testing::Test
|
|
{
|
|
protected:
|
|
using ADataType = std::tuple_element_t<0, Tuple>;
|
|
using BDataType = std::tuple_element_t<1, Tuple>;
|
|
|
|
constexpr bool skip_case()
|
|
{
|
|
#ifndef CK_ENABLE_FP16
|
|
if constexpr(ck::is_same_v<ADataType, F16> || ck::is_same_v<BDataType, F16>)
|
|
{
|
|
return true;
|
|
}
|
|
#endif
|
|
#ifndef CK_ENABLE_FP32
|
|
if constexpr(ck::is_same_v<ADataType, F32> || ck::is_same_v<BDataType, F32>)
|
|
{
|
|
return true;
|
|
}
|
|
#endif
|
|
return false;
|
|
}
|
|
|
|
template <ck::index_t NDims>
|
|
void Run(std::vector<ck::index_t> lengths,
|
|
std::vector<ck::index_t> input_strides,
|
|
std::vector<ck::index_t> output_strides)
|
|
{
|
|
if(!skip_case())
|
|
{
|
|
bool success = ck::profiler::profile_permute_scale_impl<ADataType, BDataType, NDims>(
|
|
true, 2, false, false, lengths, input_strides, output_strides, instance_index);
|
|
EXPECT_TRUE(success);
|
|
}
|
|
}
|
|
};
|
|
|
|
using KernelTypes = ::testing::Types<std::tuple<F16, F16>, std::tuple<F32, F32>>;
|
|
|
|
TYPED_TEST_SUITE(TestPermute, KernelTypes);
|
|
TYPED_TEST(TestPermute, Test1D)
|
|
{
|
|
constexpr ck::index_t NumDims = 1;
|
|
this->template Run<NumDims>({16}, {1}, {1});
|
|
this->template Run<NumDims>({16}, {1}, {2});
|
|
this->template Run<NumDims>({1}, {1}, {1});
|
|
}
|
|
|
|
TYPED_TEST(TestPermute, Test2D)
|
|
{
|
|
constexpr ck::index_t NumDims = 2;
|
|
this->template Run<NumDims>({8, 16}, {16, 1}, {1, 8});
|
|
this->template Run<NumDims>({8, 16}, {1, 8}, {16, 1});
|
|
this->template Run<NumDims>({1, 1}, {1, 1}, {1, 1});
|
|
}
|
|
|
|
TYPED_TEST(TestPermute, Test3D)
|
|
{
|
|
constexpr ck::index_t NumDims = 3;
|
|
this->template Run<NumDims>({8, 2, 8}, {16, 8, 1}, {1, 8, 16});
|
|
this->template Run<NumDims>({8, 2, 8}, {1, 8, 16}, {16, 8, 1});
|
|
this->template Run<NumDims>({1, 1, 1}, {1, 1, 1}, {1, 1, 1});
|
|
}
|
|
|
|
TYPED_TEST(TestPermute, Test4D)
|
|
{
|
|
constexpr ck::index_t NumDims = 4;
|
|
this->template Run<NumDims>({8, 2, 3, 8}, {48, 24, 8, 1}, {1, 8, 16, 48});
|
|
this->template Run<NumDims>({8, 2, 3, 8}, {1, 8, 16, 48}, {48, 24, 8, 1});
|
|
this->template Run<NumDims>({1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1});
|
|
}
|
|
|
|
TYPED_TEST(TestPermute, Test5D)
|
|
{
|
|
constexpr ck::index_t NumDims = 5;
|
|
this->template Run<NumDims>({8, 2, 3, 4, 8}, {192, 96, 32, 8, 1}, {1, 8, 16, 48, 192});
|
|
this->template Run<NumDims>({8, 2, 3, 4, 8}, {1, 8, 16, 48, 192}, {192, 96, 32, 8, 1});
|
|
this->template Run<NumDims>({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1});
|
|
}
|
|
|
|
TYPED_TEST(TestPermute, Test6D)
|
|
{
|
|
constexpr ck::index_t NumDims = 6;
|
|
this->template Run<NumDims>(
|
|
{8, 2, 3, 4, 5, 8}, {960, 480, 160, 40, 8, 1}, {1, 8, 16, 48, 192, 960});
|
|
this->template Run<NumDims>(
|
|
{8, 2, 3, 4, 5, 8}, {1, 8, 16, 48, 192, 960}, {960, 480, 160, 40, 8, 1});
|
|
this->template Run<NumDims>({1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1});
|
|
}
|
|
|
|
int main(int argc, char** argv)
|
|
{
|
|
testing::InitGoogleTest(&argc, argv);
|
|
if(argc == 1) {}
|
|
else if(argc == 3)
|
|
{
|
|
param_mask = strtol(argv[1], nullptr, 0);
|
|
instance_index = atoi(argv[2]);
|
|
}
|
|
else
|
|
{
|
|
std::cout << "Usage of " << argv[0] << std::endl;
|
|
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
|
|
}
|
|
return RUN_ALL_TESTS();
|
|
}
|