[CK_BUILDER] Add bwd weight factories (#3509)

* Add placeholder test.

* Initial conv bwd weight factory.

* Conv builder test refactoring.

* Add missing pieces to bwd weight factory.

* Improve compile time erros message when no matching factory is found.

* Use amcro to ensure automatic macthing between concepts are their string representations.

* Improve compile time diagnostics.

* Small improvements.

* Improve missing member/wrong type compile-time errors.

* Improve compile time diagnostics.

* Concept bug fixes.

* Remove debug assert.

* Update algorithm signature diagnostics.

* Factory bug fixes.

* First functional version of bwd weight conv factory.

* Refactor handing of GEMM-K batch template parameter in conv bwd weight factory.

* Concept improvements.

* Improve concept diagnostics.

* Introduve a common size type for concepts.

* Update compiletime diagnostics to use the size type.

* Update conv specialization enum.

* Fix fwd conv builder tests.

* Fix smoke tests.

* Separate bwd weigth and bwd data tests into separate targets.

* Clean-up CK Tile builder tests.

* Add bwd weight XDL CShuffle V3 factory.

* Build conv bwd weigth v3 instances successfully.

* Add instance traits for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3.

* Test fix.

* Add instance traits for bwd weight algorithms.

* Add unit tests for instance strings.

* Build new instance traits unit tests but exclude WMMA for now.

* Added factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle.

* Conv bwd weight DL factory.

* Final implementation for bwd weight DL factory.

* Add test for creating DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance.

* Add factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle

* Treat ref algorithm the same way as real algorithms in the dispatcher.

* Refactor large tensor support and WMMA configuration.

* Add factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffleV3.

* Update Readme.

* Fix WMMA bwd weight tests.

* Added factory and tests for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3.

* Factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffle.

* Dispatching for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle.

* Add factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3

* Fix DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 factory and  compute types for input and output tensor in bwd weigth convs.

* Fix fwd factories after refactoring.

* clang-format

* Move compile-time diagnostics to a separate branch.

* Fix ref algorithm dispatching.

* Fix smoke tests.

* clang-format

* Fix factory for regular WMMA conv bwd weight.

* Clarify builder Readme.

* Remove obsolete test file.

* Fix test after merge.

* clang-format

* Remove the C++26 extensions.

* Unify conv elementwise ops and layout definitions for fwd and bwd directions.

* Remove old layout and elementwise ops.

* Unify handling of conv tensor types between fwd and bwd directions.

* Unify block transfer for fwd and bwd directions. Rename ThreadSliceDim to ThreadClusterRank.

* Make BlockTransferDescriptor concept parametrized. Introduce a common TileTransferParameters concept for conv algorithms.

* clang-format

---------

Co-authored-by: Ville Pietilä <>
This commit is contained in:
Ville Pietilä
2026-01-13 18:12:38 +02:00
committed by GitHub
parent 710fa1fd3d
commit 9908a87c31
69 changed files with 2956 additions and 832 deletions

View File

@@ -124,7 +124,7 @@ add_ck_builder_test(test_ckb_conv_description
# Verifies that GetInstanceString() methods and other functions produce valid kernel code.
# Tests various convolution types:
# - Group convolution (v3, standard, large tensor, WMMA, DL variants)
# - Backward weight group convolution (XDL)
# - Backward weight group convolution (XDL standard, XDL v3, WMMA, DL, multiple D, two-stage variants)
# Requires kernel compilation to validate the generated strings through the base class.
set(INSTANCE_STRING_TESTS
@@ -167,10 +167,35 @@ add_ck_builder_test(test_ckb_build_fwd_instances
conv/ck/test_ckb_conv_fwd_3d_fp16.cpp
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp)
)
target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
set(BWD_WEIGHT_TESTS
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp
conv/ck/test_ckb_conv_bwd_weight_dl.cpp
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
)
if (CK_USE_WMMA)
list(APPEND BWD_WEIGHT_TESTS
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp
conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp
)
endif()
add_ck_builder_test(test_ckb_build_bwd_weight_instances ${BWD_WEIGHT_TESTS})
target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility)
add_ck_builder_test(test_ckb_build_bwd_data_instances
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp
)
target_link_libraries(test_ckb_build_bwd_data_instances PRIVATE utility)
################################################################################
# FACTORY TESTS - Expensive Regression Tests (Full MIOpen Kernel Set)
@@ -224,6 +249,8 @@ endforeach()
set(CKB_REGRESSION_TESTS
test_ckb_instance_string
test_ckb_build_fwd_instances
test_ckb_build_bwd_weight_instances
test_ckb_build_bwd_data_instances
test_ckb_testing_utils
# test_ckb_factory_grouped_convolution_forward_convscale
# test_ckb_factory_grouped_convolution_forward_scaleadd_ab

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl{}
.with_thread_block(cku::ThreadBlock_256_128x128x16)
.with_bwd_specialization(cku::ConvSpecialization::DEFAULT)
.with_dl_thread_config(cku::DlThreadConfig_16x1x4x4x1)
.with_dl_thread_cluster(cku::DlThreadCluster_8x2)
.with_dl_transfer(cku::DlTransfer5D);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_2DBf16_DL, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Dl",
expected_transfer_parameters,
"Default",
"GNHWC,GKYXC,GNHWK",
"PassThrough,PassThrough,PassThrough"});
}

View File

@@ -0,0 +1,42 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 3,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNDHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKZYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNDHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_3DFp16_MultiD_Wmma_ShuffleV3_GNHWC, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3",
expected_transfer_parameters,
"Default",
"GNDHWC,GKZYXC,GNDHWK",
"PassThrough,PassThrough,PassThrough",
"fp16,fp16>"}); // check compute types
}

View File

@@ -0,0 +1,41 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle{}
.with_thread_block(cku::ThreadBlock_256_128x128x8)
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(cku::BwdTransfer_4x64x1)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_2DFp16_MultiD_CShuffle_GNHWC, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle",
expected_transfer_parameters,
"Default",
"GNHWC,GKYXC,GNHWK",
"PassThrough,PassThrough,PassThrough",
"fp16,fp16>"}); // check compute types
}

View File

