Extend support for contraction 6D (#1207)

* Extend support for contraction up to 5D

* Extend contraction bilinear instances

* Fix interface test

* Add 6d support, remove 3d,4d,5d

* Fixes

* Fix readme

* Make defualt dim for contraction instances
This commit is contained in:
Bartłomiej Kocot
2024-04-09 23:46:21 +02:00
committed by GitHub
parent 366592b0ff
commit ced5af16f7
126 changed files with 5033 additions and 551 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <stdexcept>
#include <vector>
@@ -125,18 +125,6 @@ class ContractionDeviceOpWrapper
}
};
TEST(TestContractionInterface, IncorrectNumDims)
{
std::vector<std::vector<ck::index_t>> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}};
std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}};
ContractionDeviceOpWrapper<F32, F32, F32, F32, 1> wrapper_1d;
ContractionDeviceOpWrapper<F32, F32, F32, F32, 2> wrapper_2d;
ContractionDeviceOpWrapper<F32, F32, F32, F32, 3> wrapper_3d;
EXPECT_FALSE(wrapper_1d.IsSupportedInstance(Dims[0], Strides[0]));
EXPECT_TRUE(wrapper_2d.IsSupportedInstance(Dims[1], Strides[1]));
EXPECT_FALSE(wrapper_3d.IsSupportedInstance(Dims[2], Strides[2]));
}
TEST(TestContractionInterface, IncorrectDataTypes)
{
std::vector<ck::index_t> Dims = {4, 4, 4, 4};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
@@ -23,8 +23,11 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale;
template <ck::index_t NDims>
struct Dimensions
{
constexpr static ck::index_t NumDimMNK = NDims;
std::vector<ck::index_t> M;
std::vector<ck::index_t> N;
std::vector<ck::index_t> K;
@@ -42,53 +45,58 @@ class TestContraction : public ::testing::Test
using ComputeDataType = std::tuple_element_t<5, Tuple>;
using CDElementOp = std::tuple_element_t<6, Tuple>;
std::vector<Dimensions> dimension_list = {{{32, 32}, {32, 32}, {32, 32}},
{{16, 16}, {32, 32}, {16, 16}}};
std::vector<ck::index_t> init_methods = {1, 2};
std::unique_ptr<CDElementOp> p_cd_element_op;
void Run()
template <ck::index_t NumDim>
void Run(Dimensions<NumDim> dimension_params)
{
for(auto& dimension_params : dimension_list)
constexpr ck::index_t NumDimMNK = ck::remove_cvref_t<decltype(dimension_params)>::NumDimMNK;
std::vector<ck::index_t> StridesA(2 * NumDim);
std::vector<ck::index_t> StridesB(2 * NumDim);
std::vector<ck::index_t> StridesC(2 * NumDim);
std::vector<ck::index_t> StridesD(2 * NumDim);
const auto& M = dimension_params.M;
const auto& N = dimension_params.N;
const auto& K = dimension_params.K;
auto merge_dims = [](const std::vector<ck::index_t>& dims01,
const std::vector<ck::index_t>& dims23) {
std::vector<ck::index_t> dims_szt(dims01.begin(), dims01.end());
dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end());
return dims_szt;
};
assign_default_strides(ALayout{}, StridesA, merge_dims(M, K));
assign_default_strides(BLayout{}, StridesB, merge_dims(N, K));
assign_default_strides(CDLayout{}, StridesC, merge_dims(M, N));
assign_default_strides(CDLayout{}, StridesD, merge_dims(M, N));
for(const ck::index_t init_method : init_methods)
{
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC;
std::vector<ck::index_t> StridesD;
const auto& M = dimension_params.M;
const auto& N = dimension_params.N;
const auto& K = dimension_params.K;
assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(BLayout{}, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]});
assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]});
for(const ck::index_t init_method : init_methods)
{
bool pass =
ck::profiler::profile_contraction_impl<ALayout,
BLayout,
CDLayout,
DataType,
ComputeDataType,
DTupleDataType,
CDElementOp>(true /*do_verification*/,
init_method,
false /*do_logs*/,
false /*time_kernel*/,
*p_cd_element_op,
dimension_params.M,
dimension_params.N,
dimension_params.K,
StridesA,
StridesB,
StridesC,
StridesD);
EXPECT_TRUE(pass);
}
bool pass =
ck::profiler::profile_contraction_impl<NumDimMNK,
ALayout,
BLayout,
CDLayout,
DataType,
ComputeDataType,
DTupleDataType,
CDElementOp>(true /*do_verification*/,
init_method,
false /*do_logs*/,
false /*time_kernel*/,
*p_cd_element_op,
dimension_params.M,
dimension_params.N,
dimension_params.K,
StridesA,
StridesB,
StridesC,
StridesD);
EXPECT_TRUE(pass);
}
}
};
@@ -122,17 +130,31 @@ TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes);
TYPED_TEST(TestContractionBilinear, bilinear)
{
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
}
TYPED_TEST(TestContractionScale, scale)
{
this->p_cd_element_op = std::make_unique<Scale>(1.f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
this->p_cd_element_op = std::make_unique<Scale>(0.5f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
}
template <typename Tuple>
@@ -165,15 +187,29 @@ TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecis
TYPED_TEST(TestContractionBilinearMixedPrecision, bilinear)
{
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
}
TYPED_TEST(TestContractionScaleMixedPrecision, scale)
{
this->p_cd_element_op = std::make_unique<Scale>(1.f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
this->p_cd_element_op = std::make_unique<Scale>(0.5f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
}