Files
composable_kernel/experimental/builder/test/unit_conv_tuning_params.cpp

96 lines
2.9 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
namespace {
namespace ckb = ::ck_tile::builder;
using namespace ck_tile::builder;
using namespace ck_tile::builder::factory::internal;
TEST(ConvTuningParams, AssignsBlockGemmParams)
{
constexpr struct Algorithm
{
struct BlockGemm
{
size_t num_conv_groups_to_merge = 1;
size_t num_gemm_k_prefetch_stages = 1;
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3;
ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE;
} gemm_pipeline;
} kAlgorithm;
constexpr auto block_gemm = SetBlockGemm<kAlgorithm>();
EXPECT_EQ(block_gemm.pipeline_version, ck::BlockGemmPipelineVersion::v3);
EXPECT_EQ(block_gemm.scheduler, ck::BlockGemmPipelineScheduler::Intrawave);
}
TEST(ConvTuningParams, AssignsLoopSchedulerParam)
{
constexpr struct Algorithm
{
struct GemmPipeline
{
ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTERWAVE;
} gemm_pipeline;
} kAlgorithm;
constexpr auto loop_scheduler = SetLoopScheduler<kAlgorithm>();
EXPECT_EQ(loop_scheduler, ck::LoopScheduler::Interwave);
}
TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion)
{
constexpr struct Algorithm
{
struct GemmPipeline
{
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
} gemm_pipeline;
} kAlgorithm;
constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion<kAlgorithm>();
EXPECT_EQ(pipeline_version, ck::PipelineVersion::v4);
}
TEST(ConvTuningParams, AssignsGemmSpecialization)
{
constexpr struct Algorithm
{
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::MNKPadding;
} kAlgorithm;
constexpr auto gemm_spec = SetGemmSpecialization<kAlgorithm>();
EXPECT_EQ(gemm_spec, ck::tensor_operation::device::GemmSpecialization::MNKPadding);
}
TEST(ConvTuningParams, AssignsBlockGemmPipelineVersion)
{
constexpr struct Algorithm
{
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V2;
} kAlgorithm;
constexpr auto pipeline_version = SetBlockGemmPipelineVersion<kAlgorithm>();
EXPECT_EQ(pipeline_version, ck::BlockGemmPipelineVersion::v2);
}
TEST(ConvTuningParams, AssignsFwdConvSpecialization)
{
constexpr struct Algorithm
{
ckb::ConvSpecialization fwd_specialization =
ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0;
} kAlgorithm;
constexpr auto conv_spec = SetFwdConvSpecialization<kAlgorithm>();
EXPECT_EQ(conv_spec,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
}
} // namespace