[rocm-libraries] ROCm/rocm-libraries#5237 (commit ef10dc6)

[CK_TILE, CK_BUILDER] Add two-stage bwd weight kernels to CK
 Tile profiler (#5237)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

PR #4797 added CK Tile bwd weight kernels to the CK Profiler. The
two-stage kernels were not supported in the initial PR. This PR adds the
the missing bwd weight two-stage kernels to the CK Profiler.

## Technical Details

Extended the CK Tile conv builder factory to build also the elementwise
ops required for the two-stage kernels. Extended the CK Builder for CK
Tile instance to accept the two-stage flag as part of the algorithm
configuration.

## Test Plan

Added units tests for CK Builder that verify the two-stage kernel
construction.

## Test Result

If CI passes, the added unit tests are passing.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Ville Pietilä
2026-03-13 01:21:08 +00:00
committed by assistant-librarian[bot]
parent fc2f95620d
commit e2f5ab8000
16 changed files with 336 additions and 50 deletions

View File

@@ -25,8 +25,10 @@ TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D
.with_tile_thread_block(TileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(TileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
.with_tile_optimizations(TileOptimizations{.num_groups_to_merge = 1,
.split_image = false,
.explicit_gemm = false,
.two_stage = false});
using Builder = ConvBuilder<BwdDataConvSignature, BwdDataConvAlgorithm>;
run_ck_tile_test<Builder>({

View File

@@ -12,6 +12,7 @@
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
namespace ckf = ck_tile::builder::factory;
using enum ck_tile::builder::TensorLayout;
using ck_tile::test::MatchesReference;
@@ -31,12 +32,49 @@ constexpr auto ALGORITHM =
.with_tile_thread_block(cku::TileThreadBlock_64x64x64)
.with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(cku::TileTransfer_4x4x4)
.with_tile_optimizations(ckt::TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
.with_tile_optimizations(ckt::TileOptimizations{.num_groups_to_merge = 1,
.split_image = false,
.explicit_gemm = false,
.two_stage = false});
constexpr auto TWO_STAGE_ALGORITHM =
cku::ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(ckb::TileConvSpecialization::DEFAULT)
.with_tile_thread_block(cku::TileThreadBlock_64x64x64)
.with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(cku::TileTransfer_4x4x4)
.with_tile_optimizations(ckt::TileOptimizations{.num_groups_to_merge = 1,
.split_image = false,
.explicit_gemm = false,
.two_stage = true});
constexpr ckt::Args<SIGNATURE> Args = {
.lengths =
{
.batch_size = 2,
.groups = 4,
.input_channels = 32,
.output_channels = 48,
.image = {.width = 32, .height = 56},
.filter = {.width = 3, .height = 3},
},
.filter_strides = {.width = 1, .height = 1},
.filter_dilation = {.width = 1, .height = 1},
.input_left_pad = {.width = 0, .height = 0},
.input_right_pad = {.width = 0, .height = 0},
.a_elementwise_op = {},
.b_elementwise_op = {},
.cde_elementwise_op = {},
};
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
using TwoStageBuilder = ckb::ConvBuilder<SIGNATURE, TWO_STAGE_ALGORITHM>;
using TwoStageInstance = TwoStageBuilder::Instance;
using ElementwiseOpBuilder = ckf::ElementwiseOpTileFactory<SIGNATURE, TWO_STAGE_ALGORITHM>;
using ElementwiseOpInstance = ElementwiseOpBuilder::Instance;
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
TEST(BwdWeight_2D_FP16_NHWGC, Create)
@@ -61,38 +99,47 @@ TEST(BwdWeight_2D_FP16_NHWGC, Create)
});
}
TEST(ElementWiseOp, CreateBwdWeightTwoStageElementwiseOp)
{
cku::run_ck_tile_test<ElementwiseOpBuilder>({"elementwise_kernel",
"4096_256_4_4_64_4_256",
"UnaryConvert",
"kPad_1",
"ElementWiseDefaultPolicy"});
}
TEST(BwdWeight_2D_FP16_NHWGC, Execution)
{
ckt::Args<SIGNATURE> args = {
.lengths =
{
.batch_size = 2,
.groups = 4,
.input_channels = 32,
.output_channels = 48,
.image = {.width = 32, .height = 56},
.filter = {.width = 3, .height = 3},
},
.filter_strides = {.width = 1, .height = 1},
.filter_dilation = {.width = 1, .height = 1},
.input_left_pad = {.width = 0, .height = 0},
.input_right_pad = {.width = 0, .height = 0},
.a_elementwise_op = {},
.b_elementwise_op = {},
.cde_elementwise_op = {},
};
auto inputs = ckt::alloc_inputs(Args);
auto outputs = ckt::alloc_outputs(Args);
auto reference = ckt::alloc_outputs(Args);
auto inputs = ckt::alloc_inputs(args);
auto outputs = ckt::alloc_outputs(args);
auto reference = ckt::alloc_outputs(args);
ckt::init_inputs(args, inputs.get());
ckt::init_inputs(Args, inputs.get());
auto conv = Instance{};
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
EXPECT_THAT(ckt::run(conv, Args, inputs.get(), outputs.get()), SuccessfulRun());
auto ref_conv = Reference{};
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(ckt::run(ref_conv, Args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
EXPECT_THAT(outputs.get(), MatchesReference(Args, reference.get()));
}
TEST(BwdWeight_TwoStage_2D_FP16_NHWGC, Execution)
{
auto inputs = ckt::alloc_inputs(Args);
auto outputs = ckt::alloc_outputs(Args);
auto reference = ckt::alloc_outputs(Args);
ckt::init_inputs(Args, inputs.get());
auto conv = TwoStageInstance{};
auto elementwise_op = ElementwiseOpInstance{};
EXPECT_THAT(ckt::run(conv, elementwise_op, Args, inputs.get(), outputs.get()), SuccessfulRun());
auto ref_conv = Reference{};
EXPECT_THAT(ckt::run(ref_conv, Args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(outputs.get(), MatchesReference(Args, reference.get()));
}

View File

@@ -24,8 +24,10 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1
.with_tile_thread_block(TileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(TileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
.with_tile_optimizations(TileOptimizations{.num_groups_to_merge = 1,
.split_image = false,
.explicit_gemm = false,
.two_stage = false});
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_ck_tile_test<Builder>({

View File

@@ -31,8 +31,10 @@ constexpr auto ALGORITHM =
.with_tile_thread_block(cku::FwdTileThreadBlock_64x64x64)
.with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(cku::FwdTileTransfer_4x4x4)
.with_tile_optimizations(ckt::TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
.with_tile_optimizations(ckt::TileOptimizations{.num_groups_to_merge = 1,
.split_image = false,
.explicit_gemm = false,
.two_stage = false});
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;