@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using enum ck_tile::builder::TensorLayout;
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = NGCHW}},
.weight = {.config = {.layout = GKYXC}},
.output = {.config = {.layout = NGKHW}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
.with_num_conv_groups_to_merge(2)
.with_transpose_params(2, 2);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_2DFp16_TwoStage_Wmma_CShuffle_V3, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3",
expected_transfer_parameters,
"Default",
"NGCHW,GKYXC,NGKHW",
"PassThrough,PassThrough,PassThrough",
"Intrawave",
"v1",
"fp16,fp16,2,2>"}); // Check compute types and transpose params.
}

View File

@@ -0,0 +1,44 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave)
.with_num_conv_groups_to_merge(2)
.with_transpose_params(2, 4);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_2DBf16_TwoStage_CShuffle, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle",
expected_transfer_parameters,
"Default",
"GNHWC,GKYXC,GNHWK",
"PassThrough,PassThrough,PassThrough",
"Intrawave,v2", // pipeline versions
"bf16,bf16,2,4>"}); // compute types and transpose params
}

View File

@@ -0,0 +1,43 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using enum ck_tile::builder::TensorLayout;
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 3,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = NGCDHW}},
.weight = {.config = {.layout = GKZYXC}},
.output = {.config = {.layout = NGKDHW}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT)
.with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_3DBf16_Wmma_CShuffle, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Wmma_CShuffle",
expected_transfer_parameters,
"Default",
"NGCDHW,GKZYXC,NGKDHW",
"PassThrough,PassThrough,PassThrough",
"v1"});
}

View File

@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using enum ck_tile::builder::TensorLayout;
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = NGCW}},
.weight = {.config = {.layout = GKXC}},
.output = {.config = {.layout = NGKW}}};
constexpr auto ALGORITHM =
cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
.with_transpose_params(4, 4);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_1DBf16_Wmma_CShuffle_V3, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Wmma_CShuffleV3",
expected_transfer_parameters,
"Filter1x1Stride1Pad0",
"NGCW,GKXC,NGKW",
"PassThrough,PassThrough,PassThrough",
"Intrawave",
"v1",
"bf16,bf16,4,4>"});
}

View File

@@ -0,0 +1,41 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{}
.with_thread_block(cku::ThreadBlock_256_128x128x8)
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(cku::BwdTransfer_4x64x1)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_transpose_params(2, 2);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_2DFp16_CShuffle_GNHWC, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffle",
expected_transfer_parameters,
"Default",
"GNHWC,GKYXC,GNHWK",
"PassThrough,PassThrough,PassThrough",
"fp16,fp16,2,2>"}); // check compute types and transpose params
}

View File

@@ -0,0 +1,43 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using enum ck_tile::builder::TensorLayout;
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = NGCW}},
.weight = {.config = {.layout = GKXC}},
.output = {.config = {.layout = NGKW}}};
constexpr auto ALGORITHM =
cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_1DBf16_CShuffle_V3, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3",
expected_transfer_parameters,
"Filter1x1Stride1Pad0",
"NGCW,GKXC,NGKW",
"PassThrough,PassThrough,PassThrough",
"Intrawave",
"v2"});
}

View File

