[CK_BUILDER] Add grouped conv bwd ck tile traits (#3281)

* [CK_BUILDER] Add grouped conv bwd ck tile traits

* copilot fixes
This commit is contained in:
Bartłomiej Kocot
2025-11-25 14:57:43 +01:00
committed by GitHub
parent ab0101c59c
commit 9ac2666d5b
10 changed files with 583 additions and 13 deletions

View File

@@ -133,7 +133,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GemmPipelineProblem::TransposeC,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,

View File

@@ -0,0 +1,140 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// InstanceTraits specialization for GroupedConvolutionBackwardDataKernel
//
// CRITICAL MAINTENANCE NOTE:
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
// ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp
// "In sync" means that the template parameter order, names, and types in the declaration below
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
// difficult to diagnose. Always update both files together and review changes carefully.
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
// Backward Data declaration to avoid circular dependency.
namespace ck_tile {
template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct GroupedConvolutionBackwardDataKernel;
} // namespace ck_tile
namespace ck_tile {
namespace reflect {
// Specialization for GroupedConvolutionBackwardDataKernel
template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct InstanceTraits<ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>>
{
// CK Tile Conv Traits
// Spatial dimension
static constexpr int kSpatialDim = GroupedConvTraitsType_::NDimSpatial;
// Specialization
static constexpr ck_tile::ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
// DataType types
using InLayout = typename GroupedConvTraitsType_::InLayout;
using WeiLayout = typename GroupedConvTraitsType_::WeiLayout;
using DsLayout = typename GroupedConvTraitsType_::DsLayout;
using OutLayout = typename GroupedConvTraitsType_::OutLayout;
// Vector size
static constexpr int kVectorSizeA = GroupedConvTraitsType_::VectorSizeA;
static constexpr int kVectorSizeB = GroupedConvTraitsType_::VectorSizeB;
static constexpr int kVectorSizeC = GroupedConvTraitsType_::VectorSizeC;
// Num Groups To Merge
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// Split image (large tensors)
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
// TilePartitioner
// Block configuration
static constexpr int kMPerBlock = TilePartitioner_::MPerBlock;
static constexpr int kNPerBlock = TilePartitioner_::NPerBlock;
static constexpr int kKPerBlock = TilePartitioner_::KPerBlock;
static constexpr int kMWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<0>{});
static constexpr int kNWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<1>{});
static constexpr int kKWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<2>{});
static constexpr int kMWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<0>{});
static constexpr int kNWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<1>{});
static constexpr int kKWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<2>{});
// Data types
using ADataType = typename GemmPipeline_::ADataType;
using BDataType = typename GemmPipeline_::BDataType;
// Gemm Pipeline
using GemmPipeline = GemmPipeline_;
static constexpr ck_tile::GemmPipelineScheduler kPipelineScheduler = GemmPipeline_::Scheduler;
static constexpr bool kDoubleSmemBuffer = GemmPipeline_::DoubleSmemBuffer;
static constexpr int kNumWaveGroups = GemmPipeline_::NumWaveGroups;
// Epilogue Pipeline
using AccDataType = typename EpiloguePipeline_::AccDataType;
using EDataType = typename EpiloguePipeline_::ODataType;
using DsDataType = typename EpiloguePipeline_::DsDataType;
using CDEElementwiseOperation = typename EpiloguePipeline_::CDElementwise;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << "GroupedConvolutionBackwardDataKernel";
// Template parameters in exact order matching InstanceTraits member order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << ","
<< ck_tile::getConvSpecializationString(ConvSpecialization); // 2. ConvSpecialization
oss << "," << detail::layout_name<InLayout>(); // 3. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 4. WeiLayout
oss << "," << detail::tuple_name<DsLayout>(); // 5. DsLayout
oss << "," << detail::layout_name<OutLayout>(); // 6. OutLayout
oss << "," << kVectorSizeA; // 7. VectorSizeA
oss << "," << kVectorSizeB; // 8. VectorSizeB
oss << "," << kVectorSizeC; // 9. VectorSizeC
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
oss << "," << kMPerBlock; // 12. MPerBlock
oss << "," << kNPerBlock; // 13. NPerBlock
oss << "," << kKPerBlock; // 14. KPerBlock
oss << "," << kMWarp; // 15. MWarp
oss << "," << kNWarp; // 16. NWarp
oss << "," << kKWarp; // 17. KWarp
oss << "," << kMWarpTile; // 18. MWarpTile
oss << "," << kNWarpTile; // 19. NWarpTile
oss << "," << kKWarpTile; // 20. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
// CDEElementwiseOperation
oss << ">";
return oss.str();
}
};
} // namespace reflect
} // namespace ck_tile

