[CK_BUILDER] Add grouped conv fwd ck tile profiler (#3518)

* [BULDER] Add grouped conv fwd ck tile profiler

* [CK TILE] Fix grouped conv kernels splitk and double lds

* Updates

* Fixes

* Move to ckProfiler

* Fixes

* fix

* fix

* Change instances to empty list by default

* fix

* fix

* Update grouped_convolution_signatures.hpp

* Update grouped_convolution_forward_tile_algs.hpp

* [CK TILE] Add grouped convolution forward tests (#3556)

* [CK TILE] Add grouped convolution forward tests

* fix jenkins

* fixes

* comments fixes

* unit test

* unit test fix

* Move instances outside builder

* fix includes

* clang format fix

* readme fix

* fix includes

* fixes
This commit is contained in:
Bartłomiej Kocot
2026-01-20 06:29:01 +01:00
committed by GitHub
parent 0517d43d31
commit 0727e85e52
44 changed files with 3083 additions and 65 deletions

View File

@@ -90,6 +90,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
static constexpr bool Preshuffle = Problem::Preshuffle;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t kLdsAlignmentInBytes = 16;

View File

@@ -23,6 +23,18 @@ using WarpGemmMfmaF32F32F32M16N16K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterate
4,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M16N16K8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M32N32K8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF32F32F32M32N32K2<WGAttrCtlEnum::Default_>,
4,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<

View File

@@ -34,6 +34,8 @@ struct Dispatcher;
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct Dispatcher<float, float, float, 16, 16, 4, false> { using Type = WarpGemmMfmaF32F32F32M16N16K4; };
template<> struct Dispatcher<float, float, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; };
template<> struct Dispatcher<float, float, float, 16, 16, 8, false> { using Type = WarpGemmMfmaF32F32F32M16N16K8<>; };
template<> struct Dispatcher<float, float, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF32F32F32M32N32K8<>; };
template<> struct Dispatcher<float, float, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
// fp16
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity

View File

@@ -723,8 +723,11 @@ struct GroupedConvolutionForwardKernel
if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
CK_TILE_ERROR(
"Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
}
return false;
}
@@ -736,13 +739,19 @@ struct GroupedConvolutionForwardKernel
// Check access per C
if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0)
{
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
}
return false;
}
}
else
{
CK_TILE_ERROR("Not supported input layout!");
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Not supported input layout!");
}
return false;
}
@@ -754,13 +763,19 @@ struct GroupedConvolutionForwardKernel
{
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
{
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
}
return false;
}
}
else
{
CK_TILE_ERROR("Not supported weight layout!");
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Not supported weight layout!");
}
return false;
}
@@ -771,13 +786,20 @@ struct GroupedConvolutionForwardKernel
{
if(ConvK % GroupedConvTraitsType_::VectorSizeC != 0)
{
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Conv K is not a multiple of vector store size for output image!");
}
return false;
}
}
else
{
CK_TILE_ERROR("Not supported output layout!");
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Not supported output layout!");
}
return false;
}
@@ -786,7 +808,10 @@ struct GroupedConvolutionForwardKernel
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
{
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
}
return false;
}
}
@@ -955,7 +980,8 @@ struct GroupedConvolutionForwardKernel
else
{
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
is_any_of<OutDataType, fp16_t, bf16_t>::value) &&
IsSplitKSupported)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, c_desc, block_idx_m, block_idx_n);