@@ -30,11 +30,11 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_256x256x32)
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v2_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -27,11 +27,12 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_64_64x32x32)
.with_thread_block(ThreadBlock_64_64x32x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x16x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 2, PipelineScheduler::DEFAULT);
.with_transfer(Transfer_4x16x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(2);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -29,11 +29,13 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{}
.with_thread_block(FwdThreadBlock_128_64x64x64)
.with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave)
.with_transfer(FwdTransfer_4x32x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 0, PipelineScheduler::DEFAULT);
.with_thread_block(ThreadBlock_128_64x64x64)
.with_gemm_config(GemmParams_Wmma_2x1_per_wave)
.with_transfer(Transfer_4x32x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(2)
.with_gridwise_gemm_pipeline(PipelineVersion::V1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -27,10 +27,10 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_256x256x32)
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v1_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
@@ -64,10 +64,11 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_256x256x32)
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_3x3,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v5_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -32,11 +32,12 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_64_64x32x32)
.with_thread_block(ThreadBlock_64_64x32x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x16x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
.with_transfer(Transfer_4x16x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -25,15 +25,16 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
.with_thread_block(FwdThreadBlock_256_128x128x16)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_thread_block(ThreadBlock_256_128x128x16)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
.with_dl_thread_cluster(DlThreadCluster_8x2)
.with_dl_transfer(DlFwdTransfer);
.with_dl_transfer(DlTransfer4D);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
expected_transfer_parameters,
"Default",
@@ -59,16 +60,17 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
.with_thread_block(FwdThreadBlock_256_128x128x16)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_thread_block(ThreadBlock_256_128x128x16)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
.with_dl_thread_cluster(DlThreadCluster_8x2)
.with_dl_transfer(DlFwdTransfer);
.with_dl_transfer(DlTransfer4D);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
expected_transfer_parameters,
"Filter1x1Pad0",

View File

@@ -25,11 +25,11 @@ constexpr auto SIGNATURE =
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(cku::FwdThreadBlock_256_256x256x32)
.with_thread_block(cku::ThreadBlock_256_256x256x32)
.with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(cku::FwdTransfer_4x64x1)
.with_specializations(ckb::ConvFwdSpecialization::DEFAULT,
ckb::GemmSpecialization::MNKPadding)
.with_transfer(cku::Transfer_4x64x1)
.with_fwd_specializations(ckb::ConvSpecialization::DEFAULT,
ckb::GemmSpecialization::MNKPadding)
.with_block_gemm(cku::BlockGemmDesc_v3_intrawave);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;

View File

@@ -26,11 +26,11 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_128x128x32)
.with_thread_block(ThreadBlock_256_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v4_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -27,11 +27,12 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_256_256x128x32)
.with_thread_block(ThreadBlock_256_256x128x32)
.with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave)
.with_transfer(FwdTransfer_4x64x1_fp8)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
.with_transfer(Transfer_4x64x1_fp8)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -25,14 +25,13 @@ TEST(FwdConvInstances,
.output = {.config = {.layout = GNHWK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
.base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_256_256x128x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x16x1)
.with_specializations(ConvFwdSpecialization::DEFAULT,
GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{}
.with_thread_block(ThreadBlock_256_256x128x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(Transfer_4x16x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
@@ -62,14 +61,14 @@ TEST(
.output = {.config = {.layout = GNHWK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
.base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_128_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x16x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{}
.with_thread_block(ThreadBlock_128_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(Transfer_4x16x1)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -27,10 +27,10 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_256x256x32)
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v3_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -27,11 +27,11 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_128x128x32)
.with_thread_block(ThreadBlock_256_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v4_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -27,11 +27,11 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_256x256x32)
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v1_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -101,7 +101,7 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
@@ -229,7 +229,7 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
@@ -313,7 +313,7 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);

View File

@@ -230,7 +230,7 @@ TEST(InstanceToConvTraits, ExtractsDefaultSpecialization)
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
}
TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization)
@@ -289,8 +289,7 @@ TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization)
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
EXPECT_EQ(Traits::conv_specialization,
ck_tile::builder::ConvFwdSpecialization::FILTER_1X1_PAD0);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0);
}
// ============================================================================

View File

@@ -8,26 +8,27 @@ namespace {
using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::BACKWARD_DATA,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr ConvSignature BwdDataConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::BACKWARD_DATA,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto FwdConvAlgorithm =
constexpr auto BwdDataConvAlgorithm =
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(TileConvSpecialization::DEFAULT)
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
.with_tile_thread_block(TileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(FwdTileTransfer_4x4x4)
.with_tile_transfer(TileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
using Builder = ConvBuilder<BwdDataConvSignature, BwdDataConvAlgorithm>;
run_ck_tile_test<Builder>({
"grouped_convolution_backward_data",
"fp16",

View File

@@ -8,26 +8,27 @@ namespace {
using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::BACKWARD_WEIGHT,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr ConvSignature BwdWeightConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::BACKWARD_WEIGHT,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto FwdConvAlgorithm =
constexpr auto BwdWeightConvAlgorithm =
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(TileConvSpecialization::DEFAULT)
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
.with_tile_thread_block(TileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(FwdTileTransfer_4x4x4)
.with_tile_transfer(TileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
using Builder = ConvBuilder<BwdWeightConvSignature, BwdWeightConvAlgorithm>;
run_ck_tile_test<Builder>({
"grouped_convolution_backward_weight",
"fp16",

View File

@@ -21,9 +21,9 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(TileConvSpecialization::DEFAULT)
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
.with_tile_thread_block(TileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(FwdTileTransfer_4x4x4)
.with_tile_transfer(TileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});

View File

@@ -28,18 +28,31 @@ struct ThreadBlock
};
static_assert(ckb::ThreadBlockDescriptor<ThreadBlock>);
// Describe gridwise XDL GEMM parameters.
struct GridwiseXdlGemm
struct XdlParams
{
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
size_t ak1 = 0;
size_t bk1 = 0;
size_t m_per_xdl = 0;
size_t n_per_xdl = 0;
size_t m_xdl_per_wave = 0;
size_t n_xdl_per_wave = 0;
};
static_assert(ckb::GridwiseXdlGemmDescriptor<GridwiseXdlGemm>);
static_assert(ckb::GridwiseXdlGemmDescriptor<XdlParams>);
// Describe gridwise XDL GEMM parameters.
struct GridwiseFwdXdlGemm
{
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
size_t ak1 = 0;
size_t bk1 = 0;
XdlParams xdl_params;
};
static_assert(ckb::GridwiseFwdXdlGemmDescriptor<GridwiseFwdXdlGemm>);
struct GridwiseBwdXdlGemm
{
size_t k1 = 0;
XdlParams xdl_params;
};
static_assert(ckb::GridwiseBwdXdlGemmDescriptor<GridwiseBwdXdlGemm>);
// Describe gridwise WMMA GEMM parameters.
struct GridwiseWmmaGemm
@@ -49,25 +62,36 @@ struct GridwiseWmmaGemm
size_t n_per_wmma = 0;
size_t m_wmma_per_wave = 0;
size_t n_wmma_per_wave = 0;
PipelineVersion pipeline_version;
};
static_assert(ckb::GridwiseWmmaGemmDescriptor<GridwiseWmmaGemm>);
struct BlockGemm
struct BlockGemmPipeline
{
PipelineVersion pipeline_version;
PipelineScheduler scheduler;
};
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
static_assert(ckb::BlockGemmPipelineDescriptor<BlockGemmPipeline>);
// Describe Aand B block transfer thread cluster lengths.
template <size_t ThreadSliceLength = 3>
struct BlockTransfer
{
size_t k0;
size_t m_n;
size_t k1;
size_t k_batch_size;
};
static_assert(ckb::BlockTransferDescriptor<BlockTransfer>);
// Specialization for ThreadSliceLength == 3
template <>
struct BlockTransfer<3>
{
size_t k0;
size_t m_n;
size_t k1;
};
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<3>, 3>);
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<4>, 4>);
// Describe C block transfer thread cluster lengths.
struct ThreadCluster
@@ -97,31 +121,35 @@ struct Epilogue
};
static_assert(EpilogueDescriptor<Epilogue>);
template <size_t ThreadSliceLength = 3>
struct AccessOrder
{
std::array<size_t, 3> order;
std::array<size_t, ThreadSliceLength> order;
};
static_assert(AccessOrderDescriptor<AccessOrder>);
static_assert(AccessOrderDescriptor<AccessOrder<>>);
static_assert(AccessOrderDescriptor<AccessOrder<4>>);
struct TransferAB
template <size_t ThreadSliceLength = 3>
struct InputTransfer
{
BlockTransfer block_transfer;
BlockTransfer<ThreadSliceLength> block_transfer;
LdsTransfer lds_transfer;
AccessOrder block_transfer_access_order;
AccessOrder src_access_order;
AccessOrder<ThreadSliceLength> block_transfer_access_order;
AccessOrder<ThreadSliceLength> src_access_order;
};
struct TransferC
struct OutputTransfer
{
ThreadCluster thread_cluster_dims;
Epilogue epilogue;
};
struct TransferABC
template <size_t ThreadSliceLength = 3>
struct Transfer
{
TransferAB a;
TransferAB b;
TransferC c;
InputTransfer<ThreadSliceLength> a;
InputTransfer<ThreadSliceLength> b;
OutputTransfer c;
};
// DL-specific descriptors
@@ -142,17 +170,19 @@ struct DlThreadCluster
};
static_assert(ckb::DlThreadClusterDescriptor<DlThreadCluster>);
template <size_t D = 4>
struct DlBlockTransfer
{
std::array<size_t, 4> thread_slice_lengths;
std::array<size_t, 4> thread_cluster_lengths;
std::array<size_t, 4> thread_cluster_arrange_order;
std::array<size_t, 4> src_access_order;
std::array<size_t, 4> src_vector_tensor_lengths;
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
std::array<size_t, 4> dst_vector_tensor_lengths;
std::array<size_t, D> thread_slice_lengths;
std::array<size_t, D> thread_cluster_lengths;
std::array<size_t, D> thread_cluster_arrange_order;
std::array<size_t, D> src_access_order;
std::array<size_t, D> src_vector_tensor_lengths;
std::array<size_t, D> src_vector_tensor_contiguous_dim_order;
std::array<size_t, D> dst_vector_tensor_lengths;
};
static_assert(ckb::DlBlockTransferDescriptor<DlBlockTransfer>);
static_assert(ckb::DlBlockTransferDescriptor4D<DlBlockTransfer<4>>);
static_assert(ckb::DlBlockTransferDescriptor5D<DlBlockTransfer<5>>);
struct DlEpilogue
{
@@ -169,9 +199,14 @@ struct ThreadBlock_
ThreadBlock thread_block;
};
struct XdlGemm_
struct FwdXdlGemm_
{
GridwiseXdlGemm gridwise_gemm;
GridwiseFwdXdlGemm gridwise_gemm;
};
struct BwdXdlGemm_
{
GridwiseBwdXdlGemm gridwise_gemm;
};
struct WmmaGemm_
@@ -179,27 +214,48 @@ struct WmmaGemm_
GridwiseWmmaGemm gridwise_gemm;
};
template <size_t ThreadSliceLength = 3>
struct Transfer_
{
TransferABC transfer;
Transfer<ThreadSliceLength> transfer;
};
struct ConvSpecialization_
struct ConvSpecializationFwd_
{
ConvFwdSpecialization fwd_specialization;
ConvSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
};
struct ConvSpecializationBwdWeight_
{
ConvSpecialization bwd_weight_specialization;
};
struct Prefetch_
{
size_t num_gemm_k_prefetch_stages;
size_t num_groups_to_merge;
PipelineScheduler loop_scheduler;
};
struct TransposeParams_
{
size_t max_transpose_transfer_src_scalar_per_vector{1};
size_t max_transpose_transfer_dst_scalar_per_vector{1};
};
struct GemmBatchOptions_
{
size_t num_conv_groups_to_merge{1};
};
struct BlockGemm_
{
BlockGemm block_gemm;
BlockGemmPipeline block_gemm_pipeline;
};
struct GridGemm_
{
PipelineVersion pipeline_version;
};
struct DlThreadConfig_
@@ -212,33 +268,34 @@ struct DlThreadCluster_
DlThreadCluster thread_cluster;
};
struct DlBlockTransferAB
template <size_t Dim = 4>
struct DlTransfer
{
DlBlockTransfer block_transfer;
};
struct DlBlockTransferC
{
DlEpilogue epilogue;
};
struct DlTransferABC
{
DlBlockTransferAB a;
DlBlockTransferAB b;
DlBlockTransferC c;
DlBlockTransfer<Dim> a;
DlBlockTransfer<Dim> b;
DlEpilogue c;
};
template <size_t Dim = 4>
struct DlTransfer_
{
DlTransferABC transfer;
DlTransfer<Dim> transfer;
};
// Specialization wrapper for large tensor support
template <typename BaseAlgorithm>
struct LargeTensorWrapper
struct TwoStageSpecialization_
{
static constexpr ConvAlgorithmSpecialization specialization =
ConvAlgorithmSpecialization::TWO_STAGE;
};
struct MultipleDSpecialization_
{
static constexpr ConvAlgorithmSpecialization specialization =
ConvAlgorithmSpecialization::MULTIPLE_D;
};
struct LargeTensorSpecialization_
{
BaseAlgorithm base_algorithm;
static constexpr ConvAlgorithmSpecialization specialization =
ConvAlgorithmSpecialization::LARGE_TENSOR;
};
@@ -329,7 +386,11 @@ struct ConvAlgorithmTemplate : Components...
constexpr auto with_gemm_config(const GemmConfig& gemm) const
{
auto result = *this;
if constexpr(std::is_base_of_v<XdlGemm_, ConvAlgorithmTemplate>)
if constexpr(std::is_base_of_v<FwdXdlGemm_, ConvAlgorithmTemplate>)
{
result.gridwise_gemm = gemm;
}
else if constexpr(std::is_base_of_v<BwdXdlGemm_, ConvAlgorithmTemplate>)
{
result.gridwise_gemm = gemm;
}
@@ -337,46 +398,82 @@ struct ConvAlgorithmTemplate : Components...
{
result.gridwise_gemm = gemm;
}
else
{
static_assert(false, "Unrecognized GemmConfig type");
}
return result;
}
template <typename T>
constexpr auto with_transfer(const T& t) const
{
static_assert(std::is_base_of_v<Transfer_, ConvAlgorithmTemplate>);
static_assert(std::is_base_of_v<Transfer_<3>, ConvAlgorithmTemplate> ||
std::is_base_of_v<Transfer_<4>, ConvAlgorithmTemplate>);
auto result = *this;
result.transfer = t;
return result;
}
constexpr auto with_specializations(ConvFwdSpecialization fwd_spec,
GemmSpecialization gemm_spec) const
constexpr auto with_fwd_specializations(ConvSpecialization fwd_spec,
GemmSpecialization gemm_spec) const
{
static_assert(std::is_base_of_v<ConvSpecialization_, ConvAlgorithmTemplate>);
static_assert(std::is_base_of_v<ConvSpecializationFwd_, ConvAlgorithmTemplate>);
auto result = *this;
result.fwd_specialization = fwd_spec;
result.gemm_specialization = gemm_spec;
return result;
}
constexpr auto with_prefetch_config(size_t k_prefetch_stages,
size_t groups_to_merge,
PipelineScheduler scheduler) const
constexpr auto with_bwd_specialization(ConvSpecialization bwd_spec) const
{
static_assert(std::is_base_of_v<ConvSpecializationBwdWeight_, ConvAlgorithmTemplate>);
auto result = *this;
result.bwd_weight_specialization = bwd_spec;
return result;
}
constexpr auto with_prefetch_config(size_t k_prefetch_stages, PipelineScheduler scheduler) const
{
static_assert(std::is_base_of_v<Prefetch_, ConvAlgorithmTemplate>);
auto result = *this;
result.num_gemm_k_prefetch_stages = k_prefetch_stages;
result.num_groups_to_merge = groups_to_merge;
result.loop_scheduler = scheduler;
return result;
}
constexpr auto with_transpose_params(size_t max_src_scalar_per_vector,
size_t max_dst_scalar_per_vector) const
{
static_assert(std::is_base_of_v<TransposeParams_, ConvAlgorithmTemplate>);
auto result = *this;
result.max_transpose_transfer_src_scalar_per_vector = max_src_scalar_per_vector;
result.max_transpose_transfer_dst_scalar_per_vector = max_dst_scalar_per_vector;
return result;
}
constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const
{
static_assert(std::is_base_of_v<GemmBatchOptions_, ConvAlgorithmTemplate>);
auto result = *this;
result.num_conv_groups_to_merge = num_groups_to_merge;
return result;
}
template <typename BG>
constexpr auto with_block_gemm(const BG& bg) const
{
static_assert(std::is_base_of_v<BlockGemm_, ConvAlgorithmTemplate>);
auto result = *this;
result.block_gemm = bg;
auto result = *this;
result.block_gemm_pipeline = bg;
return result;
}
constexpr auto with_gridwise_gemm_pipeline(const PipelineVersion plv) const
{
static_assert(std::is_base_of_v<GridGemm_, ConvAlgorithmTemplate>);
auto result = *this;
result.pipeline_version = plv;
return result;
}
@@ -401,7 +498,8 @@ struct ConvAlgorithmTemplate : Components...
template <typename T>
constexpr auto with_dl_transfer(const T& t) const
{
static_assert(std::is_base_of_v<DlTransfer_, ConvAlgorithmTemplate>);
static_assert(std::is_base_of_v<DlTransfer_<4>, ConvAlgorithmTemplate> ||
std::is_base_of_v<DlTransfer_<5>, ConvAlgorithmTemplate>);
auto result = *this;
result.transfer = t;
return result;
@@ -453,26 +551,49 @@ struct ConvAlgorithmTemplate : Components...
}
};
// Algorithm types
// Fwd algorithm types
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, Transfer_, ConvSpecialization_, Prefetch_>;
ConvAlgorithmTemplate<ThreadBlock_,
FwdXdlGemm_,
Transfer_<>,
ConvSpecializationFwd_,
Prefetch_,
GemmBatchOptions_>;
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, Transfer_, ConvSpecialization_, BlockGemm_>;
ConvAlgorithmTemplate<ThreadBlock_,
FwdXdlGemm_,
Transfer_<>,
ConvSpecializationFwd_,
BlockGemm_>;
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_, ConvSpecialization_, Prefetch_>;
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<>,
ConvSpecializationFwd_,
GridGemm_,
Prefetch_,
GemmBatchOptions_>;
using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
ConvAlgorithmTemplate<ThreadBlock_,
ConvSpecialization_,
ConvSpecializationFwd_,
DlThreadConfig_,
DlThreadCluster_,
DlTransfer_>;
DlTransfer_<>>;
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
ConvAlgorithmTemplate<ThreadBlock_,
FwdXdlGemm_,
Transfer_<>,
ConvSpecializationFwd_,
Prefetch_,
GemmBatchOptions_,
LargeTensorSpecialization_>;
// CK Tile algorithm
using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileThreadBlock_,
TileBlockGemm_,
TileTransfer_,
@@ -488,4 +609,77 @@ struct ConvAlgorithm_Reference
// GPU reference uses simple algorithm, no tile configuration needed
};
// Bwd weight algorithm types
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_,
BwdXdlGemm_,
Transfer_<4>,
ConvSpecializationBwdWeight_,
TransposeParams_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_,
BwdXdlGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_,
TransposeParams_,
GemmBatchOptions_,
TwoStageSpecialization_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_,
BwdXdlGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl =
ConvAlgorithmTemplate<ThreadBlock_,
DlThreadConfig_,
DlThreadCluster_,
DlTransfer_<5>,
ConvSpecializationBwdWeight_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_,
BwdXdlGemm_,
Transfer_<4>,
ConvSpecializationBwdWeight_,
MultipleDSpecialization_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_,
TransposeParams_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_,
TransposeParams_,
GemmBatchOptions_,
TwoStageSpecialization_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
GridGemm_,
Prefetch_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_,
MultipleDSpecialization_>;
} // namespace ck_tile::builder::test

