[CK] CK Tile grouped convolution direct load

This commit is contained in:
Bartlomiej Kocot
2026-02-04 10:41:14 +00:00
parent 8b56ffb6ae
commit 951ee54edc
25 changed files with 885 additions and 20 deletions

View File

@@ -60,7 +60,6 @@ class ConvInstanceTemplateParams:
def get_block_gemm_desc(self):
double_smem_buffer = "true" if self.double_smem_buffer else "false"
pipeline_version = self.pipeline_version[-1:]
scheduler = (
"INTRAWAVE" if self.scheduler.find("Intrawave") != -1 else "INTERWAVE"
)
@@ -69,7 +68,7 @@ class ConvInstanceTemplateParams:
.warp_tile = {{.m = {self.warp_tile[0]}, .n = {self.warp_tile[1]}, .k = {self.warp_tile[2]}}},
.double_smem_buffer = {double_smem_buffer},
.num_wave_groups = {self.num_wave_groups},
.pipeline_version = ckb::PipelineVersion::V{pipeline_version},
.pipeline_version = ckb::PipelineVersion::{self.pipeline_version},
.scheduler = ckb::PipelineScheduler::{scheduler}}}"""
def get_block_transfer(self):
@@ -180,6 +179,16 @@ def parse_fwd_instances(instances, problem_name):
pipeline_version = (
"v1" if instance.find("BlkGemmPipelineVersion") == -1 else args[15]
)
# Replace pipeline if Direct Load
if instance.find("DirectLoad") != -1:
if instance.find("BlkGemmPipelineVersion: v1") != -1:
pipeline_version = "ASYNC_V1"
elif instance.find("BlkGemmPipelineVersion: v4") != -1:
pipeline_version = "ASYNC_V4"
else:
raise RuntimeError("not supported pipeline for direct load")
else:
pipeline_version = f"""V{pipeline_version[-1:]}"""
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))