mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_BUILDER] First fwd convolution builder implementation (#3070)
* Add experimental builder infrastructure for composable_kernel - Add experimental/builder directory with README documentation. - Create initial test infrastructure with CMakeLists.txt and placeholder test. - Update root CMakeLists.txt to support CK_EXPERIMENTAL_BUILDER option. - Update .gitignore to not treat `experimental/builder` as a CMake build directory. This establishes the directory structure for a high-level builder pattern that will provide a semantically-clear interface for constructing CK operations, with initial focus on convolution kernels for MIOpen integration. * Fix clang formatting. * Fix CMake build infrastructure for experimental builder - Add experimental/builder CMakeLists.txt with proper subdirectory structure - Add placeholder include/ck_tile/builder CMakeLists.txt for header installation - Fix gtest.cmake to use include_guard to prevent multiple inclusions - Update root CMakeLists.txt to include full builder directory instead of just tests * Scope C++20 settingto the test code Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Remove redundant GTest::gtest linkage Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Introduce basic types, and convolution algorithm concepts and limits. * Add convolution signature concepts. * Add convolution factory. * Finalize conv factory implementation for fwd convolutions. * Add type definitions for testing. * Add placeholder test. * Add convolution builder definition. * Fully functional fwd conv builder. * Test improvements. * Clean-up include headers. * Enable the limit checks for the convolution algorithm parameters. * Remove dead code. * clang formatting. * Add more tests and missing conv specialization argument. * clang formatting. * Add explicit handling of the tensor layouts. * Add complete 2D/3D layout support to CK Builder - Add missing 2D layouts: GNHWC_GKYXC_GNHWK, NGCHW_GKCYX_NGKHW - Add missing 3D layout: GNDHWC_GKZYXC_GNDHWK - Add 1D layouts (NWGC, NGCW, GNWC, NGCW_GKCX) for future support - Add 3 tests for new 2D/3D layouts - All tests pass (5/5) * Add tests for remaining 2D/3D layouts - Add test for 2D NGCHW_GKYXC_NGKHW (channels-first) with Filter1x1Stride1Pad0 - Add test for 3D NDHWGC_GKZYXC_NDHWGK (channels-last) - All 7 tests pass (complete coverage for all 2D/3D forward layouts) * Change enum converters to consteval. * 7 tests with pipeline and specialization| Test # | Dim | Type | Layout | Pipeline | Specialization | |--------|-----|------|----------------------|----------|-------------------------| | 1 | 2D | BF16 | NHWGC_GKYXC_NHWGK | V1 | DEFAULT | | 2 | 2D | FP16 | GNHWC_GKYXC_GNHWK | V3 | FILTER_1X1_PAD0 | | 3 | 2D | FP32 | NGCHW_GKCYX_NGKHW | V4 | FILTER_1X1_STRIDE1_PAD0 | | 4 | 2D | BF16 | NHWGC_GKYXC_NHWGK | V5 | FILTER_3x3 | | 5 | 3D | FP32 | NGCDHW_GKCZYX_NGKDHW | V1 | FILTER_1X1_PAD0 | | 6 | 3D | BF16 | GNDHWC_GKZYXC_GNDHWK | V3 | DEFAULT | | 7 | 3D | FP16 | NDHWGC_GKZYXC_NDHWGK | V4 | FILTER_1X1_PAD0 | * Add missing convolution layouts and provide better compile-time error in instance traits. * Fix clang formatting. * Changed I8 -> S8. * Fix signature. * Rename concepts and corresponding members. * Rename LDS related parameters. * Remove ODD_C specialization. Add V2 pipeline. * Add missing types. * Add elementwise operation to the conv signature. * Improve compile-time error message for unsupported elementwise ops. * Separate different fwd conv builder tests into separate compilation units. * Fix layout to string and add name to old CK PassThrough elementwise op. * Enable both CK and CK Tile tensor layouts in instance traits. * Fix clang-format. --------- Co-authored-by: John Shumway <jshumway@amd.com> Co-authored-by: John Shumway <john.shumwayjr@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: JH-Leon-KIM-AMD <jeonghyun.kim@amd.com>
This commit is contained in:
@@ -7,6 +7,7 @@ function(add_ck_builder_test test_name)
|
||||
target_include_directories(${test_name} PRIVATE
|
||||
"${PROJECT_SOURCE_DIR}/experimental/builder/include"
|
||||
"${PROJECT_SOURCE_DIR}/include"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
)
|
||||
target_compile_options(${test_name} PRIVATE
|
||||
-Wno-global-constructors
|
||||
@@ -24,3 +25,11 @@ add_ck_builder_test(test_get_instance_string
|
||||
test_get_instance_string.cpp)
|
||||
|
||||
add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp)
|
||||
|
||||
add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
conv/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp)
|
||||
47
experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp
Normal file
47
experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp
Normal file
@@ -0,0 +1,47 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DBF16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V1 and DEFAULT
|
||||
TEST_F(FwdConv2DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V1,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
|
||||
TEST_F(FwdConv2DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V5,
|
||||
ConvFwdSpecialization::FILTER_3x3>();
|
||||
}
|
||||
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp
Normal file
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DFP16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(FwdConv2DFP16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp
Normal file
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DFP32Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(FwdConv2DFP32Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
}
|
||||
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp
Normal file
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DBF16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 3D BF16 GNDHWC (group-first, channels-last) with Pipeline V3 and DEFAULT
|
||||
TEST_F(FwdConv3DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp
Normal file
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DFP16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 3D FP16 NDHWGC (channels-last) with Pipeline V4 and FILTER_1X1_PAD0
|
||||
TEST_F(FwdConv3DFP16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp
Normal file
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DFP32Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0
|
||||
TEST_F(FwdConv3DFP32Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V1,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
119
experimental/builder/test/impl/conv_algorithm_types.hpp
Normal file
119
experimental/builder/test/impl/conv_algorithm_types.hpp
Normal file
@@ -0,0 +1,119 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
|
||||
// Convenience struct for a tuple of m, n, and k values.
|
||||
template <typename T>
|
||||
struct MNK
|
||||
{
|
||||
T m{};
|
||||
T n{};
|
||||
T k{};
|
||||
};
|
||||
|
||||
// Specify thread block dimensions for a GEMM.
|
||||
struct ThreadBlock
|
||||
{
|
||||
// Thread block size.
|
||||
size_t block_size;
|
||||
// Size of the submatrix problem in a thread block.
|
||||
MNK<size_t> tile_size;
|
||||
};
|
||||
static_assert(ckb::ThreadBlockDescriptor<ThreadBlock>);
|
||||
|
||||
// Describe gridwise GEMM parameters.
|
||||
struct GridwiseGemm
|
||||
{
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
size_t m_per_xdl = 0;
|
||||
size_t n_per_xdl = 0;
|
||||
size_t m_xdl_per_wave = 0;
|
||||
size_t n_xdl_per_wave = 0;
|
||||
};
|
||||
static_assert(ckb::GridwiseGemmDescriptor<GridwiseGemm>);
|
||||
|
||||
// Describe Aand B block transfer thread cluster lengths.
|
||||
struct BlockTransfer
|
||||
{
|
||||
size_t k0;
|
||||
size_t m_n;
|
||||
size_t k1;
|
||||
};
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer>);
|
||||
|
||||
// Describe C block transfer thread cluster lengths.
|
||||
struct ThreadCluster
|
||||
{
|
||||
size_t m_block;
|
||||
size_t m_wave_per_xdl;
|
||||
size_t n_block;
|
||||
size_t n_wave_per_xdl;
|
||||
};
|
||||
static_assert(ThreadClusterDescriptor<ThreadCluster>);
|
||||
|
||||
struct LdsTransfer
|
||||
{
|
||||
size_t src_vector_dim;
|
||||
size_t src_scalar_per_vector;
|
||||
size_t lds_dst_scalar_per_vector;
|
||||
bool is_direct_load;
|
||||
bool lds_padding;
|
||||
};
|
||||
static_assert(LdsTransferDescriptor<LdsTransfer>);
|
||||
|
||||
struct Epilogue
|
||||
{
|
||||
size_t m_xdl_per_wave_per_shuffle;
|
||||
size_t n_xdl_per_wave_per_shuffle;
|
||||
size_t scalar_per_vector;
|
||||
};
|
||||
static_assert(EpilogueDescriptor<Epilogue>);
|
||||
|
||||
struct AccessOrder
|
||||
{
|
||||
std::array<size_t, 3> order;
|
||||
};
|
||||
static_assert(AccessOrderDescriptor<AccessOrder>);
|
||||
|
||||
struct BlockTransferABC
|
||||
{
|
||||
BlockTransfer block_transfer_a;
|
||||
BlockTransfer block_transfer_b;
|
||||
ThreadCluster thread_cluster_dims_c;
|
||||
LdsTransfer lds_transfer_a;
|
||||
LdsTransfer lds_transfer_b;
|
||||
Epilogue epilogue_c;
|
||||
AccessOrder block_transfer_access_order_a;
|
||||
AccessOrder block_transfer_access_order_b;
|
||||
AccessOrder src_access_order_a;
|
||||
AccessOrder src_access_order_b;
|
||||
};
|
||||
|
||||
struct ConvAlgorithm
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
GridwiseGemm gridwise_gemm;
|
||||
BlockTransferABC block_transfer;
|
||||
BlockGemmPipelineVersion pipeline_version;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesGridwiseGemm<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesSourceAccessOrder<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesGemmPipelineVersion<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm>);
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
23
experimental/builder/test/impl/conv_signature_types.hpp
Normal file
23
experimental/builder/test/impl/conv_signature_types.hpp
Normal file
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
template <typename GroupConvLayout>
|
||||
struct ConvSignature
|
||||
{
|
||||
int spatial_dim;
|
||||
ConvDirection direction;
|
||||
GroupConvLayout layout;
|
||||
DataType data_type;
|
||||
ElementwiseOperation elementwise_operation;
|
||||
};
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout1D>>);
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout2D>>);
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout3D>>);
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
103
experimental/builder/test/utils/ckb_conv_test_common.hpp
Normal file
103
experimental/builder/test/utils/ckb_conv_test_common.hpp
Normal file
@@ -0,0 +1,103 @@
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
// Common test base class
|
||||
class FwdConvBuilderTestBase : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
// Common test implementation
|
||||
template <auto FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
BlockGemmPipelineVersion FwdPipelineVersion,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test()
|
||||
{
|
||||
constexpr GridwiseGemm FwdGemmParams{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr ConvAlgorithm FwdConvAlgorithm{.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.pipeline_version = FwdPipelineVersion,
|
||||
.fwd_specialization = FwdConvSpecialization};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
|
||||
|
||||
// Verify pipeline version is correct
|
||||
if(FwdPipelineVersion == BlockGemmPipelineVersion::V1)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
// Common thread block configurations
|
||||
constexpr ThreadBlock DefaultThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock SmallThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
Reference in New Issue
Block a user