View File

@@ -120,14 +120,12 @@ struct DefaultAlgorithm
ckb::test::ThreadBlock thread_block{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 16,
.n_per_xdl = 16,
.m_xdl_per_wave = 8,
.n_xdl_per_wave = 8};
ckb::test::GridwiseFwdXdlGemm gridwise_gemm{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 16, .n_per_xdl = 16, .m_xdl_per_wave = 8, .n_xdl_per_wave = 8}};
ckb::test::TransferABC transfer{
ckb::test::Transfer<> transfer{
.a =
{
.block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2},
@@ -161,10 +159,11 @@ struct DefaultAlgorithm
},
};
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
.scheduler = ckb::PipelineScheduler::INTRAWAVE};
ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT;
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4,
.scheduler =
ckb::PipelineScheduler::INTRAWAVE};
};
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);

View File

@@ -38,11 +38,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK)
.weight = {.config = {.layout = GKXC}},
.output = {.config = {.layout = NWGK}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -57,11 +57,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW)
.weight = {.config = {.layout = GKXC}},
.output = {.config = {.layout = NGKW}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -76,11 +76,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK)
.weight = {.config = {.layout = GKXC}},
.output = {.config = {.layout = GNWK}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -95,11 +95,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW)
.weight = {.config = {.layout = GKCX}},
.output = {.config = {.layout = NGKW}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -114,11 +114,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW)
.weight = {.config = {.layout = GKYXC}},
.output = {.config = {.layout = NGKHW}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -133,11 +133,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK)
.weight = {.config = {.layout = GKYXC}},
.output = {.config = {.layout = NHWGK}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -152,11 +152,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK)
.weight = {.config = {.layout = GKYXC}},
.output = {.config = {.layout = GNHWK}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -171,11 +171,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW)
.weight = {.config = {.layout = GKCYX}},
.output = {.config = {.layout = NGKHW}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -190,11 +190,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW)
.weight = {.config = {.layout = GKCZYX}},
.output = {.config = {.layout = NGKDHW}}};
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 3>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCZYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCZYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -209,11 +209,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK)
.weight = {.config = {.layout = GKZYXC}},
.output = {.config = {.layout = NDHWGK}}};
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 3>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NDHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -228,11 +228,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK)
.weight = {.config = {.layout = GKZYXC}},
.output = {.config = {.layout = GNDHWK}}};
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 3>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNDHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNDHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNDHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNDHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -273,7 +273,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout)
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = G_K_strided}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K>;
@@ -287,7 +287,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout)
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::GC>;
@@ -301,7 +301,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout)
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = G_C_strided}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_C>;
@@ -316,7 +316,7 @@ TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors)
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided},
MockAuxiliaryTensorConfig{.layout = GC}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
EXPECT_EQ(AuxLayouts::Size, 2);
using ExpectedType =
@@ -333,7 +333,7 @@ TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors)
MockAuxiliaryTensorConfig{.layout = GC},
MockAuxiliaryTensorConfig{.layout = G_C_strided}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
EXPECT_EQ(AuxLayouts::Size, 3);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K,
@@ -349,7 +349,7 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith1DConvolution)
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = G_K_strided}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 1, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 1>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K>;
@@ -363,7 +363,7 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution)
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = GC}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 3, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 3>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::GC>;
@@ -387,11 +387,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K)
.operation =
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
@@ -414,11 +414,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC)
.operation =
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NHWGK>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::GC>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
@@ -442,11 +442,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors)
.operation = OutputOp{.elementwise_operation =
ElementwiseOperation::SCALEADD_SCALEADD_RELU}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNHWK>));
using ExpectedDsLayout =
ck::Tuple<ck::tensor_layout::convolution::G_K, ck::tensor_layout::convolution::GC>;
@@ -470,11 +470,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias)
.operation =
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NWGK>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
@@ -497,11 +497,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias)
.operation = OutputOp{.elementwise_operation =
ElementwiseOperation::BIAS_BNORM_CLAMP}}};
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 3>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NDHWGK>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_C>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));

