mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK] CK Tile grouped convolution direct load
This commit is contained in:
@@ -60,7 +60,6 @@ class ConvInstanceTemplateParams:
|
||||
|
||||
def get_block_gemm_desc(self):
|
||||
double_smem_buffer = "true" if self.double_smem_buffer else "false"
|
||||
pipeline_version = self.pipeline_version[-1:]
|
||||
scheduler = (
|
||||
"INTRAWAVE" if self.scheduler.find("Intrawave") != -1 else "INTERWAVE"
|
||||
)
|
||||
@@ -69,7 +68,7 @@ class ConvInstanceTemplateParams:
|
||||
.warp_tile = {{.m = {self.warp_tile[0]}, .n = {self.warp_tile[1]}, .k = {self.warp_tile[2]}}},
|
||||
.double_smem_buffer = {double_smem_buffer},
|
||||
.num_wave_groups = {self.num_wave_groups},
|
||||
.pipeline_version = ckb::PipelineVersion::V{pipeline_version},
|
||||
.pipeline_version = ckb::PipelineVersion::{self.pipeline_version},
|
||||
.scheduler = ckb::PipelineScheduler::{scheduler}}}"""
|
||||
|
||||
def get_block_transfer(self):
|
||||
@@ -180,6 +179,16 @@ def parse_fwd_instances(instances, problem_name):
|
||||
pipeline_version = (
|
||||
"v1" if instance.find("BlkGemmPipelineVersion") == -1 else args[15]
|
||||
)
|
||||
# Replace pipeline if Direct Load
|
||||
if instance.find("DirectLoad") != -1:
|
||||
if instance.find("BlkGemmPipelineVersion: v1") != -1:
|
||||
pipeline_version = "ASYNC_V1"
|
||||
elif instance.find("BlkGemmPipelineVersion: v4") != -1:
|
||||
pipeline_version = "ASYNC_V4"
|
||||
else:
|
||||
raise RuntimeError("not supported pipeline for direct load")
|
||||
else:
|
||||
pipeline_version = f"""V{pipeline_version[-1:]}"""
|
||||
|
||||
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
|
||||
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))
|
||||
|
||||
Reference in New Issue
Block a user