View File

@@ -0,0 +1,140 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// InstanceTraits specialization for GroupedConvolutionBackwardWeightKernel
//
// CRITICAL MAINTENANCE NOTE:
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
// ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp
// "In sync" means that the template parameter order, names, and types in the declaration below
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
// difficult to diagnose. Always update both files together and review changes carefully.
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
// Backward Weight declaration to avoid circular dependency.
namespace ck_tile {
template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct GroupedConvolutionBackwardWeightKernel;
} // namespace ck_tile
namespace ck_tile {
namespace reflect {
// Specialization for GroupedConvolutionBackwardWeightKernel
template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct InstanceTraits<ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>>
{
// CK Tile Conv Traits
// Spatial dimension
static constexpr int kSpatialDim = GroupedConvTraitsType_::NDimSpatial;
// Specialization
static constexpr ck_tile::ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
// DataType types
using InLayout = typename GroupedConvTraitsType_::InLayout;
using WeiLayout = typename GroupedConvTraitsType_::WeiLayout;
using DsLayout = typename GroupedConvTraitsType_::DsLayout;
using OutLayout = typename GroupedConvTraitsType_::OutLayout;
// Vector size
static constexpr int kVectorSizeA = GroupedConvTraitsType_::VectorSizeA;
static constexpr int kVectorSizeB = GroupedConvTraitsType_::VectorSizeB;
static constexpr int kVectorSizeC = GroupedConvTraitsType_::VectorSizeC;
// Num Groups To Merge
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// Split image (large tensors)
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
// TilePartitioner
// Block configuration
static constexpr int kMPerBlock = TilePartitioner_::MPerBlock;
static constexpr int kNPerBlock = TilePartitioner_::NPerBlock;
static constexpr int kKPerBlock = TilePartitioner_::KPerBlock;
static constexpr int kMWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<0>{});
static constexpr int kNWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<1>{});
static constexpr int kKWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<2>{});
static constexpr int kMWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<0>{});
static constexpr int kNWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<1>{});
static constexpr int kKWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<2>{});
// Data types
using ADataType = typename GemmPipeline_::ADataType;
using BDataType = typename GemmPipeline_::BDataType;
// Gemm Pipeline
using GemmPipeline = GemmPipeline_;
static constexpr ck_tile::GemmPipelineScheduler kPipelineScheduler = GemmPipeline_::Scheduler;
static constexpr bool kDoubleSmemBuffer = GemmPipeline_::DoubleSmemBuffer;
static constexpr int kNumWaveGroups = GemmPipeline_::NumWaveGroups;
// Epilogue Pipeline
using AccDataType = typename EpiloguePipeline_::AccDataType;
using EDataType = typename EpiloguePipeline_::ODataType;
using DsDataType = typename EpiloguePipeline_::DsDataType;
using CDEElementwiseOperation = typename EpiloguePipeline_::CDElementwise;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << "GroupedConvolutionBackwardWeightKernel";
// Template parameters in exact order matching InstanceTraits member order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << ","
<< ck_tile::getConvSpecializationString(ConvSpecialization); // 2. ConvSpecialization
oss << "," << detail::layout_name<InLayout>(); // 3. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 4. WeiLayout
oss << "," << detail::tuple_name<DsLayout>(); // 5. DsLayout
oss << "," << detail::layout_name<OutLayout>(); // 6. OutLayout
oss << "," << kVectorSizeA; // 7. VectorSizeA
oss << "," << kVectorSizeB; // 8. VectorSizeB
oss << "," << kVectorSizeC; // 9. VectorSizeC
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
oss << "," << kMPerBlock; // 12. MPerBlock
oss << "," << kNPerBlock; // 13. NPerBlock
oss << "," << kKPerBlock; // 14. KPerBlock
oss << "," << kMWarp; // 15. MWarp
oss << "," << kNWarp; // 16. NWarp
oss << "," << kKWarp; // 17. KWarp
oss << "," << kMWarpTile; // 18. MWarpTile
oss << "," << kNWarpTile; // 19. NWarpTile
oss << "," << kKWarpTile; // 20. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
// CDEElementwiseOperation
oss << ">";
return oss.str();
}
};
} // namespace reflect
} // namespace ck_tile