View File

@@ -19,7 +19,7 @@ TEST(ConvTuningParams, AssignsBlockGemmParams)
{
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3;
ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE;
} block_gemm;
} block_gemm_pipeline;
} kAlgorithm;
constexpr auto block_gemm = SetBlockGemm<kAlgorithm>();
@@ -42,10 +42,7 @@ TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion)
{
constexpr struct Algorithm
{
struct GridwiseGemm
{
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
} gridwise_gemm;
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
} kAlgorithm;
constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion<kAlgorithm>();
@@ -78,8 +75,8 @@ TEST(ConvTuningParams, AssignsFwdConvSpecialization)
{
constexpr struct Algorithm
{
ckb::ConvFwdSpecialization fwd_specialization =
ckb::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
ckb::ConvSpecialization fwd_specialization =
ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0;
} kAlgorithm;
constexpr auto conv_spec = SetFwdConvSpecialization<kAlgorithm>();

View File

@@ -15,31 +15,42 @@ using namespace test;
constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
constexpr DlThreadConfig DlThreadConfig_16x1x4x4x1{
.k0_per_block = 16, .k1 = 1, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, 1, 1, 2},
.thread_cluster_lengths = {2, 1, 128, 1},
.thread_cluster_arrange_order = {1, 2, 0, 3},
.src_access_order = {1, 2, 0, 3},
.src_vector_tensor_lengths = {4, 1, 1, 2},
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
constexpr DlBlockTransfer<4> DlBlockTransfer_8x1x1x2{
.thread_slice_lengths = {8, 1, 1, 2},
.thread_cluster_lengths = {2, 1, 128, 1},
.thread_cluster_arrange_order = {1, 2, 0, 3},
.src_access_order = {1, 2, 0, 3},
.src_vector_tensor_lengths = {4, 1, 1, 2},
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
constexpr DlTransferABC DlFwdTransfer{.a =
{
.block_transfer = DlBlockTransferAB,
},
.b =
{
.block_transfer = DlBlockTransferAB,
},
.c = {
.epilogue = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
.src_dst_vector_dim = 5,
.dst_scalar_per_vector = 4},
}};
constexpr DlTransfer<4> DlTransfer4D{.a = DlBlockTransfer_8x1x1x2,
.b = DlBlockTransfer_8x1x1x2,
.c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
.src_dst_vector_dim = 5,
.dst_scalar_per_vector = 4}};
constexpr TransferABC FwdTransfer_4x64x1{
constexpr DlBlockTransfer<5> DlBlockTransfer_1x8x1x1x1{
.thread_slice_lengths = {1, 8, 1, 1, 1},
.thread_cluster_lengths = {1, 2, 1, 128, 1},
.thread_cluster_arrange_order = {0, 2, 3, 1, 4},
.src_access_order = {0, 2, 3, 1, 4},
.src_vector_tensor_lengths = {1, 1, 1, 1, 1},
.src_vector_tensor_contiguous_dim_order = {0, 2, 3, 1, 4},
.dst_vector_tensor_lengths = {1, 1, 1, 1, 1}};
constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1,
.b = DlBlockTransfer_1x8x1x1x1,
.c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
.src_dst_vector_dim = 5,
.dst_scalar_per_vector = 1}};
constexpr Transfer<> Transfer_4x64x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
@@ -72,7 +83,73 @@ constexpr TransferABC FwdTransfer_4x64x1{
},
};
constexpr TransferABC FwdTransfer_4x64x1_fp8{
constexpr Transfer<4> BwdTransfer_4x64x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 4,
.is_direct_load = false,
.lds_padding = true},
.block_transfer_access_order = {0, 3, 1, 2},
.src_access_order = {0, 2, 1, 3},
},
.b =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 4,
.is_direct_load = false,
.lds_padding = true},
.block_transfer_access_order = {0, 3, 1, 2},
.src_access_order = {0, 2, 1, 3},
},
.c =
{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
},
};
constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1},
.lds_transfer = {.src_vector_dim = 1,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 2,
.is_direct_load = false,
.lds_padding = false},
.block_transfer_access_order = {2, 0, 1},
.src_access_order = {1, 0, 2},
},
.b =
{
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
.lds_transfer = {.src_vector_dim = 1,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 2,
.is_direct_load = false,
.lds_padding = false},
.block_transfer_access_order = {2, 0, 1},
.src_access_order = {1, 0, 2},
},
.c =
{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 2},
},
};
constexpr Transfer<> Transfer_4x64x1_fp8{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
@@ -105,7 +182,7 @@ constexpr TransferABC FwdTransfer_4x64x1_fp8{
},
};
constexpr TransferABC FwdTransfer_4x16x1{
constexpr Transfer<> Transfer_4x16x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
@@ -139,7 +216,7 @@ constexpr TransferABC FwdTransfer_4x16x1{
},
};
constexpr TransferABC FwdTransfer_4x32x1{
constexpr Transfer<> Transfer_4x32x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1},
@@ -172,59 +249,80 @@ constexpr TransferABC FwdTransfer_4x32x1{
},
};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{
.k1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2};
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_1x1_per_wave{
.k1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 1, .n_xdl_per_wave = 1}};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x2_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2};
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1};
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}};
constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8,
.m_per_wmma = 32,
.n_per_wmma = 32,
.m_wmma_per_wave = 2,
.n_wmma_per_wave = 1,
.pipeline_version = PipelineVersion::V1};
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x2_per_wave{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}};
constexpr ThreadBlock FwdThreadBlock_256_256x256x32{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}};
constexpr ThreadBlock FwdThreadBlock_256_256x128x32{.block_size = 256,
.tile_size = {.m = 256, .n = 128, .k = 32}};
constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{
.k1 = 8, .m_per_wmma = 32, .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1};
constexpr ThreadBlock FwdThreadBlock_256_128x128x32{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{
.k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1};
constexpr ThreadBlock FwdThreadBlock_256_128x128x16{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_64_64x32x32{.block_size = 64,
.tile_size = {.m = 64, .n = 32, .k = 32}};
constexpr ThreadBlock ThreadBlock_256_256x128x32{.block_size = 256,
.tile_size = {.m = 256, .n = 128, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_128_128x128x32{.block_size = 128,
.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr ThreadBlock ThreadBlock_256_128x128x32{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_128_64x64x64{.block_size = 128,
.tile_size = {.m = 64, .n = 64, .k = 64}};
constexpr ThreadBlock ThreadBlock_256_128x128x16{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 8}};
constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64,
.tile_size = {.m = 64, .n = 32, .k = 32}};
constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ThreadBlock ThreadBlock_64_32x32x32{.block_size = 64,
.tile_size = {.m = 32, .n = 32, .k = 32}};
constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128,
.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ThreadBlock ThreadBlock_128_64x64x64{.block_size = 128,
.tile_size = {.m = 64, .n = 64, .k = 64}};
constexpr BlockGemmPipeline BlockGemmDesc_v1_intrawave = {
.pipeline_version = PipelineVersion::V1, .scheduler = PipelineScheduler::INTRAWAVE};
constexpr BlockGemmPipeline BlockGemmDesc_v2_intrawave = {
.pipeline_version = PipelineVersion::V2, .scheduler = PipelineScheduler::INTRAWAVE};
constexpr BlockGemmPipeline BlockGemmDesc_v3_intrawave = {
.pipeline_version = PipelineVersion::V3, .scheduler = PipelineScheduler::INTRAWAVE};
constexpr BlockGemmPipeline BlockGemmDesc_v4_intrawave = {
.pipeline_version = PipelineVersion::V4, .scheduler = PipelineScheduler::INTRAWAVE};
constexpr BlockGemmPipeline BlockGemmDesc_v5_intrawave = {
.pipeline_version = PipelineVersion::V5, .scheduler = PipelineScheduler::INTRAWAVE};
} // namespace ck_tile::builder::test_utils

View File

@@ -12,35 +12,35 @@ namespace ck_tile::builder::test_utils {
using namespace ck_tile::builder;
using namespace test;
constexpr TileTransfer FwdTileTransfer_1x1x1{
constexpr TileTransfer TileTransfer_1x1x1{
.a_scalar_per_vector = 1,
.b_scalar_per_vector = 1,
.c_scalar_per_vector = 1,
};
constexpr TileTransfer FwdTileTransfer_4x4x4{
constexpr TileTransfer TileTransfer_4x4x4{
.a_scalar_per_vector = 4,
.b_scalar_per_vector = 4,
.c_scalar_per_vector = 4,
};
constexpr TileTransfer FwdTileTransfer_8x8x8{
constexpr TileTransfer TileTransfer_8x8x8{
.a_scalar_per_vector = 8,
.b_scalar_per_vector = 8,
.c_scalar_per_vector = 8,
};
constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr TileThreadBlock TileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
constexpr TileThreadBlock TileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr TileThreadBlock TileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
constexpr TileThreadBlock TileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
constexpr TileThreadBlock TileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
constexpr TileThreadBlock TileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = {
.warps = {.m = 2, .n = 2, .k = 1},

View File

@@ -54,7 +54,7 @@ inline std::string to_string<PipelineScheduler>(PipelineScheduler t)
}
template <>
inline std::string to_string<ConvFwdSpecialization>(ConvFwdSpecialization t)
inline std::string to_string<ConvSpecialization>(ConvSpecialization t)
{
std::ostringstream oss;
oss << t;
@@ -86,11 +86,20 @@ inline std::string to_string<ThreadBlock>(ThreadBlock t)
}
template <>
inline std::string to_string<GridwiseXdlGemm>(GridwiseXdlGemm t)
inline std::string to_string<GridwiseBwdXdlGemm>(GridwiseBwdXdlGemm t)
{
std::ostringstream oss;
oss << t.ak1 << "," << t.bk1 << "," << t.m_per_xdl << "," << t.n_per_xdl << ","
<< t.m_xdl_per_wave << "," << t.n_xdl_per_wave;
oss << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << ","
<< t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
return oss.str();
}
template <>
inline std::string to_string<GridwiseFwdXdlGemm>(GridwiseFwdXdlGemm t)
{
std::ostringstream oss;
oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl
<< "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
return oss.str();
}
@@ -104,17 +113,29 @@ inline std::string to_string<GridwiseWmmaGemm>(GridwiseWmmaGemm t)
}
template <>
inline std::string to_string<BlockGemm>(BlockGemm t)
inline std::string to_string<BlockGemmPipeline>(BlockGemmPipeline t)
{
std::ostringstream oss;
oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version);
return oss.str();
}
template <>
inline std::string to_string<BlockTransfer>(BlockTransfer t)
template <size_t ThreadClusterRank>
inline std::string to_string(BlockTransfer<ThreadClusterRank> t)
{
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
if constexpr(ThreadClusterRank == 4)
{
return array_to_seq(std::array<size_t, 4>{t.k_batch_size, t.k0, t.m_n, t.k1});
}
else if constexpr(ThreadClusterRank == 3)
{
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
}
else
{
static_assert(ThreadClusterRank == 3 || ThreadClusterRank == 4,
"Unsupported ThreadClusterRank");
}
}
template <>
@@ -134,14 +155,14 @@ inline std::string to_string<LdsTransfer>(LdsTransfer t)
return oss.str();
}
template <>
inline std::string to_string<AccessOrder>(AccessOrder t)
template <size_t N>
inline std::string to_string(AccessOrder<N> t)
{
return array_to_seq(t.order);
}
template <>
inline std::string to_string<TransferAB>(TransferAB t)
template <size_t N = 3>
inline std::string to_string(InputTransfer<N> t)
{
std::ostringstream oss;
oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << ","
@@ -152,7 +173,7 @@ inline std::string to_string<TransferAB>(TransferAB t)
}
template <>
inline std::string to_string<TransferC>(TransferC t)
inline std::string to_string<OutputTransfer>(OutputTransfer t)
{
std::ostringstream oss;
oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << ","
@@ -160,8 +181,8 @@ inline std::string to_string<TransferC>(TransferC t)
return oss.str();
}
template <>
inline std::string to_string<TransferABC>(TransferABC t)
template <size_t N = 3>
inline std::string to_string(Transfer<N> t)
{
std::ostringstream oss;
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
@@ -185,7 +206,19 @@ inline std::string to_string<DlThreadCluster>(DlThreadCluster t)
}
template <>
inline std::string to_string<DlBlockTransfer>(DlBlockTransfer t)
inline std::string to_string<DlBlockTransfer<4>>(DlBlockTransfer<4> t)
{
std::ostringstream oss;
oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths)
<< "," << array_to_seq(t.thread_cluster_arrange_order) << ","
<< array_to_seq(t.src_access_order) << "," << array_to_seq(t.src_vector_tensor_lengths)
<< "," << array_to_seq(t.src_vector_tensor_contiguous_dim_order) << ","
<< array_to_seq(t.dst_vector_tensor_lengths);
return oss.str();
}
template <>
inline std::string to_string<DlBlockTransfer<5>>(DlBlockTransfer<5> t)
{
std::ostringstream oss;
oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths)
@@ -206,19 +239,24 @@ inline std::string to_string<DlEpilogue>(DlEpilogue t)
}
template <>
inline std::string to_string<DlBlockTransferAB>(DlBlockTransferAB t)
inline std::string to_string<TransposeParams_>(TransposeParams_ t)
{
return to_string(t.block_transfer);
std::ostringstream oss;
oss << t.max_transpose_transfer_src_scalar_per_vector << ","
<< t.max_transpose_transfer_dst_scalar_per_vector;
return oss.str();
}
template <>
inline std::string to_string<DlBlockTransferC>(DlBlockTransferC t)
inline std::string to_string<DlTransfer<4>>(DlTransfer<4> t)
{
return to_string(t.epilogue);
std::ostringstream oss;
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
return oss.str();
}
template <>
inline std::string to_string<DlTransferABC>(DlTransferABC t)
inline std::string to_string<DlTransfer<5>>(DlTransfer<5> t)
{
std::ostringstream oss;
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
@@ -234,7 +272,13 @@ inline std::string to_string<ThreadBlock_>(ThreadBlock_ t)
}
template <>
inline std::string to_string<XdlGemm_>(XdlGemm_ t)
inline std::string to_string<FwdXdlGemm_>(FwdXdlGemm_ t)
{
return to_string(t.gridwise_gemm);
}
template <>
inline std::string to_string<BwdXdlGemm_>(BwdXdlGemm_ t)
{
return to_string(t.gridwise_gemm);
}
@@ -245,33 +289,40 @@ inline std::string to_string<WmmaGemm_>(WmmaGemm_ t)
return to_string(t.gridwise_gemm);
}
template <>
inline std::string to_string<Transfer_>(Transfer_ t)
template <size_t ThreadClusterRank = 3>
inline std::string to_string(Transfer_<ThreadClusterRank> t)
{
return to_string(t.transfer);
}
template <>
inline std::string to_string<ConvSpecialization_>(ConvSpecialization_ t)
inline std::string to_string<ConvSpecializationFwd_>(ConvSpecializationFwd_ t)
{
std::ostringstream oss;
oss << to_string(t.fwd_specialization) << "," << to_string(t.gemm_specialization);
return oss.str();
}
template <>
inline std::string to_string<ConvSpecializationBwdWeight_>(ConvSpecializationBwdWeight_ t)
{
std::ostringstream oss;
oss << to_string(t.bwd_weight_specialization);
return oss.str();
}
template <>
inline std::string to_string<Prefetch_>(Prefetch_ t)
{
std::ostringstream oss;
oss << t.num_gemm_k_prefetch_stages << "," << t.num_groups_to_merge << ","
<< to_string(t.loop_scheduler);
oss << t.num_gemm_k_prefetch_stages << "," << to_string(t.loop_scheduler);
return oss.str();
}
template <>
inline std::string to_string<BlockGemm_>(BlockGemm_ t)
{
return to_string(t.block_gemm);
return to_string(t.block_gemm_pipeline);
}
template <>
@@ -287,7 +338,13 @@ inline std::string to_string<DlThreadCluster_>(DlThreadCluster_ t)
}
template <>
inline std::string to_string<DlTransfer_>(DlTransfer_ t)
inline std::string to_string<DlTransfer_<4>>(DlTransfer_<4> t)
{
return to_string(t.transfer);
}
template <>
inline std::string to_string<DlTransfer_<5>>(DlTransfer_<5> t)
{
return to_string(t.transfer);
}
@@ -299,8 +356,8 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<XdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_>(t));
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
@@ -309,8 +366,8 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<XdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_>(t));
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
@@ -320,7 +377,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CS
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
<< "," << to_string(static_cast<Transfer_>(t));
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
@@ -332,7 +389,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
<< to_string(static_cast<DlThreadConfig_>(t)) << ","
<< to_string(static_cast<DlThreadCluster_>(t)) << ","
<< to_string(static_cast<DlTransfer_>(t));
<< to_string(static_cast<DlTransfer_<4>>(t));
return oss.str();
}
@@ -340,7 +397,102 @@ template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor>(
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor t)
{
return to_string(t.base_algorithm);
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<4>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3>(
ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
<< to_string(static_cast<DlThreadConfig_>(t)) << ","
<< to_string(static_cast<DlThreadCluster_>(t)) << ","
<< to_string(static_cast<DlTransfer_<5>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle>(
ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<4>>(t));
return oss.str();
}
} // namespace ck_tile::builder::test