diff --git a/tile_engine/ops/commons/validation_utils.py b/tile_engine/ops/commons/gemm_validation_utils.py similarity index 58% rename from tile_engine/ops/commons/validation_utils.py rename to tile_engine/ops/commons/gemm_validation_utils.py index 5787446e8c..1b4a7191cd 100644 --- a/tile_engine/ops/commons/validation_utils.py +++ b/tile_engine/ops/commons/gemm_validation_utils.py @@ -1,16 +1,19 @@ #!/usr/bin/env python -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT - -""" -Validation utilities for GEMM kernel generation. -Extracted from tile_engine_develop for consistency. -""" +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. import logging from typing import Tuple, List -# Element size mapping for different data types +GEMM_PIPELINES = ["mem", "compv3", "compv4"] + +GEMM_PRESHUFFLE_PIPELINES = ["preshufflev2"] + +LAYOUT_MAP = { + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor", +} + ELEMENT_SIZE_MAP = { "fp16": 2, "bf16": 2, @@ -47,9 +50,79 @@ WARP_SUPPORTED_COMBINATIONS = { ], } -# [TODO] Handle this while moving code to commons -# Supported warp tile combinations for different GPU architectures and data types -WARP_TILE_SUPPORTED_COMBINATIONS = { +GEMM_PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx90a": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], + }, + "gfx942": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + "gfx950": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_bf8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 64], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + }, +} + +GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { "gfx90a": { "fp16_fp16_fp16": [ [32, 32, 8], @@ -132,7 +205,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { }, } -# Unsupported trait combinations TRAIT_UNSUPPORTED_COMBINATIONS = { ("compv3", "cshuffle", "interwave"), ("compv3", "default", "interwave"), @@ -220,7 +292,7 @@ def validate_lds_capacity( matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) total_tile_in_lds = matrix_a_size + matrix_b_size - max_tile_size = 2**15 if pipeline == "compv4" else 2**16 + max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16 if total_tile_in_lds > max_tile_size: error_msg = ( @@ -234,7 +306,7 @@ def validate_lds_capacity( return True, "" -def validate_warp_tile_combination( +def validate_gemm_warp_tile_combination( warp_tile_m: int, warp_tile_n: int, warp_tile_k: int, @@ -250,7 +322,51 @@ def validate_warp_tile_combination( current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] # Check if we have GPU-specific combinations - gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) + gpu_warp_tile_combinations = GEMM_WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) + if not gpu_warp_tile_combinations: + # If GPU not recognized, try to be permissive but log warning + logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") + return True, "" + + # Check if we have combinations for this data type combination + allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, []) + if not allowed_combinations: + # For data type combinations not in the list, be permissive + logging.debug( + f"No warp tile combinations found for data types: {warp_tile_key}" + ) + return True, "" + + # Check if current combination is in the allowed list + if current_combination not in allowed_combinations: + error_msg = ( + f"Invalid warp tile combination: {current_combination} not in allowed list. " + f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}" + ) + return False, error_msg + + return True, "" + + +def validate_gemm_preshuffle_warp_tile_combination( + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + gpu_name: str, +) -> Tuple[bool, str]: + """Validate warp tile combination against GPU-specific supported combinations.""" + + # Construct the key for looking up supported combinations + warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}" + current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] + + # Check if we have GPU-specific combinations + gpu_warp_tile_combinations = GEMM_PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS.get( + gpu_name, {} + ) if not gpu_warp_tile_combinations: # If GPU not recognized, try to be permissive but log warning logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") @@ -292,7 +408,6 @@ def is_tile_config_valid( pipeline: str, layout: str, gpu_target: str, - trait_name: str = None, ) -> bool: """ Comprehensive tile configuration validation. @@ -349,37 +464,81 @@ def is_tile_config_valid( logging.debug(f"LDS validation failed: {lds_error}") return False - # Validate whole workgroup cover configuration - wr_cover_valid, wg_cover_error = validate_whole_wg_cover_configuration( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - layout, - a_datatype, - b_datatype, - ) - if not wr_cover_valid: - logging.debug( - f"Whole workgroup cover configuration validation failed: {wg_cover_error}" + if pipeline in GEMM_PIPELINES: + gemm_valid, gemm_valid_error = validate_gemm( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + pipeline, + layout, + gpu_target, ) - return False + if not gemm_valid: + logging.debug(f"GEMM validation failed: {gemm_valid_error}") + return False - # Validate warp tile combination - warp_tile_valid, warp_tile_error = validate_warp_tile_combination( - warp_tile_m, - warp_tile_n, - warp_tile_k, - a_datatype, - b_datatype, - c_datatype, - gpu_target, - ) - if not warp_tile_valid: - logging.debug(f"Warp tile validation failed: {warp_tile_error}") - return False + # Validate warp tile combination + warp_tile_valid, warp_tile_error = validate_gemm_warp_tile_combination( + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + gpu_target, + ) + if not warp_tile_valid: + logging.debug(f"Warp tile validation failed: {warp_tile_error}") + return False + + elif pipeline in GEMM_PRESHUFFLE_PIPELINES: + preshuffle_valid, preshuffle_valid_error = validate_gemm_preshuffle( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + pipeline, + layout, + gpu_target, + ) + if not preshuffle_valid: + logging.debug( + f"GEMM Preshuffle validation failed: {preshuffle_valid_error}" + ) + return False + + # Validate warp tile combination + warp_tile_valid, warp_tile_error = ( + validate_gemm_preshuffle_warp_tile_combination( + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + gpu_target, + ) + ) + if not warp_tile_valid: + logging.debug(f"Warp tile validation failed: {warp_tile_error}") + return False return True @@ -398,12 +557,6 @@ def get_dtype_string(datatype: str) -> str: return dtype_map.get(datatype, "float") -LAYOUT_MAP = { - "r": "ck_tile::tensor_layout::gemm::RowMajor", - "c": "ck_tile::tensor_layout::gemm::ColumnMajor", -} - - def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]: """ Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. @@ -600,3 +753,200 @@ def get_global_vector_load_size( return int(PackedSize * 2 / element_size(DataType)) else: return PackedSize + + +def validate_gemm( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + pipeline: str, + layout: str, + gpu_target: str, + trait_name: str = None, +) -> bool: + # GEMM Validation + # Validate whole workgroup cover configuration + whole_workgroup_cover_valid, whole_workgroup_cover_error = ( + validate_whole_wg_cover_configuration( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + layout, + a_datatype, + b_datatype, + ) + ) + if not whole_workgroup_cover_valid: + logging.debug( + f"Whole workgroup cover configuration validation failed: {whole_workgroup_cover_error}" + ) + return False, whole_workgroup_cover_error + + return True, "" + + +def validate_gemm_preshuffle( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + pipeline: str, + layout: str, + gpu_target: str, + trait_name: str = None, +) -> bool: + # Preshuffle Validations + # Validate vector load alignment + m_iter_per_warp = tile_m / (warp_m * warp_tile_m) + vector_valid, vector_error = validate_vector_load_alignment( + warp_tile_m, + warp_tile_k, + a_datatype, + m_iter_per_warp, + wave_size=64, + vector_load_size=16, + ) + if not vector_valid: + logging.debug(f"Vector load alignment failed: {vector_error}") + return False, "vector load alignment error" + + # Validate M0, M1, M2 configuration for matrix A row-major layout + m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration( + tile_m, + tile_k, + warp_m, + warp_n, + warp_k, + a_datatype, + vector_load_size=16, + warp_size=64, + ) + if not m0_m1_m2_valid: + logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}") + return False, m0_m1_m2_error + + return True, "" + + +def validate_vector_load_alignment( + wg_m: int, + wg_k: int, + a_datatype: str, + m_iter_per_warp: int, + wave_size: int, + vector_load_size: int, +) -> Tuple[bool, str]: + try: + # Calculate the memory access pattern size + a_element_size = element_size(a_datatype) + access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size + + # Check if it's aligned to vector load size + if access_size % vector_load_size != 0: + error_msg = ( + f"Vector load alignment violation: " + f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) " + f"% {vector_load_size} = {access_size % vector_load_size} != 0. " + f"Access size: {access_size} bytes" + ) + return False, error_msg + + return True, "" + + except Exception as e: + return False, f"Error in vector load validation: {str(e)}" + + +def validate_m0_m1_m2_configuration( + tile_m: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + a_datatype: str, + vector_load_size: int = 16, + warp_size: int = 64, +) -> Tuple[bool, str]: + """ + Validate M0, M1, M2 configuration for matrix A row-major layout. + This ensures proper memory access pattern alignment. + """ + try: + # Validation for A as row-major + MPerBlock = tile_m + + # Calculate K1 using element size + K1 = vector_load_size / element_size(a_datatype) + + # Check if K1 is valid (must be integer) + if K1 != int(K1): + return ( + False, + f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})", + ) + K1 = int(K1) + + # Calculate K0 + if tile_k % K1 != 0: + return False, f"tile_k({tile_k}) must be divisible by K1({K1})" + K0 = tile_k // K1 + + # Calculate M2 + if warp_size % K0 != 0: + return False, f"warp_size({warp_size}) must be divisible by K0({K0})" + M2 = warp_size // K0 + + # Calculate number of warps and block size + NumWarps = warp_m * warp_n * warp_k + BlockSize = NumWarps * warp_size + + # Calculate M0 (assuming get_warp_size() returns warp_size) + M0 = BlockSize // warp_size # This should equal NumWarps + + # Calculate M1 + if (M2 * M0) == 0: + return False, f"M2({M2}) * M0({M0}) cannot be zero" + + if MPerBlock % (M2 * M0) != 0: + return ( + False, + f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}", + ) + M1 = MPerBlock // (M2 * M0) + + # Validate the assertion: M0 * M1 * M2 == MPerBlock + calculated_m_per_block = M0 * M1 * M2 + if calculated_m_per_block != MPerBlock: + error_msg = ( + f"Incorrect M0, M1, M2 configuration! " + f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). " + f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}" + ) + return False, error_msg + + return True, "" + + except ZeroDivisionError as e: + return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}" + except Exception as e: + return False, f"Error in M0/M1/M2 validation: {str(e)}" diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py deleted file mode 100644 index eecc2228a6..0000000000 --- a/tile_engine/ops/gemm/codegen_utils.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -# -*- coding: utf-8 -*- - -""" -Mappings and utility functions for kernel code generation. -""" - -DATA_TYPE_MAP = { - "fp32": "float", - "fp16": "ck_tile::half_t", - "bf16": "ck_tile::bf16_t", - "int8": "ck_tile::int8_t", - "fp8": "ck_tile::fp8_t", - "bf8": "ck_tile::bf8_t", - "int4": "ck_tile::pk_int4_t", - "int32": "ck_tile::int32_t", -} - -LAYOUT_MAP = { - "r": "ck_tile::tensor_layout::gemm::RowMajor", - "c": "ck_tile::tensor_layout::gemm::ColumnMajor", -} - -DEFAULT_EPILOGUE = """ - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - kPadM, - kPadN, - WarpTileM, - WarpTileN, - WarpTileK, - UniversalGemmProblem::TransposeC, - true, - memory_operation>>; -""" - -CSHUFFLE_EPILOGUE = """ - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - WarpM, - WarpN, - WarpTileM, - WarpTileN, - WarpTileK, - UniversalGemmProblem::TransposeC, - memory_operation>>; -""" - -PIPELINE_MAP = { - "mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"], - "compv3": [ - "ck_tile::BaseGemmPipelineAgBgCrCompV3", - "ck_tile::GemmPipelineAgBgCrCompV3", - ], - "compv4": [ - "ck_tile::BaseGemmPipelineAgBgCrCompV4", - "ck_tile::GemmPipelineAgBgCrCompV4", - ], -} - -SCHEDULER_MAP = { - "interwave": "ck_tile::GemmPipelineScheduler::Interwave", - "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", -} - -EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE} - - -def BOOL_MAP(b_): - return {True: "true", False: "false"}[bool(b_)] - - -# To Do: add some more supported combinations -warp_tile_supported_combinations = { - "gfx90a": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], - }, - "gfx942": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], - "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], - }, - "gfx950": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "fp8_fp8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 32], - [16, 16, 64], - [16, 16, 128], - [32, 32, 64], - ], - "bf8_bf8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 64], - [16, 16, 32], - [16, 16, 128], - [32, 32, 64], - ], - "fp8_bf8_fp16": [ - [16, 16, 128], - [32, 32, 64], - ], - "bf8_fp8_fp16": [ - [16, 16, 128], - [32, 32, 64], - ], - }, - "gfx1201": { - "fp16_fp16_fp16": [ - [16, 16, 16], - ], - }, -} - -# To Do: remove some unsupported combinations -trait_unsupported_combinations = { - ("compv3", "cshuffle", "interwave"), - ("compv3", "default", "interwave"), - ("compv4", "cshuffle", "interwave"), - ("compv4", "default", "interwave"), -} - - -ELEMENT_SIZE_MAP = { - "fp16": 2, - "bf16": 2, - "int8": 1, - "fp8": 1, - "bf8": 1, - "int4": 0.5, - "int32": 4, -} - - -def element_size(data_type: str) -> float: - """Calculate the size (in bytes) of a single element for given data type.""" - data_type = data_type.lower() - if data_type not in ELEMENT_SIZE_MAP: - raise ValueError(f"Unsupported data type: {data_type}") - return ELEMENT_SIZE_MAP[data_type] diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 8885c821c1..d450f20105 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -21,7 +21,8 @@ def _import_validation_utils(): # Load the module dynamically spec = importlib.util.spec_from_file_location( - "validation_utils", os.path.join(parent_dir, "commons", "validation_utils.py") + "validation_utils", + os.path.join(parent_dir, "commons", "gemm_validation_utils.py"), ) validation_utils = importlib.util.module_from_spec(spec) spec.loader.exec_module(validation_utils) diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py index cc167fb75f..06da7ea8a2 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py @@ -21,7 +21,8 @@ def _import_validation_utils(): # Load the module dynamically spec = importlib.util.spec_from_file_location( - "validation_utils", os.path.join(parent_dir, "commons", "validation_utils.py") + "validation_utils", + os.path.join(parent_dir, "commons", "gemm_validation_utils.py"), ) validation_utils = importlib.util.module_from_spec(spec) spec.loader.exec_module(validation_utils) @@ -824,7 +825,7 @@ def main(): elif elementwise_function == "add": function_name = "MultiDAdd" elif elementwise_function == "passthrough": - function_name = "PassThrough" # TODO Change this + function_name = "PassThrough" args.elementwise_function = function_name diff --git a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py b/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py deleted file mode 100644 index 70ce3b0d72..0000000000 --- a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py +++ /dev/null @@ -1,483 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -""" -Validation utilities for GEMM kernel generation. -Extracted from tile_engine_develop for consistency. -""" - -import logging -from typing import Tuple, List - -# Element size mapping for different data types -ELEMENT_SIZE_MAP = { - "fp16": 2, - "bf16": 2, - "int8": 1, - "fp8": 1, - "bf8": 1, - "int4": 0.5, - "int32": 4, - "fp32": 4, - "fp64": 8, -} - -# [TODO] Handle this while moving code to commons -# Supported warp tile combinations for different GPU architectures and data types -WARP_TILE_SUPPORTED_COMBINATIONS = { - "gfx90a": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], - }, - "gfx942": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], - "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], - }, - "gfx950": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "fp8_fp8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 32], - [16, 16, 64], - [16, 16, 128], - [32, 32, 64], - ], - "bf8_bf8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 64], - [16, 16, 32], - [16, 16, 128], - [32, 32, 64], - ], - }, -} - -# Unsupported trait combinations -TRAIT_UNSUPPORTED_COMBINATIONS = { - ("compv3", "cshuffle", "interwave"), - ("compv3", "default", "interwave"), - ("compv4", "cshuffle", "interwave"), - ("compv4", "default", "interwave"), -} - - -def element_size(data_type: str) -> float: - """Calculate the size (in bytes) of a single element for given data type.""" - data_type = data_type.lower() - if data_type not in ELEMENT_SIZE_MAP: - raise ValueError(f"Unsupported data type: {data_type}") - return ELEMENT_SIZE_MAP[data_type] - - -def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool: - """Check if a trait combination is valid.""" - if pipeline not in ["preshufflev2"]: - raise ValueError("Accepted pipeline values are: ['preshufflev2']") - if epilogue not in ["default", "cshuffle"]: - return ValueError("Accepted epilogue values are: ['default', 'cshuffle']") - if scheduler not in ["default"]: - return ValueError("Accepted scheduler values are: ['default']") - return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS - - -def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool: - """Validate warp configuration.""" - return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)] - - -def validate_dimension_alignment( - tile_m: int, - tile_n: int, - tile_k: int, - warp_m: int, - warp_n: int, - warp_k: int, - warp_tile_m: int, - warp_tile_n: int, - warp_tile_k: int, -) -> Tuple[bool, List[str]]: - """Check if tile dimensions are properly aligned with warp dimensions.""" - alignment_issues = [] - - if tile_m % (warp_m * warp_tile_m) != 0: - alignment_issues.append( - f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" - ) - if tile_n % (warp_n * warp_tile_n) != 0: - alignment_issues.append( - f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" - ) - if tile_k % (warp_k * warp_tile_k) != 0: - alignment_issues.append( - f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" - ) - - return len(alignment_issues) == 0, alignment_issues - - -def validate_lds_capacity( - tile_m: int, - tile_n: int, - tile_k: int, - a_datatype: str, - b_datatype: str, - pipeline: str, -) -> Tuple[bool, str]: - """Validate LDS capacity requirements.""" - matrix_a_size = (tile_m * tile_k) * element_size(a_datatype) - matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) - total_tile_in_lds = matrix_a_size + matrix_b_size - - max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16 - - if total_tile_in_lds > max_tile_size: - error_msg = ( - f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " - f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" - f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" - f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B" - ) - return False, error_msg - - return True, "" - - -def validate_warp_tile_combination( - warp_tile_m: int, - warp_tile_n: int, - warp_tile_k: int, - a_datatype: str, - b_datatype: str, - c_datatype: str, - gpu_name: str, -) -> Tuple[bool, str]: - """Validate warp tile combination against GPU-specific supported combinations.""" - - # Construct the key for looking up supported combinations - warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}" - current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] - - # Check if we have GPU-specific combinations - gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) - if not gpu_warp_tile_combinations: - # If GPU not recognized, try to be permissive but log warning - logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") - return True, "" - - # Check if we have combinations for this data type combination - allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, []) - if not allowed_combinations: - # For data type combinations not in the list, be permissive - logging.debug( - f"No warp tile combinations found for data types: {warp_tile_key}" - ) - return True, "" - - # Check if current combination is in the allowed list - if current_combination not in allowed_combinations: - error_msg = ( - f"Invalid warp tile combination: {current_combination} not in allowed list. " - f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}" - ) - return False, error_msg - - return True, "" - - -def is_tile_config_valid( - tile_m: int, - tile_n: int, - tile_k: int, - warp_m: int, - warp_n: int, - warp_k: int, - warp_tile_m: int, - warp_tile_n: int, - warp_tile_k: int, - a_datatype: str, - b_datatype: str, - c_datatype: str, - pipeline: str, - gpu_target: str, - trait_name: str = None, -) -> bool: - """ - Comprehensive tile configuration validation. - Returns True if configuration is valid, False otherwise. - """ - # Basic sanity checks - if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: - return False - if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: - return False - if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: - return False - - # Check that warp tiles fit within block tiles - if warp_m * warp_tile_m > tile_m: - return False - if warp_n * warp_tile_n > tile_n: - return False - if warp_k * warp_tile_k > tile_k: - return False - - # Validate vector load alignment - m_iter_per_warp = tile_m / (warp_m * warp_tile_m) - vector_valid, vector_error = validate_vector_load_alignment( - warp_tile_m, - warp_tile_k, - a_datatype, - m_iter_per_warp, - wave_size=64, - vector_load_size=16, - ) - if not vector_valid: - logging.debug(f"Vector load alignment failed: {vector_error}") - return False - - # Validate M0, M1, M2 configuration for matrix A row-major layout - m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration( - tile_m, - tile_k, - warp_m, - warp_n, - warp_k, - a_datatype, - vector_load_size=16, - warp_size=64, - ) - if not m0_m1_m2_valid: - logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}") - return False - - # Validate warp configuration - if not validate_warp_configuration(warp_m, warp_n, warp_k): - logging.debug( - f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})" - ) - return False - - # Validate dimension alignment - is_aligned, alignment_issues = validate_dimension_alignment( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - ) - if not is_aligned: - logging.debug( - f"Dimension alignment failed: {', '.join(alignment_issues)}. " - f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " - f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" - ) - return False - - # Validate LDS capacity - lds_valid, lds_error = validate_lds_capacity( - tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline - ) - if not lds_valid: - logging.debug(f"LDS validation failed: {lds_error}") - return False - - # Validate warp tile combination - warp_tile_valid, warp_tile_error = validate_warp_tile_combination( - warp_tile_m, - warp_tile_n, - warp_tile_k, - a_datatype, - b_datatype, - c_datatype, - gpu_target, - ) - if not warp_tile_valid: - logging.debug(f"Warp tile validation failed: {warp_tile_error}") - return False - - return True - - -def validate_vector_load_alignment( - wg_m: int, - wg_k: int, - a_datatype: str, - m_iter_per_warp: int, - wave_size: int, - vector_load_size: int, -) -> Tuple[bool, str]: - try: - # Calculate the memory access pattern size - a_element_size = element_size(a_datatype) - access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size - - # Check if it's aligned to vector load size - if access_size % vector_load_size != 0: - error_msg = ( - f"Vector load alignment violation: " - f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) " - f"% {vector_load_size} = {access_size % vector_load_size} != 0. " - f"Access size: {access_size} bytes" - ) - return False, error_msg - - return True, "" - - except Exception as e: - return False, f"Error in vector load validation: {str(e)}" - - -def validate_m0_m1_m2_configuration( - tile_m: int, - tile_k: int, - warp_m: int, - warp_n: int, - warp_k: int, - a_datatype: str, - vector_load_size: int = 16, - warp_size: int = 64, -) -> Tuple[bool, str]: - """ - Validate M0, M1, M2 configuration for matrix A row-major layout. - This ensures proper memory access pattern alignment. - """ - try: - # Validation for A as row-major - MPerBlock = tile_m - - # Calculate K1 using element size - K1 = vector_load_size / element_size(a_datatype) - - # Check if K1 is valid (must be integer) - if K1 != int(K1): - return ( - False, - f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})", - ) - K1 = int(K1) - - # Calculate K0 - if tile_k % K1 != 0: - return False, f"tile_k({tile_k}) must be divisible by K1({K1})" - K0 = tile_k // K1 - - # Calculate M2 - if warp_size % K0 != 0: - return False, f"warp_size({warp_size}) must be divisible by K0({K0})" - M2 = warp_size // K0 - - # Calculate number of warps and block size - NumWarps = warp_m * warp_n * warp_k - BlockSize = NumWarps * warp_size - - # Calculate M0 (assuming get_warp_size() returns warp_size) - M0 = BlockSize // warp_size # This should equal NumWarps - - # Calculate M1 - if (M2 * M0) == 0: - return False, f"M2({M2}) * M0({M0}) cannot be zero" - - if MPerBlock % (M2 * M0) != 0: - return ( - False, - f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}", - ) - M1 = MPerBlock // (M2 * M0) - - # Validate the assertion: M0 * M1 * M2 == MPerBlock - calculated_m_per_block = M0 * M1 * M2 - if calculated_m_per_block != MPerBlock: - error_msg = ( - f"Incorrect M0, M1, M2 configuration! " - f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). " - f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}" - ) - return False, error_msg - - return True, "" - - except ZeroDivisionError as e: - return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}" - except Exception as e: - return False, f"Error in M0/M1/M2 validation: {str(e)}" - - -# [TODO] Handle this while moving code to commons Add more datatype to this function if needed -def get_dtype_string(datatype: str) -> str: - """Get C++ type string for datatype""" - dtype_map = { - "fp16": "ck_tile::fp16_t", - "fp8": "ck_tile::fp8_t", - "bf8": "ck_tile::bf8_t", - "bf16": "ck_tile::bf16_t", - "fp32": "float", - "fp64": "double", - } - return dtype_map.get(datatype, "float") - - -LAYOUT_MAP = { - "r": "ck_tile::tensor_layout::gemm::RowMajor", - "c": "ck_tile::tensor_layout::gemm::ColumnMajor", -} - - -def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]: - """ - Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. - """ - code = str(layout_code).strip().lower() - - a_layout = LAYOUT_MAP[code[0]] - b_layout = LAYOUT_MAP[code[1]] - c_layout = LAYOUT_MAP[code[2]] - return a_layout, b_layout, c_layout diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py index 9ce6d8cb25..654a039b9c 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py @@ -8,15 +8,34 @@ import itertools import logging import multiprocessing import concurrent.futures - from pathlib import Path +import importlib.util -from commons.validation_utils import ( - is_tile_config_valid, - is_trait_combination_valid, - get_dtype_string, - get_abc_layouts, -) + +def _import_validation_utils(): + """Import validation utilities from commons directory.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = os.path.dirname(current_dir) + + # Load the module dynamically + spec = importlib.util.spec_from_file_location( + "validation_utils", + os.path.join(parent_dir, "commons", "gemm_validation_utils.py"), + ) + validation_utils = importlib.util.module_from_spec(spec) + spec.loader.exec_module(validation_utils) + + return validation_utils + + +# Import validation functions +_validation_utils = _import_validation_utils() +is_tile_config_valid = _validation_utils.is_tile_config_valid +is_trait_combination_valid = _validation_utils.is_trait_combination_valid +get_dtype_string = _validation_utils.get_dtype_string +get_abc_layouts = _validation_utils.get_abc_layouts + +logging.basicConfig(level=logging.INFO) class GemmPreshuffleKernelBuilder: @@ -305,6 +324,8 @@ class GemmPreshuffleKernelBuilder: b_datatype = self.datatype c_datatype = self.datatype + layout = self.layout + # Special handling for certain data types if self.datatype in ["fp8", "bf8"]: c_datatype = "fp16" @@ -324,6 +345,7 @@ class GemmPreshuffleKernelBuilder: b_datatype, c_datatype, pipeline, + layout, self.gpu_target, )