mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#5237 (commit ef10dc6)
[CK_TILE, CK_BUILDER] Add two-stage bwd weight kernels to CK Tile profiler (#5237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation PR #4797 added CK Tile bwd weight kernels to the CK Profiler. The two-stage kernels were not supported in the initial PR. This PR adds the the missing bwd weight two-stage kernels to the CK Profiler. ## Technical Details Extended the CK Tile conv builder factory to build also the elementwise ops required for the two-stage kernels. Extended the CK Builder for CK Tile instance to accept the two-stage flag as part of the algorithm configuration. ## Test Plan Added units tests for CK Builder that verify the two-stage kernel construction. ## Test Result If CI passes, the added unit tests are passing. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
fc2f95620d
commit
e2f5ab8000
@@ -13,6 +13,7 @@ class ConvInstanceTemplateParams:
|
||||
warp_tile,
|
||||
double_smem_buffer,
|
||||
num_wave_groups,
|
||||
is_two_stage_instance,
|
||||
pipeline_version,
|
||||
scheduler,
|
||||
scalar_per_vector,
|
||||
@@ -27,6 +28,7 @@ class ConvInstanceTemplateParams:
|
||||
self.warp_tile = warp_tile
|
||||
self.double_smem_buffer = double_smem_buffer
|
||||
self.num_wave_groups = num_wave_groups
|
||||
self.is_two_stage_instance = is_two_stage_instance
|
||||
self.pipeline_version = pipeline_version
|
||||
self.scheduler = scheduler
|
||||
self.scalar_per_vector = scalar_per_vector
|
||||
@@ -39,7 +41,8 @@ class ConvInstanceTemplateParams:
|
||||
explicit_gemm = "true" if self.explicit_gemm else "false"
|
||||
split_image = "true" if self.split_image else "false"
|
||||
num_groups_to_merge = str(self.num_groups_to_merge)
|
||||
return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}}}"
|
||||
two_stage_instance = "true" if self.is_two_stage_instance else "false"
|
||||
return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}, .two_stage = {two_stage_instance}}}"
|
||||
|
||||
def get_specialization(self):
|
||||
namespace = "ckb::TileConvSpecialization::"
|
||||
@@ -270,6 +273,8 @@ def parse_fwd_instances(instances, problem_name):
|
||||
print(f"Skipping instance {instance_id} with ASYNC_V4 since it's not supported yet.")
|
||||
continue
|
||||
|
||||
is_two_stage = False
|
||||
|
||||
conv = ConvInstanceTemplateParams(
|
||||
spec,
|
||||
[m_per_block, n_per_block, k_per_block],
|
||||
@@ -277,6 +282,7 @@ def parse_fwd_instances(instances, problem_name):
|
||||
[m_per_xdl, n_per_xdl, k_per_xdl],
|
||||
double_smem_buffer,
|
||||
num_wave_groups,
|
||||
is_two_stage,
|
||||
pipeline_version,
|
||||
scheduler,
|
||||
[a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector],
|
||||
@@ -343,7 +349,7 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
num_groups_to_merge = 1
|
||||
|
||||
# Block GEMM pipeline parameters
|
||||
blk_gemm_pipeline_schduler = args[6]
|
||||
block_gemm_pipeline_scheduler = args[6]
|
||||
blk_gemm_pipeline_version = args[7]
|
||||
else:
|
||||
spec = args[11]
|
||||
@@ -372,20 +378,29 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
num_groups_to_merge = int(args[44])
|
||||
|
||||
# Block GEMM pipeline parameters
|
||||
blk_gemm_pipeline_schduler = args[39]
|
||||
block_gemm_pipeline_scheduler = args[39]
|
||||
blk_gemm_pipeline_version = args[40]
|
||||
elif is_two_stage_instance:
|
||||
print(f"Skipping instance {instance_id} with device op {device_op_name} since it's not supported yet.")
|
||||
continue
|
||||
if len(args) != 46:
|
||||
raise RuntimeError(f"Wrong number of parameters in the TwoStage instance string: {instance}\n" +
|
||||
f"Expected 46 parameters for TwoStage instance. Found {len(args)} parameters.")
|
||||
|
||||
num_groups_to_merge = args[41]
|
||||
|
||||
# Block GEMM pipeline parameters
|
||||
block_gemm_pipeline_scheduler = args[39]
|
||||
blk_gemm_pipeline_version = args[40]
|
||||
|
||||
else:
|
||||
# Regular V1 XDL CShuffle instance
|
||||
if len(args) != 43:
|
||||
raise RuntimeError(f"Wrong number of parameters in the XDL CShuffle instance string: {instance}")
|
||||
raise RuntimeError(f"Wrong number of parameters in the XDL CShuffle instance string: {instance}\n" +
|
||||
f"Expected 43 parameters for V1 instance. Found {len(args)} parameters.")
|
||||
|
||||
num_groups_to_merge = 1
|
||||
|
||||
# Block GEMM pipeline parameters
|
||||
blk_gemm_pipeline_schduler = "Intrawave"
|
||||
block_gemm_pipeline_scheduler = "Intrawave"
|
||||
blk_gemm_pipeline_version = "v1"
|
||||
|
||||
# Common part to all solvers.
|
||||
@@ -393,15 +408,15 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
# Sanity check for Block GEMM pipeline parameters
|
||||
# Scheduler must be either Intrawave or Interwave.
|
||||
# Version must be from v1 to v5
|
||||
if blk_gemm_pipeline_schduler not in ["Intrawave", "Interwave"]:
|
||||
raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {blk_gemm_pipeline_schduler} in instance: {instance}")
|
||||
if block_gemm_pipeline_scheduler not in ["Intrawave", "Interwave"]:
|
||||
raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {block_gemm_pipeline_scheduler} in instance: {instance}")
|
||||
if blk_gemm_pipeline_version not in ["v1", "v2", "v3", "v4", "v5"]:
|
||||
raise RuntimeError(f"Invalid Block GEMM pipeline version: {blk_gemm_pipeline_version} in instance: {instance}")
|
||||
|
||||
split_image = instance.find("Large") != -1
|
||||
double_smem_buffer = blk_gemm_pipeline_version == "v4"
|
||||
num_wave_groups = 1
|
||||
scheduler = blk_gemm_pipeline_schduler
|
||||
scheduler = block_gemm_pipeline_scheduler
|
||||
pipeline_version = blk_gemm_pipeline_version.upper()
|
||||
|
||||
# OLd CK pipeline version V5 maps to V6 for CK Tile
|
||||
@@ -428,6 +443,7 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
[m_per_xdl, n_per_xdl, k_per_xdl],
|
||||
double_smem_buffer,
|
||||
num_wave_groups,
|
||||
is_two_stage_instance,
|
||||
pipeline_version,
|
||||
scheduler,
|
||||
[a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector],
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
namespace ckf = ck_tile::builder::factory;
|
||||
|
||||
namespace ck_tile::builder::profiling {
|
||||
|
||||
|
||||
@@ -1,7 +1,21 @@
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using ConvInstance = Builder::Instance;
|
||||
|
||||
auto conv = ConvInstance{};
|
||||
|
||||
auto result = [&]<auto Sig, auto Alg>() {
|
||||
if constexpr(ConvDirectionIsBackwardWeight<Sig> && Alg.optimizations.two_stage)
|
||||
{
|
||||
using ElementwiseOpBuilder = ckf::ElementwiseOpTileFactory<Sig, Alg>;
|
||||
using ElementwiseOpInstance = ElementwiseOpBuilder::Instance;
|
||||
auto elementwise_op = ElementwiseOpInstance{};
|
||||
return ckt::run(conv, elementwise_op, args, inputs, outputs, s_conf);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ckt::run(conv, args, inputs, outputs, s_conf);
|
||||
}
|
||||
}.template operator()<SIGNATURE, ALGORITHM>();
|
||||
|
||||
auto conv = Instance{};
|
||||
ckt::RunResult result = ckt::run(conv, args, inputs, outputs, s_conf);
|
||||
return std::make_tuple(result.is_supported(), result.runtime, conv.GetInstanceString());
|
||||
|
||||
Reference in New Issue
Block a user