mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Merge commit '9f35cde374381ba76ea793d0794ac31ced075bb0' into develop
This commit is contained in:
@@ -1143,7 +1143,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
{i_m0, i_n1});
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile);
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -513,7 +513,7 @@ struct FmhaFwdV3Kernel
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
{i_m0, i_n1});
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile);
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -8,6 +8,7 @@ import multiprocessing
|
||||
import concurrent.futures
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from typing import Optional
|
||||
from validation_utils import is_tile_config_valid, is_trait_combination_valid
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -325,7 +326,7 @@ class GemmKernelBuilder:
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
def _get_abc_layouts(self, layout_code: str | None = None):
|
||||
def _get_abc_layouts(self, layout_code: Optional[str] = None):
|
||||
"""
|
||||
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'.
|
||||
If layout_code is None, use self.layout.
|
||||
|
||||
@@ -11,6 +11,7 @@ import subprocess
|
||||
import re
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
from typing import Tuple, List
|
||||
|
||||
# Element size mapping for different data types
|
||||
ELEMENT_SIZE_MAP = {
|
||||
@@ -169,7 +170,7 @@ def validate_dimension_alignment(
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
) -> tuple[bool, list[str]]:
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""Check if tile dimensions are properly aligned with warp dimensions."""
|
||||
alignment_issues = []
|
||||
|
||||
@@ -196,7 +197,7 @@ def validate_lds_capacity(
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
pipeline: str,
|
||||
) -> tuple[bool, str]:
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate LDS capacity requirements."""
|
||||
matrix_a_size = (tile_m * tile_k) * element_size(a_datatype)
|
||||
matrix_b_size = (tile_n * tile_k) * element_size(b_datatype)
|
||||
@@ -224,7 +225,7 @@ def validate_warp_tile_combination(
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
gpu_name: str = None,
|
||||
) -> tuple[bool, str]:
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate warp tile combination against GPU-specific supported combinations."""
|
||||
if gpu_name is None:
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
|
||||
Reference in New Issue
Block a user