mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Extend support for contraction 6D (#1207)
* Extend support for contraction up to 5D * Extend contraction bilinear instances * Fix interface test * Add 6d support, remove 3d,4d,5d * Fixes * Fix readme * Make defualt dim for contraction instances
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
@@ -125,18 +125,6 @@ class ContractionDeviceOpWrapper
|
||||
}
|
||||
};
|
||||
|
||||
TEST(TestContractionInterface, IncorrectNumDims)
|
||||
{
|
||||
std::vector<std::vector<ck::index_t>> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}};
|
||||
std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}};
|
||||
ContractionDeviceOpWrapper<F32, F32, F32, F32, 1> wrapper_1d;
|
||||
ContractionDeviceOpWrapper<F32, F32, F32, F32, 2> wrapper_2d;
|
||||
ContractionDeviceOpWrapper<F32, F32, F32, F32, 3> wrapper_3d;
|
||||
EXPECT_FALSE(wrapper_1d.IsSupportedInstance(Dims[0], Strides[0]));
|
||||
EXPECT_TRUE(wrapper_2d.IsSupportedInstance(Dims[1], Strides[1]));
|
||||
EXPECT_FALSE(wrapper_3d.IsSupportedInstance(Dims[2], Strides[2]));
|
||||
}
|
||||
|
||||
TEST(TestContractionInterface, IncorrectDataTypes)
|
||||
{
|
||||
std::vector<ck::index_t> Dims = {4, 4, 4, 4};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
@@ -23,8 +23,11 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
template <ck::index_t NDims>
|
||||
struct Dimensions
|
||||
{
|
||||
constexpr static ck::index_t NumDimMNK = NDims;
|
||||
|
||||
std::vector<ck::index_t> M;
|
||||
std::vector<ck::index_t> N;
|
||||
std::vector<ck::index_t> K;
|
||||
@@ -42,53 +45,58 @@ class TestContraction : public ::testing::Test
|
||||
using ComputeDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDElementOp = std::tuple_element_t<6, Tuple>;
|
||||
|
||||
std::vector<Dimensions> dimension_list = {{{32, 32}, {32, 32}, {32, 32}},
|
||||
{{16, 16}, {32, 32}, {16, 16}}};
|
||||
|
||||
std::vector<ck::index_t> init_methods = {1, 2};
|
||||
std::unique_ptr<CDElementOp> p_cd_element_op;
|
||||
|
||||
void Run()
|
||||
template <ck::index_t NumDim>
|
||||
void Run(Dimensions<NumDim> dimension_params)
|
||||
{
|
||||
for(auto& dimension_params : dimension_list)
|
||||
constexpr ck::index_t NumDimMNK = ck::remove_cvref_t<decltype(dimension_params)>::NumDimMNK;
|
||||
|
||||
std::vector<ck::index_t> StridesA(2 * NumDim);
|
||||
std::vector<ck::index_t> StridesB(2 * NumDim);
|
||||
std::vector<ck::index_t> StridesC(2 * NumDim);
|
||||
std::vector<ck::index_t> StridesD(2 * NumDim);
|
||||
|
||||
const auto& M = dimension_params.M;
|
||||
const auto& N = dimension_params.N;
|
||||
const auto& K = dimension_params.K;
|
||||
|
||||
auto merge_dims = [](const std::vector<ck::index_t>& dims01,
|
||||
const std::vector<ck::index_t>& dims23) {
|
||||
std::vector<ck::index_t> dims_szt(dims01.begin(), dims01.end());
|
||||
dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end());
|
||||
return dims_szt;
|
||||
};
|
||||
|
||||
assign_default_strides(ALayout{}, StridesA, merge_dims(M, K));
|
||||
assign_default_strides(BLayout{}, StridesB, merge_dims(N, K));
|
||||
assign_default_strides(CDLayout{}, StridesC, merge_dims(M, N));
|
||||
assign_default_strides(CDLayout{}, StridesD, merge_dims(M, N));
|
||||
|
||||
for(const ck::index_t init_method : init_methods)
|
||||
{
|
||||
std::vector<ck::index_t> StridesA;
|
||||
std::vector<ck::index_t> StridesB;
|
||||
std::vector<ck::index_t> StridesC;
|
||||
std::vector<ck::index_t> StridesD;
|
||||
|
||||
const auto& M = dimension_params.M;
|
||||
const auto& N = dimension_params.N;
|
||||
const auto& K = dimension_params.K;
|
||||
|
||||
assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]});
|
||||
assign_default_strides(BLayout{}, StridesB, {N[0], N[1], K[0], K[1]});
|
||||
assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]});
|
||||
assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]});
|
||||
|
||||
for(const ck::index_t init_method : init_methods)
|
||||
{
|
||||
bool pass =
|
||||
ck::profiler::profile_contraction_impl<ALayout,
|
||||
BLayout,
|
||||
CDLayout,
|
||||
DataType,
|
||||
ComputeDataType,
|
||||
DTupleDataType,
|
||||
CDElementOp>(true /*do_verification*/,
|
||||
init_method,
|
||||
false /*do_logs*/,
|
||||
false /*time_kernel*/,
|
||||
*p_cd_element_op,
|
||||
dimension_params.M,
|
||||
dimension_params.N,
|
||||
dimension_params.K,
|
||||
StridesA,
|
||||
StridesB,
|
||||
StridesC,
|
||||
StridesD);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
bool pass =
|
||||
ck::profiler::profile_contraction_impl<NumDimMNK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CDLayout,
|
||||
DataType,
|
||||
ComputeDataType,
|
||||
DTupleDataType,
|
||||
CDElementOp>(true /*do_verification*/,
|
||||
init_method,
|
||||
false /*do_logs*/,
|
||||
false /*time_kernel*/,
|
||||
*p_cd_element_op,
|
||||
dimension_params.M,
|
||||
dimension_params.N,
|
||||
dimension_params.K,
|
||||
StridesA,
|
||||
StridesB,
|
||||
StridesC,
|
||||
StridesD);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -122,17 +130,31 @@ TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes);
|
||||
TYPED_TEST(TestContractionBilinear, bilinear)
|
||||
{
|
||||
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
|
||||
this->Run();
|
||||
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
|
||||
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
|
||||
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
|
||||
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
|
||||
|
||||
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
|
||||
this->Run();
|
||||
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
|
||||
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
|
||||
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
|
||||
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
|
||||
}
|
||||
|
||||
TYPED_TEST(TestContractionScale, scale)
|
||||
{
|
||||
this->p_cd_element_op = std::make_unique<Scale>(1.f);
|
||||
this->Run();
|
||||
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
|
||||
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
|
||||
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
|
||||
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
|
||||
|
||||
this->p_cd_element_op = std::make_unique<Scale>(0.5f);
|
||||
this->Run();
|
||||
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
|
||||
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
|
||||
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
|
||||
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
@@ -165,15 +187,29 @@ TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecis
|
||||
TYPED_TEST(TestContractionBilinearMixedPrecision, bilinear)
|
||||
{
|
||||
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
|
||||
this->Run();
|
||||
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
|
||||
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
|
||||
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
|
||||
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
|
||||
|
||||
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
|
||||
this->Run();
|
||||
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
|
||||
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
|
||||
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
|
||||
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
|
||||
}
|
||||
|
||||
TYPED_TEST(TestContractionScaleMixedPrecision, scale)
|
||||
{
|
||||
this->p_cd_element_op = std::make_unique<Scale>(1.f);
|
||||
this->Run();
|
||||
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
|
||||
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
|
||||
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
|
||||
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
|
||||
|
||||
this->p_cd_element_op = std::make_unique<Scale>(0.5f);
|
||||
this->Run();
|
||||
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
|
||||
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
|
||||
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
|
||||
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user