[rocm-libraries] ROCm/rocm-libraries#5237 (commit ef10dc6)

[CK_TILE, CK_BUILDER] Add two-stage bwd weight kernels to CK
 Tile profiler (#5237)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

PR #4797 added CK Tile bwd weight kernels to the CK Profiler. The
two-stage kernels were not supported in the initial PR. This PR adds the
the missing bwd weight two-stage kernels to the CK Profiler.

## Technical Details

Extended the CK Tile conv builder factory to build also the elementwise
ops required for the two-stage kernels. Extended the CK Builder for CK
Tile instance to accept the two-stage flag as part of the algorithm
configuration.

## Test Plan

Added units tests for CK Builder that verify the two-stage kernel
construction.

## Test Result

If CI passes, the added unit tests are passing.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Ville Pietilä
2026-03-13 01:21:08 +00:00
committed by assistant-librarian[bot]
parent fc2f95620d
commit e2f5ab8000
16 changed files with 336 additions and 50 deletions

View File

@@ -13,6 +13,7 @@ class ConvInstanceTemplateParams:
warp_tile,
double_smem_buffer,
num_wave_groups,
is_two_stage_instance,
pipeline_version,
scheduler,
scalar_per_vector,
@@ -27,6 +28,7 @@ class ConvInstanceTemplateParams:
self.warp_tile = warp_tile
self.double_smem_buffer = double_smem_buffer
self.num_wave_groups = num_wave_groups
self.is_two_stage_instance = is_two_stage_instance
self.pipeline_version = pipeline_version
self.scheduler = scheduler
self.scalar_per_vector = scalar_per_vector
@@ -39,7 +41,8 @@ class ConvInstanceTemplateParams:
explicit_gemm = "true" if self.explicit_gemm else "false"
split_image = "true" if self.split_image else "false"
num_groups_to_merge = str(self.num_groups_to_merge)
return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}}}"
two_stage_instance = "true" if self.is_two_stage_instance else "false"
return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}, .two_stage = {two_stage_instance}}}"
def get_specialization(self):
namespace = "ckb::TileConvSpecialization::"
@@ -270,6 +273,8 @@ def parse_fwd_instances(instances, problem_name):
print(f"Skipping instance {instance_id} with ASYNC_V4 since it's not supported yet.")
continue
is_two_stage = False
conv = ConvInstanceTemplateParams(
spec,
[m_per_block, n_per_block, k_per_block],
@@ -277,6 +282,7 @@ def parse_fwd_instances(instances, problem_name):
[m_per_xdl, n_per_xdl, k_per_xdl],
double_smem_buffer,
num_wave_groups,
is_two_stage,
pipeline_version,
scheduler,
[a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector],
@@ -343,7 +349,7 @@ def parse_bwd_weight_instances(instances, problem_name):
num_groups_to_merge = 1
# Block GEMM pipeline parameters
blk_gemm_pipeline_schduler = args[6]
block_gemm_pipeline_scheduler = args[6]
blk_gemm_pipeline_version = args[7]
else:
spec = args[11]
@@ -372,20 +378,29 @@ def parse_bwd_weight_instances(instances, problem_name):
num_groups_to_merge = int(args[44])
# Block GEMM pipeline parameters
blk_gemm_pipeline_schduler = args[39]
block_gemm_pipeline_scheduler = args[39]
blk_gemm_pipeline_version = args[40]
elif is_two_stage_instance:
print(f"Skipping instance {instance_id} with device op {device_op_name} since it's not supported yet.")
continue
if len(args) != 46:
raise RuntimeError(f"Wrong number of parameters in the TwoStage instance string: {instance}\n" +
f"Expected 46 parameters for TwoStage instance. Found {len(args)} parameters.")
num_groups_to_merge = args[41]
# Block GEMM pipeline parameters
block_gemm_pipeline_scheduler = args[39]
blk_gemm_pipeline_version = args[40]
else:
# Regular V1 XDL CShuffle instance
if len(args) != 43:
raise RuntimeError(f"Wrong number of parameters in the XDL CShuffle instance string: {instance}")
raise RuntimeError(f"Wrong number of parameters in the XDL CShuffle instance string: {instance}\n" +
f"Expected 43 parameters for V1 instance. Found {len(args)} parameters.")
num_groups_to_merge = 1
# Block GEMM pipeline parameters
blk_gemm_pipeline_schduler = "Intrawave"
block_gemm_pipeline_scheduler = "Intrawave"
blk_gemm_pipeline_version = "v1"
# Common part to all solvers.
@@ -393,15 +408,15 @@ def parse_bwd_weight_instances(instances, problem_name):
# Sanity check for Block GEMM pipeline parameters
# Scheduler must be either Intrawave or Interwave.
# Version must be from v1 to v5
if blk_gemm_pipeline_schduler not in ["Intrawave", "Interwave"]:
raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {blk_gemm_pipeline_schduler} in instance: {instance}")
if block_gemm_pipeline_scheduler not in ["Intrawave", "Interwave"]:
raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {block_gemm_pipeline_scheduler} in instance: {instance}")
if blk_gemm_pipeline_version not in ["v1", "v2", "v3", "v4", "v5"]:
raise RuntimeError(f"Invalid Block GEMM pipeline version: {blk_gemm_pipeline_version} in instance: {instance}")
split_image = instance.find("Large") != -1
double_smem_buffer = blk_gemm_pipeline_version == "v4"
num_wave_groups = 1
scheduler = blk_gemm_pipeline_schduler
scheduler = block_gemm_pipeline_scheduler
pipeline_version = blk_gemm_pipeline_version.upper()
# OLd CK pipeline version V5 maps to V6 for CK Tile
@@ -428,6 +443,7 @@ def parse_bwd_weight_instances(instances, problem_name):
[m_per_xdl, n_per_xdl, k_per_xdl],
double_smem_buffer,
num_wave_groups,
is_two_stage_instance,
pipeline_version,
scheduler,
[a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector],

View File

@@ -6,6 +6,7 @@
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
namespace ckf = ck_tile::builder::factory;
namespace ck_tile::builder::profiling {

View File

@@ -1,7 +1,21 @@
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using ConvInstance = Builder::Instance;
auto conv = ConvInstance{};
auto result = [&]<auto Sig, auto Alg>() {
if constexpr(ConvDirectionIsBackwardWeight<Sig> && Alg.optimizations.two_stage)
{
using ElementwiseOpBuilder = ckf::ElementwiseOpTileFactory<Sig, Alg>;
using ElementwiseOpInstance = ElementwiseOpBuilder::Instance;
auto elementwise_op = ElementwiseOpInstance{};
return ckt::run(conv, elementwise_op, args, inputs, outputs, s_conf);
}
else
{
return ckt::run(conv, args, inputs, outputs, s_conf);
}
}.template operator()<SIGNATURE, ALGORITHM>();
auto conv = Instance{};
ckt::RunResult result = ckt::run(conv, args, inputs, outputs, s_conf);
return std::make_tuple(result.is_supported(), result.runtime, conv.GetInstanceString());