mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Revert "Add support for mixed precision in contraction scale and bilinear" (#967)
* Revert "Add support for mixed precision in contraction scale and bilinear (#936)"
This reverts commit f07485060e.
* revert commits #957 and #960
This commit is contained in:
@@ -10,12 +10,9 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "profiler/profile_contraction_impl.hpp"
|
||||
#include "profiler/profile_contraction_utils.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using F64 = double;
|
||||
using F32 = float;
|
||||
using F64 = double;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -23,49 +20,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
struct Dimensions
|
||||
struct MemoryParams
|
||||
{
|
||||
std::vector<ck::index_t> M;
|
||||
std::vector<ck::index_t> N;
|
||||
std::vector<ck::index_t> K;
|
||||
std::vector<ck::index_t> StridesA;
|
||||
std::vector<ck::index_t> StridesB;
|
||||
std::vector<ck::index_t> StridesC;
|
||||
std::vector<ck::index_t> StridesD;
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestContraction : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CDLayout = std::tuple_element_t<2, Tuple>;
|
||||
using DataType = std::tuple_element_t<3, Tuple>;
|
||||
using DTupleDataType = std::tuple_element_t<4, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDElementOp = std::tuple_element_t<6, Tuple>;
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CDLayout = std::tuple_element_t<2, Tuple>;
|
||||
using DataType = std::tuple_element_t<3, Tuple>;
|
||||
using DTupleDataType = std::tuple_element_t<4, Tuple>;
|
||||
using CDElementOp = std::tuple_element_t<5, Tuple>;
|
||||
|
||||
std::vector<Dimensions> dimension_list = {{{32, 32}, {32, 32}, {32, 32}},
|
||||
{{16, 16}, {32, 32}, {16, 16}}};
|
||||
std::vector<MemoryParams> list_of_memory_params = {{{32, 32},
|
||||
{32, 32},
|
||||
{32, 32},
|
||||
{32768, 1024, 32, 1},
|
||||
{32768, 1024, 32, 1},
|
||||
{32768, 1024, 32, 1},
|
||||
{32768, 1024, 32, 1}},
|
||||
{{16, 16},
|
||||
{32, 32},
|
||||
{16, 16},
|
||||
{4096, 256, 16, 1},
|
||||
{16, 1, 8192, 256},
|
||||
{16384, 1024, 32, 1},
|
||||
{16384, 1024, 32, 1}}};
|
||||
|
||||
std::vector<ck::index_t> init_methods = {1, 2};
|
||||
std::vector<ck::index_t> init_methods = {0, 1, 2};
|
||||
std::unique_ptr<CDElementOp> p_cd_element_op;
|
||||
|
||||
void Run()
|
||||
{
|
||||
for(auto& dimension_params : dimension_list)
|
||||
for(auto& memory_params : list_of_memory_params)
|
||||
{
|
||||
std::vector<ck::index_t> StridesA;
|
||||
std::vector<ck::index_t> StridesB;
|
||||
std::vector<ck::index_t> StridesC;
|
||||
std::vector<ck::index_t> StridesD;
|
||||
|
||||
const auto& M = dimension_params.M;
|
||||
const auto& N = dimension_params.N;
|
||||
const auto& K = dimension_params.K;
|
||||
|
||||
assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]});
|
||||
assign_default_strides(BLayout{}, StridesB, {N[0], N[1], K[0], K[1]});
|
||||
assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]});
|
||||
assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]});
|
||||
|
||||
for(const ck::index_t init_method : init_methods)
|
||||
{
|
||||
bool pass =
|
||||
@@ -73,20 +70,19 @@ class TestContraction : public ::testing::Test
|
||||
BLayout,
|
||||
CDLayout,
|
||||
DataType,
|
||||
ComputeDataType,
|
||||
DTupleDataType,
|
||||
CDElementOp>(true /*do_verification*/,
|
||||
init_method,
|
||||
false /*do_logs*/,
|
||||
false /*time_kernel*/,
|
||||
*p_cd_element_op,
|
||||
dimension_params.M,
|
||||
dimension_params.N,
|
||||
dimension_params.K,
|
||||
StridesA,
|
||||
StridesB,
|
||||
StridesC,
|
||||
StridesD);
|
||||
memory_params.M,
|
||||
memory_params.N,
|
||||
memory_params.K,
|
||||
memory_params.StridesA,
|
||||
memory_params.StridesB,
|
||||
memory_params.StridesC,
|
||||
memory_params.StridesD);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
}
|
||||
@@ -103,18 +99,24 @@ class TestContractionBilinear : public TestContraction<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op) \
|
||||
std::tuple<Row, Row, Row, dt, tuple_dt, compute_dt, op>, \
|
||||
std::tuple<Row, Col, Row, dt, tuple_dt, compute_dt, op>, \
|
||||
std::tuple<Col, Row, Row, dt, tuple_dt, compute_dt, op>, \
|
||||
std::tuple<Col, Col, Row, dt, tuple_dt, compute_dt, op>
|
||||
|
||||
using BilinearKernelTypes =
|
||||
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F32, Bilinear),
|
||||
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F64, Bilinear)>;
|
||||
::testing::Types<std::tuple<Row, Row, Row, F32, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Row, Col, Row, F32, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Col, Row, Row, F32, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Col, Col, Row, F32, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Row, Row, Row, F64, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Row, Col, Row, F64, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Col, Row, Row, F64, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Col, Col, Row, F64, ck::Tuple<F32>, Bilinear>>;
|
||||
|
||||
using ScaleKernelTypes = ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F32, Scale),
|
||||
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F64, Scale)>;
|
||||
using ScaleKernelTypes = ::testing::Types<std::tuple<Row, Row, Row, F32, ck::Tuple<>, Scale>,
|
||||
std::tuple<Row, Col, Row, F32, ck::Tuple<>, Scale>,
|
||||
std::tuple<Col, Row, Row, F32, ck::Tuple<>, Scale>,
|
||||
std::tuple<Col, Col, Row, F32, ck::Tuple<>, Scale>,
|
||||
std::tuple<Row, Row, Row, F64, ck::Tuple<>, Scale>,
|
||||
std::tuple<Row, Col, Row, F64, ck::Tuple<>, Scale>,
|
||||
std::tuple<Col, Row, Row, F64, ck::Tuple<>, Scale>,
|
||||
std::tuple<Col, Col, Row, F64, ck::Tuple<>, Scale>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes);
|
||||
TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes);
|
||||
@@ -134,46 +136,3 @@ TYPED_TEST(TestContractionScale, scale)
|
||||
this->p_cd_element_op = std::make_unique<Scale>(0.5f);
|
||||
this->Run();
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class TestContractionScaleMixedPrecision : public TestContraction<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestContractionBilinearMixedPrecision : public TestContraction<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
using BilinearKernelTypesMixedPrecision =
|
||||
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F16, Bilinear),
|
||||
ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, BF16, Bilinear),
|
||||
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F32, Bilinear),
|
||||
ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<F16>, F32, Bilinear),
|
||||
ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<BF16>, F32, Bilinear)>;
|
||||
|
||||
using ScaleKernelTypesMixedPrecision =
|
||||
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F16, Scale),
|
||||
ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, BF16, Scale),
|
||||
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F32, Scale),
|
||||
ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<>, F32, Scale),
|
||||
ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<>, F32, Scale)>;
|
||||
|
||||
TYPED_TEST_SUITE(TestContractionBilinearMixedPrecision, BilinearKernelTypesMixedPrecision);
|
||||
TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecision);
|
||||
|
||||
TYPED_TEST(TestContractionBilinearMixedPrecision, bilinear)
|
||||
{
|
||||
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
|
||||
this->Run();
|
||||
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestContractionScaleMixedPrecision, scale)
|
||||
{
|
||||
this->p_cd_element_op = std::make_unique<Scale>(1.f);
|
||||
this->Run();
|
||||
this->p_cd_element_op = std::make_unique<Scale>(0.5f);
|
||||
this->Run();
|
||||
}
|
||||
|
||||
@@ -34,11 +34,11 @@ class ContractionInstanceWrapper
|
||||
static constexpr ck::index_t NumDim = 2;
|
||||
// clang-format off
|
||||
using ContractionDeviceInstance = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>;
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>;
|
||||
// clang-format on
|
||||
|
||||
bool isSupported(std::vector<ck::index_t>& ADims,
|
||||
|
||||
Reference in New Issue
Block a user