View File

@@ -16,7 +16,7 @@
#include "instance_traits_util.hpp"
// Forward declaration to avoid circular dependency.
namespace ck_tile::device {
namespace ck_tile {
template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
@@ -24,7 +24,7 @@ template <typename GroupedConvTraitsType_,
typename EpiloguePipeline_>
struct GroupedConvolutionForwardKernel;
} // namespace ck_tile::device
} // namespace ck_tile
namespace ck_tile {
namespace reflect {
@@ -34,10 +34,10 @@ template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct InstanceTraits<ck_tile::device::GroupedConvolutionForwardKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>>
struct InstanceTraits<ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>>
{
// CK Tile Conv Traits
// Spatial dimension
@@ -122,7 +122,7 @@ struct InstanceTraits<ck_tile::device::GroupedConvolutionForwardKernel<GroupedCo
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 25. NumWaveGroups
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
oss << "," << detail::type_name<EDataType>(); // 28. EDataType

View File

@@ -21,6 +21,7 @@ add_ck_builder_test(test_ckb_conv_builder
test_conv_builder.cpp
test_fwd_instance_traits.cpp
test_bwd_weight_instance_traits.cpp
test_bwd_data_instance_traits.cpp
test_instance_traits_util.cpp)
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)

View File

@@ -0,0 +1,133 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <ck/ck.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp>
namespace {
TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
{
using GroupedConvTraitsType =
ck_tile::GroupedConvTraits<2 /*NDimSpatial*/,
ck_tile::ConvolutionSpecialization::Default /*ConvSpec*/,
ck_tile::tensor_layout::convolution::NHWGC /*InLayout*/,
ck_tile::tensor_layout::convolution::GKYXC /*WeiLayout*/,
ck_tile::tuple<> /*DsLayout*/,
ck_tile::tensor_layout::convolution::NHWGK /*OutLayout*/,
4 /*VectorSizeA*/,
4 /*VectorSizeB*/,
4 /*VectorSizeC*/,
1 /*NumGroupsToMerge*/,
false /*EnableSplitImage*/>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
ck_tile::sequence<4 /*M_Warp*/, 1 /*N_Warp*/, 1 /*K_Warp*/>,
ck_tile::sequence<16 /*M_Warp_Tile*/, 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/>>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
false /*DoubleSmemBuffer*/,
typename GroupedConvTraitsType::AsLayoutBwdData,
typename GroupedConvTraitsType::BsLayoutBwdData,
typename GroupedConvTraitsType::CLayoutBwdData,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
1 /*NumWaveGroups*/>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ck_tile::bf16_t /*OutDataType*/,
ck_tile::bf16_t /*WeiDataType*/,
float /*AccDataType*/,
GemmShape,
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/,
true /*has_hot_loop_v*/,
ck_tile::TailNumber::Full /*tail_number_v*/,
ck_tile::element_wise::PassThrough /*AElementwiseOperation*/,
ck_tile::element_wise::PassThrough /*BElementwiseOperation*/,
ck_tile::bf16_t /*InDataType*/,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ck_tile::bf16_t /*OutDataType*/,
ck_tile::bf16_t /*WeiDataType*/,
ck_tile::tuple<> /*DsDataType*/,
float /*AccDataType*/,
ck_tile::bf16_t /*InDataType*/,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
ck_tile::element_wise::PassThrough /*CDElementWise*/,
128 /*MPerBlock*/,
128 /*NPerBlock*/,
4 /*M_Warp*/,
1 /*N_Warp*/,
16 /*M_Warp_Tile*/,
16 /*N_Warp_Tile*/,
16 /*K_Warp_Tile*/,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ck_tile::memory_operation_enum::set /*memory_operation*/,
1 /*kNumWaveGroups*/,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using GroupedConvBwdDataKernel =
ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
std::string instance_str = ck_tile::reflect::instance_string<GroupedConvBwdDataKernel>();
std::string expected_str = "GroupedConvolutionBackwardDataKernel"
"<2" // NDimSpatial
",Default" // ConvSpecialization
",NHWGC" // InLayout
",GKYXC" // WeiLayout
",EmptyTuple" // DsLayout
",NHWGK" // OutLayout
",4" // VectorSizeA
",4" // VectorSizeB
",4" // VectorSizeC
",1" // NumGroupsToMerge
",0" // EnableSplitImage
",128" // MPerBlock
",128" // NPerBlock
",32" // KPerBlock
",4" // MWarp
",1" // NWarp
",1" // KWarp
",16" // MWarpTile
",16" // NWarpTile
",16" // KWarpTile
",bf16" // ADataType
",bf16" // BDataType
",COMPUTE_V3" // BlkGemmPipelineVer
",Intrawave" // BlkGemmPipeSched
",0" // DoubleSmemBuffer
",1" // NumWaveGroups
",fp32" // AccDataType
",bf16" // EDataType
",EmptyTuple" // DsDataType
",PassThrough" // CDEElementwiseOperation
">";
EXPECT_EQ(instance_str, expected_str);
}
} // anonymous namespace

View File

@@ -5,6 +5,7 @@
#include <ck/ck.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp>
namespace {
@@ -109,4 +110,126 @@ TEST(InstanceTraits, BwdWeightXdlCShuffleInstanceStringReturnsCorrectFormat)
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
{
using GroupedConvTraitsType =
ck_tile::GroupedConvTraits<2 /*NDimSpatial*/,
ck_tile::ConvolutionSpecialization::Default /*ConvSpec*/,
ck_tile::tensor_layout::convolution::NHWGC /*InLayout*/,
ck_tile::tensor_layout::convolution::GKYXC /*WeiLayout*/,
ck_tile::tuple<> /*DsLayout*/,
ck_tile::tensor_layout::convolution::NHWGK /*OutLayout*/,
4 /*VectorSizeA*/,
4 /*VectorSizeB*/,
4 /*VectorSizeC*/,
1 /*NumGroupsToMerge*/,
false /*EnableSplitImage*/>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
ck_tile::sequence<4 /*M_Warp*/, 1 /*N_Warp*/, 1 /*K_Warp*/>,
ck_tile::sequence<16 /*M_Warp_Tile*/, 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/>>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
false /*DoubleSmemBuffer*/,
typename GroupedConvTraitsType::AsLayoutBwdWeight,
typename GroupedConvTraitsType::BsLayoutBwdWeight,
typename GroupedConvTraitsType::CLayoutBwdWeight,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
1 /*NumWaveGroups*/>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ck_tile::bf16_t /*OutDataType*/,
ck_tile::bf16_t /*InDataType*/,
float /*AccDataType*/,
GemmShape,
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/,
true /*has_hot_loop_v*/,
ck_tile::TailNumber::Full /*tail_number_v*/,
ck_tile::element_wise::PassThrough /*AElementwiseOperation*/,
ck_tile::element_wise::PassThrough /*BElementwiseOperation*/,
ck_tile::bf16_t /*WeiDataType*/,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ck_tile::bf16_t /*OutDataType*/,
ck_tile::bf16_t /*InDataType*/,
ck_tile::tuple<> /*DsDataType*/,
float /*AccDataType*/,
ck_tile::bf16_t /*WeiDataType*/,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
ck_tile::element_wise::PassThrough /*CDElementWise*/,
128 /*MPerBlock*/,
128 /*NPerBlock*/,
4 /*M_Warp*/,
1 /*N_Warp*/,
16 /*M_Warp_Tile*/,
16 /*N_Warp_Tile*/,
16 /*K_Warp_Tile*/,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ck_tile::memory_operation_enum::set /*memory_operation*/,
1 /*kNumWaveGroups*/,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using GroupedConvBwdWeiKernel =
ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
std::string instance_str = ck_tile::reflect::instance_string<GroupedConvBwdWeiKernel>();
std::string expected_str = "GroupedConvolutionBackwardWeightKernel"
"<2" // NDimSpatial
",Default" // ConvSpecialization
",NHWGC" // InLayout
",GKYXC" // WeiLayout
",EmptyTuple" // DsLayout
",NHWGK" // OutLayout
",4" // VectorSizeA
",4" // VectorSizeB
",4" // VectorSizeC
",1" // NumGroupsToMerge
",0" // EnableSplitImage
",128" // MPerBlock
",128" // NPerBlock
",32" // KPerBlock
",4" // MWarp
",1" // NWarp
",1" // KWarp
",16" // MWarpTile
",16" // NWarpTile
",16" // KWarpTile
",bf16" // ADataType
",bf16" // BDataType
",COMPUTE_V3" // BlkGemmPipelineVer
",Intrawave" // BlkGemmPipeSched
",0" // DoubleSmemBuffer
",1" // NumWaveGroups
",fp32" // AccDataType
",bf16" // EDataType
",EmptyTuple" // DsDataType
",PassThrough" // CDEElementwiseOperation
">";
EXPECT_EQ(instance_str, expected_str);
}
} // anonymous namespace

View File

@@ -799,11 +799,10 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using GroupedConvFwdKernel =
ck_tile::device::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
using GroupedConvFwdKernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
std::string instance_str = ck_tile::reflect::instance_string<GroupedConvFwdKernel>();

View File

@@ -14,6 +14,10 @@
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp"
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp"
#endif
namespace ck_tile {
/// @brief The Grouped Convolution kernel device arguments.
@@ -565,6 +569,19 @@ struct GroupedConvolutionBackwardDataKernel
// clang-format on
}
#ifdef CK_EXPERIMENTAL_BUILDER
CK_TILE_HOST std::string GetInstanceString() const
{
static_assert(ck_tile::reflect::HasInstanceTraits<GroupedConvolutionBackwardDataKernel>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_tile_grouped_convolution_backward_data.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<GroupedConvolutionBackwardDataKernel>();
}
#endif
CK_TILE_HOST static auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
{
// enable batched grouped gemm

View File

@@ -14,6 +14,10 @@
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
#endif
namespace ck_tile {
/// @brief The Grouped Convolution kernel device arguments.
@@ -430,6 +434,19 @@ struct GroupedConvolutionBackwardWeightKernel
// clang-format on
}
#ifdef CK_EXPERIMENTAL_BUILDER
CK_TILE_HOST std::string GetInstanceString() const
{
static_assert(ck_tile::reflect::HasInstanceTraits<GroupedConvolutionBackwardWeightKernel>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_tile_grouped_convolution_backward_weight.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<GroupedConvolutionBackwardWeightKernel>();
}
#endif
CK_TILE_HOST static constexpr auto
GridSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{