mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[rocm-libraries] ROCm/rocm-libraries#4406 (commit 61f9f90)
[CK] CK Tile grouped convolution direct load ## Motivation CK Tile grouped convolution forward direct load support. ## Technical Details Basic pipeline for direct load and new instances for forward for v1 and v4 pipelines. ## Test Plan test_grouped_convnd_fwd_tile ## Test Result CI pending ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-130
This commit is contained in:
committed by
assistant-librarian[bot]
parent
0cafa68b6f
commit
27e0a34e0f
@@ -60,7 +60,6 @@ class ConvInstanceTemplateParams:
|
||||
|
||||
def get_block_gemm_desc(self):
|
||||
double_smem_buffer = "true" if self.double_smem_buffer else "false"
|
||||
pipeline_version = self.pipeline_version[-1:]
|
||||
scheduler = (
|
||||
"INTRAWAVE" if self.scheduler.find("Intrawave") != -1 else "INTERWAVE"
|
||||
)
|
||||
@@ -69,7 +68,7 @@ class ConvInstanceTemplateParams:
|
||||
.warp_tile = {{.m = {self.warp_tile[0]}, .n = {self.warp_tile[1]}, .k = {self.warp_tile[2]}}},
|
||||
.double_smem_buffer = {double_smem_buffer},
|
||||
.num_wave_groups = {self.num_wave_groups},
|
||||
.pipeline_version = ckb::PipelineVersion::V{pipeline_version},
|
||||
.pipeline_version = ckb::PipelineVersion::{self.pipeline_version},
|
||||
.scheduler = ckb::PipelineScheduler::{scheduler}}}"""
|
||||
|
||||
def get_block_transfer(self):
|
||||
@@ -180,6 +179,16 @@ def parse_fwd_instances(instances, problem_name):
|
||||
pipeline_version = (
|
||||
"v1" if instance.find("BlkGemmPipelineVersion") == -1 else args[15]
|
||||
)
|
||||
# Replace pipeline if Direct Load
|
||||
if instance.find("DirectLoad") != -1:
|
||||
if instance.find("BlkGemmPipelineVersion: v1") != -1:
|
||||
pipeline_version = "ASYNC_V1"
|
||||
elif instance.find("BlkGemmPipelineVersion: v4") != -1:
|
||||
pipeline_version = "ASYNC_V4"
|
||||
else:
|
||||
raise RuntimeError("not supported pipeline for direct load")
|
||||
else:
|
||||
pipeline_version = f"""V{pipeline_version[-1:]}"""
|
||||
|
||||
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
|
||||
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))
|
||||
|
||||
Reference in New Issue
Block a user