From e5f0985b7bf56c9d735ef5fd7baf28e50c48ca9d Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Tue, 2 Sep 2025 17:11:16 +0000 Subject: [PATCH] Merge commit '9f35cde374381ba76ea793d0794ac31ced075bb0' into develop --- .../ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp | 2 +- include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 2 +- tile_engine/ops/gemm/gemm_instance_builder.py | 3 ++- tile_engine/ops/gemm/validation_utils.py | 7 ++++--- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 2850ce3379..fcd512056d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -1143,7 +1143,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel make_tuple(number{}, number{}), {i_m0, i_n1}); - EpiloguePipeline{}(o_dram_window, o_acc_tile); + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); } }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index be14a36353..87021354aa 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -513,7 +513,7 @@ struct FmhaFwdV3Kernel make_tuple(number{}, number{}), {i_m0, i_n1}); - EpiloguePipeline{}(o_dram_window, o_acc_tile); + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); } }; } // namespace ck_tile diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index d679be7b84..c2214da613 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -8,6 +8,7 @@ import multiprocessing import concurrent.futures from pathlib import Path import logging +from typing import Optional from validation_utils import is_tile_config_valid, is_trait_combination_valid logging.basicConfig(level=logging.INFO) @@ -325,7 +326,7 @@ class GemmKernelBuilder: "c": "ck_tile::tensor_layout::gemm::ColumnMajor", } - def _get_abc_layouts(self, layout_code: str | None = None): + def _get_abc_layouts(self, layout_code: Optional[str] = None): """ Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. If layout_code is None, use self.layout. diff --git a/tile_engine/ops/gemm/validation_utils.py b/tile_engine/ops/gemm/validation_utils.py index 4948fd5744..7367f2446d 100644 --- a/tile_engine/ops/gemm/validation_utils.py +++ b/tile_engine/ops/gemm/validation_utils.py @@ -11,6 +11,7 @@ import subprocess import re from functools import lru_cache import logging +from typing import Tuple, List # Element size mapping for different data types ELEMENT_SIZE_MAP = { @@ -169,7 +170,7 @@ def validate_dimension_alignment( warp_tile_m: int, warp_tile_n: int, warp_tile_k: int, -) -> tuple[bool, list[str]]: +) -> Tuple[bool, List[str]]: """Check if tile dimensions are properly aligned with warp dimensions.""" alignment_issues = [] @@ -196,7 +197,7 @@ def validate_lds_capacity( a_datatype: str, b_datatype: str, pipeline: str, -) -> tuple[bool, str]: +) -> Tuple[bool, str]: """Validate LDS capacity requirements.""" matrix_a_size = (tile_m * tile_k) * element_size(a_datatype) matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) @@ -224,7 +225,7 @@ def validate_warp_tile_combination( b_datatype: str, c_datatype: str, gpu_name: str = None, -) -> tuple[bool, str]: +) -> Tuple[bool, str]: """Validate warp tile combination against GPU-specific supported combinations.""" if gpu_name is None: gpu_name = get_gpu_name_by_id(0)