mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_BUILDER] Add grouped conv fwd ck tile traits (#3183)
* [CK BUILDER] Add grouped conv fwd ck tile traits * Update instance_traits_tile_grouped_convolution_forward.hpp * Update grouped_convolution_forward_kernel.hpp
This commit is contained in:
@@ -15,6 +15,9 @@
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/utility/loop_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
|
||||
#include <ck_tile/ops/gemm.hpp>
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include <ck_tile/ops/grouped_convolution.hpp>
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
#include "instance_traits_util.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// InstanceTraits specialization for GroupedConvolutionForwardKernel
|
||||
//
|
||||
// CRITICAL MAINTENANCE NOTE:
|
||||
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
|
||||
// ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp
|
||||
// "In sync" means that the template parameter order, names, and types in the declaration below
|
||||
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
|
||||
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
|
||||
// difficult to diagnose. Always update both files together and review changes carefully.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
#include "instance_traits_util.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
namespace ck_tile::device {
|
||||
|
||||
template <typename GroupedConvTraitsType_,
|
||||
typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionForwardKernel;
|
||||
|
||||
} // namespace ck_tile::device
|
||||
|
||||
namespace ck_tile {
|
||||
namespace reflect {
|
||||
|
||||
// Specialization for GroupedConvolutionForwardKernel
|
||||
template <typename GroupedConvTraitsType_,
|
||||
typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_>
|
||||
struct InstanceTraits<ck_tile::device::GroupedConvolutionForwardKernel<GroupedConvTraitsType_,
|
||||
TilePartitioner_,
|
||||
GemmPipeline_,
|
||||
EpiloguePipeline_>>
|
||||
{
|
||||
// CK Tile Conv Traits
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = GroupedConvTraitsType_::NDimSpatial;
|
||||
// Specialization
|
||||
static constexpr ck_tile::ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType_::ConvSpecialization;
|
||||
// DataType types
|
||||
using InLayout = typename GroupedConvTraitsType_::InLayout;
|
||||
using WeiLayout = typename GroupedConvTraitsType_::WeiLayout;
|
||||
using DsLayout = typename GroupedConvTraitsType_::DsLayout;
|
||||
using OutLayout = typename GroupedConvTraitsType_::OutLayout;
|
||||
// Vector size
|
||||
static constexpr int kVectorSizeA = GroupedConvTraitsType_::VectorSizeA;
|
||||
static constexpr int kVectorSizeB = GroupedConvTraitsType_::VectorSizeB;
|
||||
static constexpr int kVectorSizeC = GroupedConvTraitsType_::VectorSizeC;
|
||||
// Num Groups To Merge
|
||||
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// Split image (large tensors)
|
||||
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
|
||||
// TilePartitioner
|
||||
// Block configuration
|
||||
static constexpr int kMPerBlock = TilePartitioner_::MPerBlock;
|
||||
static constexpr int kNPerBlock = TilePartitioner_::NPerBlock;
|
||||
static constexpr int kKPerBlock = TilePartitioner_::KPerBlock;
|
||||
|
||||
static constexpr int kMWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<0>{});
|
||||
static constexpr int kNWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<1>{});
|
||||
static constexpr int kKWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<2>{});
|
||||
|
||||
static constexpr int kMWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<0>{});
|
||||
static constexpr int kNWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<1>{});
|
||||
static constexpr int kKWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
|
||||
// Data types
|
||||
using ADataType = typename GemmPipeline_::ADataType;
|
||||
using BDataType = typename GemmPipeline_::BDataType;
|
||||
// Gemm Pipeline
|
||||
using GemmPipeline = GemmPipeline_;
|
||||
static constexpr ck_tile::GemmPipelineScheduler kPipelineScheduler = GemmPipeline_::Scheduler;
|
||||
static constexpr bool kDoubleSmemBuffer = GemmPipeline_::DoubleSmemBuffer;
|
||||
static constexpr int kNumWaveGroups = GemmPipeline_::NumWaveGroups;
|
||||
|
||||
// Epilogue Pipeline
|
||||
using AccDataType = typename EpiloguePipeline_::AccDataType;
|
||||
using EDataType = typename EpiloguePipeline_::ODataType;
|
||||
using DsDataType = typename EpiloguePipeline_::DsDataType;
|
||||
using CDEElementwiseOperation = typename EpiloguePipeline_::CDElementwise;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
// Kernel type name
|
||||
oss << "GroupedConvolutionForwardKernel";
|
||||
|
||||
// Template parameters in exact order matching InstanceTraits member order
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << ","
|
||||
<< ck_tile::getConvSpecializationString(ConvSpecialization); // 2. ConvSpecialization
|
||||
oss << "," << detail::layout_name<InLayout>(); // 3. InLayout
|
||||
oss << "," << detail::layout_name<WeiLayout>(); // 4. WeiLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 5. DsLayout
|
||||
oss << "," << detail::layout_name<OutLayout>(); // 6. OutLayout
|
||||
oss << "," << kVectorSizeA; // 7. VectorSizeA
|
||||
oss << "," << kVectorSizeB; // 8. VectorSizeB
|
||||
oss << "," << kVectorSizeC; // 9. VectorSizeC
|
||||
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
|
||||
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
|
||||
oss << "," << kMPerBlock; // 12. MPerBlock
|
||||
oss << "," << kNPerBlock; // 13. NPerBlock
|
||||
oss << "," << kKPerBlock; // 14. KPerBlock
|
||||
oss << "," << kMWarp; // 15. MWarp
|
||||
oss << "," << kNWarp; // 16. NWarp
|
||||
oss << "," << kKWarp; // 17. KWarp
|
||||
oss << "," << kMWarpTile; // 18. MWarpTile
|
||||
oss << "," << kNWarpTile; // 19. NWarpTile
|
||||
oss << "," << kKWarpTile; // 20. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 25. NumWaveGroups
|
||||
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
|
||||
// CDEElementwiseOperation
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace reflect
|
||||
} // namespace ck_tile
|
||||
@@ -28,6 +28,10 @@
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
#include <ck_tile/ops/gemm.hpp>
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
|
||||
namespace ck_tile::reflect::detail {
|
||||
|
||||
@@ -38,7 +42,7 @@ namespace impl {
|
||||
template <typename T>
|
||||
consteval std::string_view type_name_impl()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck::half_t>)
|
||||
if constexpr(std::is_same_v<T, ck::half_t> || std::is_same_v<T, ck_tile::half_t>)
|
||||
return "fp16";
|
||||
else if constexpr(std::is_same_v<T, float>)
|
||||
return "fp32";
|
||||
@@ -50,11 +54,11 @@ consteval std::string_view type_name_impl()
|
||||
return "s8";
|
||||
else if constexpr(std::is_same_v<T, int32_t>)
|
||||
return "s32";
|
||||
else if constexpr(std::is_same_v<T, ck::bhalf_t>)
|
||||
else if constexpr(std::is_same_v<T, ck::bhalf_t> || std::is_same_v<T, ck_tile::bf16_t>)
|
||||
return "bf16";
|
||||
else if constexpr(std::is_same_v<T, ck::f8_t>)
|
||||
else if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck_tile::fp8_t>)
|
||||
return "fp8";
|
||||
else if constexpr(std::is_same_v<T, ck::bf8_t>)
|
||||
else if constexpr(std::is_same_v<T, ck::bf8_t> || std::is_same_v<T, ck_tile::bf8_t>)
|
||||
return "bf8";
|
||||
else
|
||||
return std::string_view{}; // Return empty for supported types
|
||||
@@ -168,6 +172,17 @@ constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineSchedule
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view pipeline_scheduler_name(ck_tile::GemmPipelineScheduler sched)
|
||||
{
|
||||
using enum ck_tile::GemmPipelineScheduler;
|
||||
switch(sched)
|
||||
{
|
||||
case Default: return "Default";
|
||||
case Intrawave: return "Intrawave";
|
||||
case Interwave: return "Interwave";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BlockGemmPipelineVersion enum to string
|
||||
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
|
||||
{
|
||||
@@ -206,6 +221,26 @@ constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert TailNumber enum to string
|
||||
constexpr std::string_view tail_number_name(ck_tile::TailNumber tail_num)
|
||||
{
|
||||
using enum ck_tile::TailNumber;
|
||||
switch(tail_num)
|
||||
{
|
||||
case Odd: return "Odd";
|
||||
case Even: return "Even";
|
||||
case One: return "One";
|
||||
case Two: return "Two";
|
||||
case Three: return "Three";
|
||||
case Four: return "Four";
|
||||
case Five: return "Five";
|
||||
case Six: return "Six";
|
||||
case Seven: return "Seven";
|
||||
case Empty: return "Empty";
|
||||
case Full: return "Full";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert std::array to string
|
||||
template <typename T, std::size_t N>
|
||||
inline std::string array_to_string(const std::array<T, N>& arr)
|
||||
@@ -356,17 +391,53 @@ constexpr std::string tuple_name()
|
||||
}(static_cast<T*>(nullptr));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
requires requires { []<typename... Ts>(ck_tile::tuple<Ts...>*) {}(static_cast<T*>(nullptr)); }
|
||||
constexpr std::string tuple_name()
|
||||
{
|
||||
return []<typename... Ts>(ck_tile::tuple<Ts...>*) constexpr {
|
||||
if constexpr(sizeof...(Ts) == 0)
|
||||
{
|
||||
return std::string("EmptyTuple");
|
||||
}
|
||||
else if constexpr((IsLayoutType<Ts> && ...))
|
||||
{
|
||||
// Lambda wrapper for layout_name
|
||||
auto layout_name_fn = []<typename U>() { return layout_name<U>(); };
|
||||
return detail::build_list_string<decltype(layout_name_fn), Ts...>("tuple",
|
||||
layout_name_fn);
|
||||
}
|
||||
else if constexpr((IsDataType<Ts> && ...))
|
||||
{
|
||||
// Lambda wrapper for type_name
|
||||
auto type_name_fn = []<typename U>() { return type_name<U>(); };
|
||||
return detail::build_list_string<decltype(type_name_fn), Ts...>("tuple", type_name_fn);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert((IsLayoutType<Ts> && ...) || (IsDataType<Ts> && ...),
|
||||
"tuple elements must be all layouts or all data types, not mixed");
|
||||
return std::string{}; // unreachable
|
||||
}
|
||||
}(static_cast<T*>(nullptr));
|
||||
}
|
||||
|
||||
// Concept to check if a type is a ck::Tuple
|
||||
template <typename T>
|
||||
concept IsCkTuple =
|
||||
requires { []<typename... Ts>(ck::Tuple<Ts...>*) {}(static_cast<T*>(nullptr)); };
|
||||
|
||||
// Concept to check if a type is a ck_tile::tuple
|
||||
template <typename T>
|
||||
concept IsCkTileTuple =
|
||||
requires { []<typename... Ts>(ck_tile::tuple<Ts...>*) {}(static_cast<T*>(nullptr)); };
|
||||
|
||||
// Deduces whether to use tuple_name or type_name
|
||||
// Handles both scalar data types and ck::Tuple types
|
||||
template <typename T>
|
||||
constexpr std::string type_or_type_tuple_name()
|
||||
{
|
||||
if constexpr(IsCkTuple<T>)
|
||||
if constexpr(IsCkTuple<T> || IsCkTileTuple<T>)
|
||||
{
|
||||
return tuple_name<T>();
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -720,4 +721,126 @@ TEST(InstanceTraits, DlInstanceStringReturnsCorrectFormat)
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
||||
{
|
||||
using GroupedConvTraitsType =
|
||||
ck_tile::GroupedConvTraits<2 /*NDimSpatial*/,
|
||||
ck_tile::ConvolutionSpecialization::Default /*ConvSpec*/,
|
||||
ck_tile::tensor_layout::convolution::NHWGC /*InLayout*/,
|
||||
ck_tile::tensor_layout::convolution::GKYXC /*WeiLayout*/,
|
||||
ck_tile::tuple<> /*DsLayout*/,
|
||||
ck_tile::tensor_layout::convolution::NHWGK /*OutLayout*/,
|
||||
4 /*VectorSizeA*/,
|
||||
4 /*VectorSizeB*/,
|
||||
4 /*VectorSizeC*/,
|
||||
1 /*NumGroupsToMerge*/,
|
||||
false /*EnableSplitImage*/>;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
|
||||
ck_tile::sequence<4 /*M_Warp*/, 1 /*N_Warp*/, 1 /*K_Warp*/>,
|
||||
ck_tile::sequence<16 /*M_Warp_Tile*/, 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
||||
false /*DoubleSmemBuffer*/,
|
||||
typename GroupedConvTraitsType::AsLayoutFwd,
|
||||
typename GroupedConvTraitsType::BsLayoutFwd,
|
||||
typename GroupedConvTraitsType::CLayoutFwd,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
1 /*NumWaveGroups*/>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
ck_tile::bf16_t /*InDataType*/,
|
||||
ck_tile::bf16_t /*WeiDataType*/,
|
||||
float /*AccDataType*/,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/,
|
||||
true /*has_hot_loop_v*/,
|
||||
ck_tile::TailNumber::Full /*tail_number_v*/,
|
||||
ck_tile::element_wise::PassThrough /*AElementwiseOperation*/,
|
||||
ck_tile::element_wise::PassThrough /*BElementwiseOperation*/,
|
||||
ck_tile::bf16_t /*OutDataType*/,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ck_tile::bf16_t /*InDataType*/,
|
||||
ck_tile::bf16_t /*WeiDataType*/,
|
||||
ck_tile::tuple<> /*DsDataType*/,
|
||||
float /*AccDataType*/,
|
||||
ck_tile::bf16_t /*OutDataType*/,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
ck_tile::element_wise::PassThrough /*CDElementWise*/,
|
||||
128 /*MPerBlock*/,
|
||||
128 /*NPerBlock*/,
|
||||
4 /*M_Warp*/,
|
||||
1 /*N_Warp*/,
|
||||
16 /*M_Warp_Tile*/,
|
||||
16 /*N_Warp_Tile*/,
|
||||
16 /*K_Warp_Tile*/,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
ck_tile::memory_operation_enum::set /*memory_operation*/,
|
||||
1 /*kNumWaveGroups*/,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using GroupedConvFwdKernel =
|
||||
ck_tile::device::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
|
||||
std::string instance_str = ck_tile::reflect::instance_string<GroupedConvFwdKernel>();
|
||||
|
||||
std::string expected_str = "GroupedConvolutionForwardKernel"
|
||||
"<2" // NDimSpatial
|
||||
",Default" // ConvSpecialization
|
||||
",NHWGC" // InLayout
|
||||
",GKYXC" // WeiLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",NHWGK" // OutLayout
|
||||
",4" // VectorSizeA
|
||||
",4" // VectorSizeB
|
||||
",4" // VectorSizeC
|
||||
",1" // NumGroupsToMerge
|
||||
",0" // EnableSplitImage
|
||||
",128" // MPerBlock
|
||||
",128" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
",4" // MWarp
|
||||
",1" // NWarp
|
||||
",1" // KWarp
|
||||
",16" // MWarpTile
|
||||
",16" // NWarpTile
|
||||
",16" // KWarpTile
|
||||
",bf16" // ADataType
|
||||
",bf16" // BDataType
|
||||
",COMPUTE_V3" // BlkGemmPipelineVer
|
||||
",Intrawave" // BlkGemmPipeSched
|
||||
",0" // DoubleSmemBuffer
|
||||
",1" // NumWaveGroups
|
||||
",fp32" // AccDataType
|
||||
",bf16" // EDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
">";
|
||||
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
10
include/ck_tile/core/arch/arch.hpp
Executable file → Normal file
10
include/ck_tile/core/arch/arch.hpp
Executable file → Normal file
@@ -299,12 +299,12 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0)
|
||||
#endif
|
||||
}
|
||||
|
||||
#define CK_CONSTANT_ADDRESS_SPACE \
|
||||
__attribute__((address_space( \
|
||||
#define CK_TILE_CONSTANT_ADDRESS_SPACE \
|
||||
__attribute__((address_space( \
|
||||
static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
|
||||
|
||||
template <typename T>
|
||||
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p)
|
||||
__device__ T* cast_pointer_to_generic_address_space(T CK_TILE_CONSTANT_ADDRESS_SPACE* p)
|
||||
{
|
||||
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
|
||||
// only c-style pointer cast seems be able to be compiled
|
||||
@@ -315,13 +315,13 @@ __device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE*
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
|
||||
__host__ __device__ T CK_TILE_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
|
||||
{
|
||||
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
|
||||
// only c-style pointer cast seems be able to be compiled;
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
|
||||
return (T CK_TILE_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
|
||||
@@ -190,7 +190,7 @@ struct GroupedGemmKernel
|
||||
*/
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
|
||||
using ConstantPointer = const void CK_TILE_CONSTANT_ADDRESS_SPACE*;
|
||||
const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>;
|
||||
int occupancy;
|
||||
HIP_CHECK_ERROR(
|
||||
@@ -518,7 +518,7 @@ struct GroupedGemmKernel
|
||||
|
||||
// For non-persistent kernels
|
||||
template <bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
|
||||
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
index_t group_count) const
|
||||
{
|
||||
const index_t block_id = ck_tile::get_block_1d_id();
|
||||
@@ -541,7 +541,7 @@ struct GroupedGemmKernel
|
||||
template <bool U = UsePersistentKernel,
|
||||
typename = std::enable_if_t<U>,
|
||||
typename = void> // extra template parameter to avoid redefinition
|
||||
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t group_count) const
|
||||
{
|
||||
const index_t grid_size = ck_tile::get_grid_size();
|
||||
|
||||
@@ -164,6 +164,13 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "COMPUTE_ASYNC";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
|
||||
@@ -170,6 +170,13 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
using Base::PrefetchStages;
|
||||
using Base::UsePersistentKernel;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "COMPUTE_V3";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -172,6 +172,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "COMPUTE_V4";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -99,6 +99,13 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
static constexpr index_t NumWarps = BlockGemmShape::NumWarps;
|
||||
static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{});
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "COMPUTE_V5";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -159,6 +159,13 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<BasePImpl::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<BasePImpl::is_b_load_tr>{};
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "COMPUTE_V6";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -214,6 +214,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "MEMORY";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -70,6 +70,13 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "BASIC_V1";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -70,6 +70,13 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "BASIC_V2";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -176,6 +176,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
|
||||
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "PRESHUFFLE_V2";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -208,7 +208,7 @@ struct QuantGroupedGemmKernel
|
||||
*/
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
|
||||
using ConstantPointer = const void CK_TILE_CONSTANT_ADDRESS_SPACE*;
|
||||
const auto kernel_func = kentry<1, Kernel, ConstantPointer, index_t>;
|
||||
int occupancy;
|
||||
HIP_CHECK_ERROR(
|
||||
@@ -499,7 +499,7 @@ struct QuantGroupedGemmKernel
|
||||
template <bool U = UsePersistentKernel,
|
||||
typename = std::enable_if_t<U>,
|
||||
typename = void> // extra template parameter to avoid redefinition
|
||||
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t group_count) const
|
||||
{
|
||||
const index_t grid_size = ck_tile::get_grid_size();
|
||||
|
||||
@@ -16,6 +16,10 @@
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp"
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The Grouped Convolution kernel device arguments.
|
||||
@@ -568,6 +572,19 @@ struct GroupedConvolutionForwardKernel
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
CK_TILE_HOST std::string GetInstanceString() const
|
||||
{
|
||||
static_assert(ck_tile::reflect::HasInstanceTraits<GroupedConvolutionForwardKernel>,
|
||||
"Specialization of instance_traits not found. Please check that a "
|
||||
"specialization exists in file "
|
||||
"ck_tile/builder/reflect/"
|
||||
"instance_traits_tile_grouped_convolution_forward.hpp "
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<GroupedConvolutionForwardKernel>();
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST static auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
{
|
||||
return dim3(
|
||||
|
||||
Reference in New Issue
Block a user