mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[CK_BUILDER] Add bwd weight factories (#3509)
* Add placeholder test. * Initial conv bwd weight factory. * Conv builder test refactoring. * Add missing pieces to bwd weight factory. * Improve compile time erros message when no matching factory is found. * Use amcro to ensure automatic macthing between concepts are their string representations. * Improve compile time diagnostics. * Small improvements. * Improve missing member/wrong type compile-time errors. * Improve compile time diagnostics. * Concept bug fixes. * Remove debug assert. * Update algorithm signature diagnostics. * Factory bug fixes. * First functional version of bwd weight conv factory. * Refactor handing of GEMM-K batch template parameter in conv bwd weight factory. * Concept improvements. * Improve concept diagnostics. * Introduve a common size type for concepts. * Update compiletime diagnostics to use the size type. * Update conv specialization enum. * Fix fwd conv builder tests. * Fix smoke tests. * Separate bwd weigth and bwd data tests into separate targets. * Clean-up CK Tile builder tests. * Add bwd weight XDL CShuffle V3 factory. * Build conv bwd weigth v3 instances successfully. * Add instance traits for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3. * Test fix. * Add instance traits for bwd weight algorithms. * Add unit tests for instance strings. * Build new instance traits unit tests but exclude WMMA for now. * Added factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle. * Conv bwd weight DL factory. * Final implementation for bwd weight DL factory. * Add test for creating DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance. * Add factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle * Treat ref algorithm the same way as real algorithms in the dispatcher. * Refactor large tensor support and WMMA configuration. * Add factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffleV3. * Update Readme. * Fix WMMA bwd weight tests. * Added factory and tests for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3. * Factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffle. * Dispatching for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle. * Add factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 * Fix DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 factory and compute types for input and output tensor in bwd weigth convs. * Fix fwd factories after refactoring. * clang-format * Move compile-time diagnostics to a separate branch. * Fix ref algorithm dispatching. * Fix smoke tests. * clang-format * Fix factory for regular WMMA conv bwd weight. * Clarify builder Readme. * Remove obsolete test file. * Fix test after merge. * clang-format * Remove the C++26 extensions. * Unify conv elementwise ops and layout definitions for fwd and bwd directions. * Remove old layout and elementwise ops. * Unify handling of conv tensor types between fwd and bwd directions. * Unify block transfer for fwd and bwd directions. Rename ThreadSliceDim to ThreadClusterRank. * Make BlockTransferDescriptor concept parametrized. Introduce a common TileTransferParameters concept for conv algorithms. * clang-format --------- Co-authored-by: Ville Pietilä <>
This commit is contained in:
@@ -8,26 +8,27 @@ namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_DATA,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
constexpr ConvSignature BwdDataConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_DATA,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
constexpr auto BwdDataConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_thread_block(TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(FwdTileTransfer_4x4x4)
|
||||
.with_tile_transfer(TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
using Builder = ConvBuilder<BwdDataConvSignature, BwdDataConvAlgorithm>;
|
||||
run_ck_tile_test<Builder>({
|
||||
"grouped_convolution_backward_data",
|
||||
"fp16",
|
||||
|
||||
@@ -8,26 +8,27 @@ namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
constexpr ConvSignature BwdWeightConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
constexpr auto BwdWeightConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_thread_block(TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(FwdTileTransfer_4x4x4)
|
||||
.with_tile_transfer(TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
using Builder = ConvBuilder<BwdWeightConvSignature, BwdWeightConvAlgorithm>;
|
||||
run_ck_tile_test<Builder>({
|
||||
"grouped_convolution_backward_weight",
|
||||
"fp16",
|
||||
|
||||
@@ -21,9 +21,9 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_thread_block(TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(FwdTileTransfer_4x4x4)
|
||||
.with_tile_transfer(TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user