[CK_BUILDER] Ck Tile Grouped convolution factory (#3352)

* [BUILDER] Ck Tile Grouped convolution factory

* Part 2

* Fixes after rebase

* Remove leftovers
This commit is contained in:
Bartłomiej Kocot
2025-12-08 10:32:56 +01:00
committed by GitHub
parent 8fec8054b2
commit 04612c30ce
55 changed files with 1431 additions and 92 deletions

View File

@@ -243,6 +243,73 @@ struct LargeTensorWrapper
ConvAlgorithmSpecialization::LARGE_TENSOR;
};
// Specify thread block dimensions for a GEMM (CK Tile).
struct TileThreadBlock
{
// Size of the submatrix problem in a thread block.
MNK<size_t> tile_size;
};
static_assert(ckb::TileThreadBlockDescriptor<TileThreadBlock>);
struct TileTransfer
{
size_t a_scalar_per_vector;
size_t b_scalar_per_vector;
size_t c_scalar_per_vector;
};
static_assert(ckb::TileTransferDescriptor<TileTransfer>);
struct TileBlockGemm
{
// Number of warps per each dimension.
MNK<int> warps;
// Number of data processed per each dimension for each XDL/WMMA instruction.
MNK<int> warp_tile;
// Double LDS buffer.
bool double_smem_buffer;
// Waves grouping (Ping-Pong scheduler).
int num_wave_groups;
PipelineVersion pipeline_version;
PipelineScheduler scheduler;
};
static_assert(ckb::TileBlockGemmDescriptor<TileBlockGemm>);
struct TileOptimizations
{
// Number of convolution groups processed per one workgroup
int num_groups_to_merge;
// Split image for large tensors
bool split_image;
// Explicit gemm for 1x1, stride=0, pad=0 cases
bool explicit_gemm;
};
static_assert(ckb::TileOptimizationsDescriptor<TileOptimizations>);
struct TileConvSpecialization_
{
TileConvSpecialization specialization;
};
struct TileThreadBlock_
{
TileThreadBlock thread_block;
};
struct TileTransfer_
{
TileTransfer transfer;
};
struct TileBlockGemm_
{
TileBlockGemm block_gemm;
};
struct TileOptimizations_
{
TileOptimizations optimizations;
};
// Factory
template <typename... Components>
@@ -339,6 +406,51 @@ struct ConvAlgorithmTemplate : Components...
result.transfer = t;
return result;
}
template <typename S>
constexpr auto with_tile_specializations(const S& s) const
{
static_assert(std::is_base_of_v<TileConvSpecialization_, ConvAlgorithmTemplate>);
auto result = *this;
result.specialization = s;
return result;
}
template <typename TB>
constexpr auto with_tile_thread_block(const TB& tb) const
{
static_assert(std::is_base_of_v<TileThreadBlock_, ConvAlgorithmTemplate>);
auto result = *this;
result.thread_block = tb;
return result;
}
template <typename BG>
constexpr auto with_tile_block_gemm(const BG& bg) const
{
static_assert(std::is_base_of_v<TileBlockGemm_, ConvAlgorithmTemplate>);
auto result = *this;
result.block_gemm = bg;
return result;
}
template <typename T>
constexpr auto with_tile_transfer(const T& t) const
{
static_assert(std::is_base_of_v<TileTransfer_, ConvAlgorithmTemplate>);
auto result = *this;
result.transfer = t;
return result;
}
template <typename O>
constexpr auto with_tile_optimizations(const O& o) const
{
static_assert(std::is_base_of_v<TileOptimizations_, ConvAlgorithmTemplate>);
auto result = *this;
result.optimizations = o;
return result;
}
};
// Algorithm types
@@ -361,4 +473,10 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileThreadBlock_,
TileBlockGemm_,
TileTransfer_,
TileConvSpecialization_,
TileOptimizations_>;
} // namespace ck_tile::builder::test