[CK_BUILDER] Add grouped conv fwd ck tile traits (#3183)

* [CK BUILDER] Add grouped conv fwd ck tile traits

* Update instance_traits_tile_grouped_convolution_forward.hpp

* Update grouped_convolution_forward_kernel.hpp
This commit is contained in:
Bartłomiej Kocot
2025-11-11 22:55:33 +01:00
committed by GitHub
parent b145a5fe80
commit 92c1f4981a
18 changed files with 433 additions and 15 deletions

View File

@@ -15,6 +15,9 @@
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/utility/loop_scheduler.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
#include <ck_tile/ops/gemm.hpp>
#include "ck_tile/ops/epilogue.hpp"
#include <ck_tile/ops/grouped_convolution.hpp>
namespace ck_tile::reflect::conv {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
// Forward declaration to avoid circular dependency

View File

@@ -0,0 +1,140 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// InstanceTraits specialization for GroupedConvolutionForwardKernel
//
// CRITICAL MAINTENANCE NOTE:
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
// ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp
// "In sync" means that the template parameter order, names, and types in the declaration below
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
// difficult to diagnose. Always update both files together and review changes carefully.
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
// Forward declaration to avoid circular dependency.
namespace ck_tile::device {
template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct GroupedConvolutionForwardKernel;
} // namespace ck_tile::device
namespace ck_tile {
namespace reflect {
// Specialization for GroupedConvolutionForwardKernel
template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct InstanceTraits<ck_tile::device::GroupedConvolutionForwardKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>>
{
// CK Tile Conv Traits
// Spatial dimension
static constexpr int kSpatialDim = GroupedConvTraitsType_::NDimSpatial;
// Specialization
static constexpr ck_tile::ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
// DataType types
using InLayout = typename GroupedConvTraitsType_::InLayout;
using WeiLayout = typename GroupedConvTraitsType_::WeiLayout;
using DsLayout = typename GroupedConvTraitsType_::DsLayout;
using OutLayout = typename GroupedConvTraitsType_::OutLayout;
// Vector size
static constexpr int kVectorSizeA = GroupedConvTraitsType_::VectorSizeA;
static constexpr int kVectorSizeB = GroupedConvTraitsType_::VectorSizeB;
static constexpr int kVectorSizeC = GroupedConvTraitsType_::VectorSizeC;
// Num Groups To Merge
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// Split image (large tensors)
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
// TilePartitioner
// Block configuration
static constexpr int kMPerBlock = TilePartitioner_::MPerBlock;
static constexpr int kNPerBlock = TilePartitioner_::NPerBlock;
static constexpr int kKPerBlock = TilePartitioner_::KPerBlock;
static constexpr int kMWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<0>{});
static constexpr int kNWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<1>{});
static constexpr int kKWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<2>{});
static constexpr int kMWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<0>{});
static constexpr int kNWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<1>{});
static constexpr int kKWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<2>{});
// Data types
using ADataType = typename GemmPipeline_::ADataType;
using BDataType = typename GemmPipeline_::BDataType;
// Gemm Pipeline
using GemmPipeline = GemmPipeline_;
static constexpr ck_tile::GemmPipelineScheduler kPipelineScheduler = GemmPipeline_::Scheduler;
static constexpr bool kDoubleSmemBuffer = GemmPipeline_::DoubleSmemBuffer;
static constexpr int kNumWaveGroups = GemmPipeline_::NumWaveGroups;
// Epilogue Pipeline
using AccDataType = typename EpiloguePipeline_::AccDataType;
using EDataType = typename EpiloguePipeline_::ODataType;
using DsDataType = typename EpiloguePipeline_::DsDataType;
using CDEElementwiseOperation = typename EpiloguePipeline_::CDElementwise;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << "GroupedConvolutionForwardKernel";
// Template parameters in exact order matching InstanceTraits member order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << ","
<< ck_tile::getConvSpecializationString(ConvSpecialization); // 2. ConvSpecialization
oss << "," << detail::layout_name<InLayout>(); // 3. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 4. WeiLayout
oss << "," << detail::tuple_name<DsLayout>(); // 5. DsLayout
oss << "," << detail::layout_name<OutLayout>(); // 6. OutLayout
oss << "," << kVectorSizeA; // 7. VectorSizeA
oss << "," << kVectorSizeB; // 8. VectorSizeB
oss << "," << kVectorSizeC; // 9. VectorSizeC
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
oss << "," << kMPerBlock; // 12. MPerBlock
oss << "," << kNPerBlock; // 13. NPerBlock
oss << "," << kKPerBlock; // 14. KPerBlock
oss << "," << kMWarp; // 15. MWarp
oss << "," << kNWarp; // 16. NWarp
oss << "," << kKWarp; // 17. KWarp
oss << "," << kMWarpTile; // 18. MWarpTile
oss << "," << kNWarpTile; // 19. NWarpTile
oss << "," << kKWarpTile; // 20. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 25. NumWaveGroups
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
// CDEElementwiseOperation
oss << ">";
return oss.str();
}
};
} // namespace reflect
} // namespace ck_tile

