mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_BUILDER] Ck Tile Grouped convolution factory (#3352)
* [BUILDER] Ck Tile Grouped convolution factory * Part 2 * Fixes after rebase * Remove leftovers
This commit is contained in:
@@ -243,6 +243,73 @@ struct LargeTensorWrapper
|
||||
ConvAlgorithmSpecialization::LARGE_TENSOR;
|
||||
};
|
||||
|
||||
// Specify thread block dimensions for a GEMM (CK Tile).
|
||||
struct TileThreadBlock
|
||||
{
|
||||
// Size of the submatrix problem in a thread block.
|
||||
MNK<size_t> tile_size;
|
||||
};
|
||||
static_assert(ckb::TileThreadBlockDescriptor<TileThreadBlock>);
|
||||
|
||||
struct TileTransfer
|
||||
{
|
||||
size_t a_scalar_per_vector;
|
||||
size_t b_scalar_per_vector;
|
||||
size_t c_scalar_per_vector;
|
||||
};
|
||||
static_assert(ckb::TileTransferDescriptor<TileTransfer>);
|
||||
|
||||
struct TileBlockGemm
|
||||
{
|
||||
// Number of warps per each dimension.
|
||||
MNK<int> warps;
|
||||
// Number of data processed per each dimension for each XDL/WMMA instruction.
|
||||
MNK<int> warp_tile;
|
||||
// Double LDS buffer.
|
||||
bool double_smem_buffer;
|
||||
// Waves grouping (Ping-Pong scheduler).
|
||||
int num_wave_groups;
|
||||
PipelineVersion pipeline_version;
|
||||
PipelineScheduler scheduler;
|
||||
};
|
||||
static_assert(ckb::TileBlockGemmDescriptor<TileBlockGemm>);
|
||||
|
||||
struct TileOptimizations
|
||||
{
|
||||
// Number of convolution groups processed per one workgroup
|
||||
int num_groups_to_merge;
|
||||
// Split image for large tensors
|
||||
bool split_image;
|
||||
// Explicit gemm for 1x1, stride=0, pad=0 cases
|
||||
bool explicit_gemm;
|
||||
};
|
||||
static_assert(ckb::TileOptimizationsDescriptor<TileOptimizations>);
|
||||
|
||||
struct TileConvSpecialization_
|
||||
{
|
||||
TileConvSpecialization specialization;
|
||||
};
|
||||
|
||||
struct TileThreadBlock_
|
||||
{
|
||||
TileThreadBlock thread_block;
|
||||
};
|
||||
|
||||
struct TileTransfer_
|
||||
{
|
||||
TileTransfer transfer;
|
||||
};
|
||||
|
||||
struct TileBlockGemm_
|
||||
{
|
||||
TileBlockGemm block_gemm;
|
||||
};
|
||||
|
||||
struct TileOptimizations_
|
||||
{
|
||||
TileOptimizations optimizations;
|
||||
};
|
||||
|
||||
// Factory
|
||||
|
||||
template <typename... Components>
|
||||
@@ -339,6 +406,51 @@ struct ConvAlgorithmTemplate : Components...
|
||||
result.transfer = t;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
constexpr auto with_tile_specializations(const S& s) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TileConvSpecialization_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.specialization = s;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename TB>
|
||||
constexpr auto with_tile_thread_block(const TB& tb) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TileThreadBlock_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.thread_block = tb;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BG>
|
||||
constexpr auto with_tile_block_gemm(const BG& bg) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TileBlockGemm_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.block_gemm = bg;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr auto with_tile_transfer(const T& t) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TileTransfer_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.transfer = t;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
constexpr auto with_tile_optimizations(const O& o) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TileOptimizations_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.optimizations = o;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// Algorithm types
|
||||
@@ -361,4 +473,10 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
|
||||
|
||||
using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileThreadBlock_,
|
||||
TileBlockGemm_,
|
||||
TileTransfer_,
|
||||
TileConvSpecialization_,
|
||||
TileOptimizations_>;
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
Reference in New Issue
Block a user