mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Split-K autodeduction (#3351)
* First version of split-K autodeduction. * Fix circular dependency and kernel construction. * Fix tolerance calculation for bwd weight example. * Simplify kernel construction. * Fix kernel launching bug for split-K autodeduce. * Add split-K autodeduction support for the two stage example. * Fix a corner case. * Fix clang-format. * Fix clang-format for inc files. * Add missing header. * Prevent too large split-K values. * Fix formatting. * Add unit tests for IsSupportedArgument in grouped bwd conv. * clang-format. * Fix merge conflicts. * Address feedback from code review. * clang-format * Fix new tests after merge. --------- Co-authored-by: Ville Pietilä <>
This commit is contained in:
@@ -38,3 +38,4 @@ add_subdirectory(atomic_add_op)
|
||||
add_subdirectory(fmha)
|
||||
add_subdirectory(gemm_tile_engine)
|
||||
add_subdirectory(pooling)
|
||||
add_subdirectory(grouped_conv)
|
||||
|
||||
7
test/ck_tile/grouped_conv/CMakeLists.txt
Normal file
7
test/ck_tile/grouped_conv/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_grouped_conv_bwd_weight test_ck_tile_grouped_conv_bwd_weight.cpp)
|
||||
endif()
|
||||
@@ -0,0 +1,249 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
struct TestConvConfig
|
||||
{
|
||||
static constexpr index_t VectorSizeA = 4;
|
||||
static constexpr index_t VectorSizeB = 8;
|
||||
static constexpr index_t VectorSizeC = 8;
|
||||
|
||||
static constexpr index_t M_Tile = 128;
|
||||
static constexpr index_t N_Tile = 128;
|
||||
static constexpr index_t K_Tile = 32;
|
||||
|
||||
static constexpr index_t M_Warp = 2;
|
||||
static constexpr index_t N_Warp = 2;
|
||||
static constexpr index_t K_Warp = 1;
|
||||
|
||||
static constexpr index_t M_Warp_Tile = 16;
|
||||
static constexpr index_t N_Warp_Tile = 16;
|
||||
static constexpr index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr GemmPipeline Pipeline = GemmPipeline::COMPUTE_V3;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
static constexpr index_t NumGroupsToMerge = 1;
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
// Helper to build full kernel type
|
||||
template <typename PrecType,
|
||||
typename ConvConfig,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
memory_operation_enum MemOp = memory_operation_enum::set,
|
||||
index_t NDimSpatial = 2>
|
||||
struct BuildKernel
|
||||
{
|
||||
using GemmShape = TileGemmShape<
|
||||
sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
|
||||
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>;
|
||||
|
||||
using ConvTraits = GroupedConvTraits<NDimSpatial,
|
||||
ConvolutionSpecialization::Default,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
tuple<>,
|
||||
OutLayout,
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge>;
|
||||
|
||||
using TilePartitioner = GemmSpatiallyLocalTilePartitioner<GemmShape, 8, 4>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
TileGemmUniversalTraits<ConvTraits::FixedGemmParams::kPadM,
|
||||
ConvTraits::FixedGemmParams::kPadN,
|
||||
ConvTraits::FixedGemmParams::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename ConvTraits::AsLayoutBwdWeight,
|
||||
typename ConvTraits::BsLayoutBwdWeight,
|
||||
typename ConvTraits::CLayoutBwdWeight,
|
||||
ConvTraits::FixedGemmParams::TransposeC,
|
||||
ConvTraits::FixedGemmParams::UseStructuredSparsity,
|
||||
ConvTraits::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
GemmPipelineProblem<PrecType, // OutDataType (A in bwd weight)
|
||||
PrecType, // InDataType (B in bwd weight)
|
||||
float, // AccDataType
|
||||
GemmShape,
|
||||
typename ConvTraits::template GroupedConvImplicitGemmTraitsBwdWeight<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
element_wise::PassThrough,
|
||||
element_wise::PassThrough,
|
||||
PrecType, // WeiDataType (C in bwd weight)
|
||||
ConvTraits::FixedGemmParams::FixedVectorSize,
|
||||
ConvTraits::VectorSizeA,
|
||||
ConvTraits::VectorSizeB>;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
UniversalGemmPipelineProblem<PrecType,
|
||||
PrecType,
|
||||
float,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
ConvConfig::Scheduler,
|
||||
element_wise::PassThrough,
|
||||
element_wise::PassThrough,
|
||||
PrecType,
|
||||
ConvTraits::FixedGemmParams::FixedVectorSize,
|
||||
ConvTraits::VectorSizeA,
|
||||
ConvTraits::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
using EpilogueProblem = CShuffleEpilogueProblem<PrecType,
|
||||
PrecType,
|
||||
tuple<>,
|
||||
float,
|
||||
PrecType,
|
||||
typename ConvTraits::ImplicitGemmDsLayout,
|
||||
typename ConvTraits::FixedGemmParams::ELayout,
|
||||
element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
ConvTraits::FixedGemmParams::TransposeC,
|
||||
MemOp,
|
||||
ConvConfig::NumWaveGroups,
|
||||
ConvTraits::FixedGemmParams::FixedVectorSize,
|
||||
ConvTraits::VectorSizeC>;
|
||||
|
||||
using Epilogue = CShuffleEpilogue<EpilogueProblem>;
|
||||
|
||||
using type =
|
||||
GroupedConvolutionBackwardWeightKernel<ConvTraits, TilePartitioner, GemmPipeline, Epilogue>;
|
||||
};
|
||||
|
||||
// Helper to create 2D host args
|
||||
static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t left_pad_y,
|
||||
index_t left_pad_x,
|
||||
index_t right_pad_y,
|
||||
index_t right_pad_x,
|
||||
index_t k_batch = 1)
|
||||
{
|
||||
auto conv_param = conv::ConvParam{2,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
{Y, X},
|
||||
{Hi, Wi},
|
||||
{stride_y, stride_x},
|
||||
{dilation_y, dilation_x},
|
||||
{left_pad_y, left_pad_x},
|
||||
{right_pad_y, right_pad_x}};
|
||||
|
||||
return GroupedConvBwdWeightHostArgs{conv_param, nullptr, nullptr, {}, nullptr, k_batch};
|
||||
}
|
||||
|
||||
static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t k_batch)
|
||||
{
|
||||
return create_2d_host_args(2, 2, 8, 8, 3, 3, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, k_batch);
|
||||
}
|
||||
|
||||
class GroupedConvBwdWeightIsSupportedArgumentTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, ValidKBatch)
|
||||
{
|
||||
using Kernel = typename BuildKernel<half_t,
|
||||
TestConvConfig,
|
||||
tensor_layout::convolution::NHWGC,
|
||||
tensor_layout::convolution::GKYXC,
|
||||
tensor_layout::convolution::NHWGK>::type;
|
||||
|
||||
auto host_args_kbatch_1 = create_2d_host_args(1);
|
||||
auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_1));
|
||||
|
||||
auto host_args_kbatch_4 = create_2d_host_args(4);
|
||||
auto kargs_4 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_4);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_4));
|
||||
}
|
||||
|
||||
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, InvalidKBatchLessThanOne)
|
||||
{
|
||||
using Kernel = typename BuildKernel<half_t,
|
||||
TestConvConfig,
|
||||
tensor_layout::convolution::NHWGC,
|
||||
tensor_layout::convolution::GKYXC,
|
||||
tensor_layout::convolution::NHWGK>::type;
|
||||
|
||||
auto host_args_kbatch_0 = create_2d_host_args(0);
|
||||
auto kargs = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_0);
|
||||
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs));
|
||||
}
|
||||
|
||||
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreaterThanOne)
|
||||
{
|
||||
using Kernel = typename BuildKernel<half_t,
|
||||
TestConvConfig,
|
||||
tensor_layout::convolution::NHWGC,
|
||||
tensor_layout::convolution::GKYXC,
|
||||
tensor_layout::convolution::NHWGK,
|
||||
memory_operation_enum::atomic_add>::type;
|
||||
|
||||
// k_batch = 1 should fail with atomic_add
|
||||
auto host_args_kbatch_1 = create_2d_host_args(1);
|
||||
auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1);
|
||||
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_1));
|
||||
|
||||
// k_batch = 2 should pass
|
||||
auto host_args_kbatch_2 = create_2d_host_args(2);
|
||||
auto kargs_2 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_2);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2));
|
||||
}
|
||||
|
||||
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKBatch)
|
||||
{
|
||||
using Kernel = typename BuildKernel<half_t,
|
||||
TestConvConfig,
|
||||
tensor_layout::convolution::NHWGC,
|
||||
tensor_layout::convolution::GKYXC,
|
||||
tensor_layout::convolution::NHWGK>::type;
|
||||
|
||||
// k_batch = 128 should pass
|
||||
auto host_args_kbatch_128 = create_2d_host_args(128);
|
||||
auto kargs_128 =
|
||||
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_128);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_128));
|
||||
|
||||
// k_batch = 129 should fail for half_t output
|
||||
auto host_args_kbatch_129 = create_2d_host_args(129);
|
||||
auto kargs_129 =
|
||||
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_129);
|
||||
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_129));
|
||||
}
|
||||
Reference in New Issue
Block a user