View File

@@ -28,6 +28,10 @@
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck_tile/ops/gemm.hpp>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
namespace ck_tile::reflect::detail {
@@ -38,7 +42,7 @@ namespace impl {
template <typename T>
consteval std::string_view type_name_impl()
{
if constexpr(std::is_same_v<T, ck::half_t>)
if constexpr(std::is_same_v<T, ck::half_t> || std::is_same_v<T, ck_tile::half_t>)
return "fp16";
else if constexpr(std::is_same_v<T, float>)
return "fp32";
@@ -50,11 +54,11 @@ consteval std::string_view type_name_impl()
return "s8";
else if constexpr(std::is_same_v<T, int32_t>)
return "s32";
else if constexpr(std::is_same_v<T, ck::bhalf_t>)
else if constexpr(std::is_same_v<T, ck::bhalf_t> || std::is_same_v<T, ck_tile::bf16_t>)
return "bf16";
else if constexpr(std::is_same_v<T, ck::f8_t>)
else if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck_tile::fp8_t>)
return "fp8";
else if constexpr(std::is_same_v<T, ck::bf8_t>)
else if constexpr(std::is_same_v<T, ck::bf8_t> || std::is_same_v<T, ck_tile::bf8_t>)
return "bf8";
else
return std::string_view{}; // Return empty for supported types
@@ -168,6 +172,17 @@ constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineSchedule
}
}
constexpr std::string_view pipeline_scheduler_name(ck_tile::GemmPipelineScheduler sched)
{
using enum ck_tile::GemmPipelineScheduler;
switch(sched)
{
case Default: return "Default";
case Intrawave: return "Intrawave";
case Interwave: return "Interwave";
}
}
// Convert BlockGemmPipelineVersion enum to string
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
{
@@ -206,6 +221,26 @@ constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched)
}
}
// Convert TailNumber enum to string
constexpr std::string_view tail_number_name(ck_tile::TailNumber tail_num)
{
using enum ck_tile::TailNumber;
switch(tail_num)
{
case Odd: return "Odd";
case Even: return "Even";
case One: return "One";
case Two: return "Two";
case Three: return "Three";
case Four: return "Four";
case Five: return "Five";
case Six: return "Six";
case Seven: return "Seven";
case Empty: return "Empty";
case Full: return "Full";
}
}
// Convert std::array to string
template <typename T, std::size_t N>
inline std::string array_to_string(const std::array<T, N>& arr)
@@ -356,17 +391,53 @@ constexpr std::string tuple_name()
}(static_cast<T*>(nullptr));
}
template <typename T>
requires requires { []<typename... Ts>(ck_tile::tuple<Ts...>*) {}(static_cast<T*>(nullptr)); }
constexpr std::string tuple_name()
{
return []<typename... Ts>(ck_tile::tuple<Ts...>*) constexpr {
if constexpr(sizeof...(Ts) == 0)
{
return std::string("EmptyTuple");
}
else if constexpr((IsLayoutType<Ts> && ...))
{
// Lambda wrapper for layout_name
auto layout_name_fn = []<typename U>() { return layout_name<U>(); };
return detail::build_list_string<decltype(layout_name_fn), Ts...>("tuple",
layout_name_fn);
}
else if constexpr((IsDataType<Ts> && ...))
{
// Lambda wrapper for type_name
auto type_name_fn = []<typename U>() { return type_name<U>(); };
return detail::build_list_string<decltype(type_name_fn), Ts...>("tuple", type_name_fn);
}
else
{
static_assert((IsLayoutType<Ts> && ...) || (IsDataType<Ts> && ...),
"tuple elements must be all layouts or all data types, not mixed");
return std::string{}; // unreachable
}
}(static_cast<T*>(nullptr));
}
// Concept to check if a type is a ck::Tuple
template <typename T>
concept IsCkTuple =
requires { []<typename... Ts>(ck::Tuple<Ts...>*) {}(static_cast<T*>(nullptr)); };
// Concept to check if a type is a ck_tile::tuple
template <typename T>
concept IsCkTileTuple =
requires { []<typename... Ts>(ck_tile::tuple<Ts...>*) {}(static_cast<T*>(nullptr)); };
// Deduces whether to use tuple_name or type_name
// Handles both scalar data types and ck::Tuple types
template <typename T>
constexpr std::string type_or_type_tuple_name()
{
if constexpr(IsCkTuple<T>)
if constexpr(IsCkTuple<T> || IsCkTileTuple<T>)
{
return tuple_name<T>();
}