#!/usr/bin/env python # SPDX-License-Identifier: MIT # Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. """ Validation utilities for GEMM kernel generation. Extracted from tile_engine_develop for consistency. """ import logging from typing import Tuple, List # Element size mapping for different data types ELEMENT_SIZE_MAP = { "fp16": 2, "bf16": 2, "int8": 1, "fp8": 1, "bf8": 1, "int4": 0.5, "int32": 4, "fp32": 4, "fp64": 8, } WARP_SUPPORTED_COMBINATIONS = { "gfx90a": [ [1, 4, 1], [2, 2, 1], [4, 1, 1], ], "gfx942": [ [1, 4, 1], [2, 2, 1], [4, 1, 1], ], "gfx950": [ [1, 4, 1], [2, 2, 1], [4, 1, 1], ], "gfx1201": [ [2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1], ], } # [TODO] Handle this while moving code to commons # Supported warp tile combinations for different GPU architectures and data types WARP_TILE_SUPPORTED_COMBINATIONS = { "gfx90a": { "fp16_fp16_fp16": [ [32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ [32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], }, "gfx942": { "fp16_fp16_fp16": [ [32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ [32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], }, "gfx950": { "fp16_fp16_fp16": [ [32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ [32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [ [32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64], ], "bf8_bf8_fp16": [ [32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64], ], }, "gfx1201": { # Check how to handle for GEMM and Multi D "fp16_fp16_fp16": [ [16, 16, 16], ], }, } # Unsupported trait combinations TRAIT_UNSUPPORTED_COMBINATIONS = { ("compv3", "cshuffle", "interwave"), ("compv3", "default", "interwave"), ("compv4", "cshuffle", "interwave"), ("compv4", "default", "interwave"), } def element_size(data_type: str) -> float: """Calculate the size (in bytes) of a single element for given data type.""" data_type = data_type.lower() if data_type not in ELEMENT_SIZE_MAP: raise ValueError(f"Unsupported data type: {data_type}") return ELEMENT_SIZE_MAP[data_type] def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool: """Check if a trait combination is valid.""" return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS def validate_warp_configuration( warp_m: int, warp_n: int, warp_k: int, gpu_name: str, ) -> bool: """Validate warp configuration.""" current_combination = [warp_m, warp_n, warp_k] allowed_combinations = WARP_SUPPORTED_COMBINATIONS.get(gpu_name, {}) if not allowed_combinations: # If GPU not recognized, try to be permissive but log warning logging.warning(f"No warp_[m/n/k] combinations found for GPU: {gpu_name}") return True # Check if current combination is in the allowed list if current_combination not in allowed_combinations: return False return True def validate_dimension_alignment( tile_m: int, tile_n: int, tile_k: int, warp_m: int, warp_n: int, warp_k: int, warp_tile_m: int, warp_tile_n: int, warp_tile_k: int, ) -> Tuple[bool, List[str]]: """Check if tile dimensions are properly aligned with warp dimensions.""" alignment_issues = [] if tile_m % (warp_m * warp_tile_m) != 0: alignment_issues.append( f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" ) if tile_n % (warp_n * warp_tile_n) != 0: alignment_issues.append( f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" ) if tile_k % (warp_k * warp_tile_k) != 0: alignment_issues.append( f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" ) return len(alignment_issues) == 0, alignment_issues def validate_lds_capacity( tile_m: int, tile_n: int, tile_k: int, a_datatype: str, b_datatype: str, pipeline: str, ) -> Tuple[bool, str]: """Validate LDS capacity requirements.""" matrix_a_size = (tile_m * tile_k) * element_size(a_datatype) matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) total_tile_in_lds = matrix_a_size + matrix_b_size max_tile_size = 2**15 if pipeline == "compv4" else 2**16 if total_tile_in_lds > max_tile_size: error_msg = ( f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B" ) return False, error_msg return True, "" def validate_warp_tile_combination( warp_tile_m: int, warp_tile_n: int, warp_tile_k: int, a_datatype: str, b_datatype: str, c_datatype: str, gpu_name: str, ) -> Tuple[bool, str]: """Validate warp tile combination against GPU-specific supported combinations.""" # Construct the key for looking up supported combinations warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}" current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] # Check if we have GPU-specific combinations gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) if not gpu_warp_tile_combinations: # If GPU not recognized, try to be permissive but log warning logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") return True, "" # Check if we have combinations for this data type combination allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, []) if not allowed_combinations: # For data type combinations not in the list, be permissive logging.debug( f"No warp tile combinations found for data types: {warp_tile_key}" ) return True, "" # Check if current combination is in the allowed list if current_combination not in allowed_combinations: error_msg = ( f"Invalid warp tile combination: {current_combination} not in allowed list. " f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}" ) return False, error_msg return True, "" def is_tile_config_valid( tile_m: int, tile_n: int, tile_k: int, warp_m: int, warp_n: int, warp_k: int, warp_tile_m: int, warp_tile_n: int, warp_tile_k: int, a_datatype: str, b_datatype: str, c_datatype: str, pipeline: str, layout: str, gpu_target: str, trait_name: str = None, ) -> bool: """ Comprehensive tile configuration validation. Returns True if configuration is valid, False otherwise. """ # Basic sanity checks if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: return False if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: return False if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: return False # Check that warp tiles fit within block tiles if warp_m * warp_tile_m > tile_m: return False if warp_n * warp_tile_n > tile_n: return False if warp_k * warp_tile_k > tile_k: return False # Validate warp configuration if not validate_warp_configuration(warp_m, warp_n, warp_k, gpu_target): logging.debug( f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})" ) return False # Validate dimension alignment is_aligned, alignment_issues = validate_dimension_alignment( tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k, ) if not is_aligned: logging.debug( f"Dimension alignment failed: {', '.join(alignment_issues)}. " f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" ) return False # Validate LDS capacity lds_valid, lds_error = validate_lds_capacity( tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline ) if not lds_valid: logging.debug(f"LDS validation failed: {lds_error}") return False # Validate whole workgroup cover configuration wr_cover_valid, wg_cover_error = validate_whole_wg_cover_configuration( tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, layout, a_datatype, b_datatype, ) if not wr_cover_valid: logging.debug( f"Whole workgroup cover configuration validation failed: {wg_cover_error}" ) return False # Validate warp tile combination warp_tile_valid, warp_tile_error = validate_warp_tile_combination( warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype, gpu_target, ) if not warp_tile_valid: logging.debug(f"Warp tile validation failed: {warp_tile_error}") return False return True # [TODO] Handle this while moving code to commons Add more datatype to this function if needed def get_dtype_string(datatype: str) -> str: """Get C++ type string for datatype""" dtype_map = { "fp16": "ck_tile::fp16_t", "fp8": "ck_tile::fp8_t", "bf8": "ck_tile::bf8_t", "bf16": "ck_tile::bf16_t", "fp32": "float", "fp64": "double", } return dtype_map.get(datatype, "float") LAYOUT_MAP = { "r": "ck_tile::tensor_layout::gemm::RowMajor", "c": "ck_tile::tensor_layout::gemm::ColumnMajor", } def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]: """ Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. """ code = str(layout_code).strip().lower() a_layout = LAYOUT_MAP[code[0]] b_layout = LAYOUT_MAP[code[1]] c_layout = LAYOUT_MAP[code[2]] return a_layout, b_layout, c_layout def get_abcd_layouts(layout_code: str) -> Tuple[str, str, str, List[str]]: """ Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcrr', 'ccrr', 'crrr', 'rrrr'. """ code = str(layout_code).strip().lower() a_layout = LAYOUT_MAP[code[0]] b_layout = LAYOUT_MAP[code[1]] c_layout = LAYOUT_MAP[code[2]] d0_layout = LAYOUT_MAP[code[3]] d1_layout = LAYOUT_MAP[code[3]] return a_layout, b_layout, c_layout, [d0_layout, d1_layout] def validate_whole_wg_cover_configuration( tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, layout, a_datatype, b_datatype, ) -> Tuple[bool, str]: # Validate whole workgroup cover configuration warp_size = 64 NumWarps = warp_m * warp_n * warp_k BlockSize = NumWarps * warp_size XPerTile = 0 YPerTile = 0 vector_load_size = 0 # A matrix validation if layout[0] == "r": vector_load_size = get_global_vector_load_size( BlockSize, tile_k, a_datatype, tile_m, tile_k ) XPerTile = tile_k YPerTile = tile_m elif layout[0] == "c": vector_load_size = get_global_vector_load_size( BlockSize, tile_k, a_datatype, tile_m, tile_m ) # Validate distribution XPerTile = tile_k YPerTile = tile_m wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation( XPerTile, YPerTile, BlockSize, vector_load_size, warp_size ) if not wg_cover_core_valid: logging.debug( f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}" ) return False, wg_cover_core_error XPerTile = tile_m YPerTile = tile_k wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation( XPerTile, YPerTile, BlockSize, vector_load_size, warp_size ) if not wg_cover_core_valid: logging.debug( f"whole workgroup cover failed for Matrix A: {wg_cover_core_error}" ) return False, wg_cover_core_error # B matrix validation if layout[1] == "r": vector_load_size = get_global_vector_load_size( BlockSize, tile_k, b_datatype, tile_n, tile_n ) # Validate distribution XPerTile = tile_k YPerTile = tile_n wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation( XPerTile, YPerTile, BlockSize, vector_load_size, warp_size ) if not wg_cover_core_valid: print("I am here 3") logging.debug( f"whole workgroup cover failed for Matrix B distribution: {wg_cover_core_error}" ) return False, wg_cover_core_error XPerTile = tile_n YPerTile = tile_k elif layout[1] == "c": XPerTile = tile_k YPerTile = tile_n vector_load_size = get_global_vector_load_size( BlockSize, tile_k, b_datatype, tile_n, tile_k ) wg_cover_core_valid, wg_cover_core_error = wg_cover_core_validation( XPerTile, YPerTile, BlockSize, vector_load_size, warp_size ) if not wg_cover_core_valid: logging.debug( f"whole workgroup cover failed for Matrix B: {wg_cover_core_error}" ) return False, wg_cover_core_error return True, "" def wg_cover_core_validation( XPerTile: int, YPerTile: int, BlockSize: int, vector_load_size: int, warp_size: int, ) -> Tuple[bool, str]: if XPerTile % vector_load_size != 0: return False, "XPerTile is not divisible by vector_load_size" num_warps = BlockSize / warp_size LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size) X1 = LargestVec if vector_load_size > LargestVec else vector_load_size X0 = XPerTile / X1 Y1 = warp_size // X0 if X0 * Y1 != warp_size: return False, "X0 * Y1 != warp_size" return True, "" def get_global_vector_load_size( BlockSize: int, KPerBlock: int, DataType: str, MNPerBlock: int, XPerTile: int, ) -> int: elements_per_thread = MNPerBlock * KPerBlock / BlockSize PackedSize = 1 if ( PackedSize == 2 and XPerTile % (PackedSize * 32 / element_size(DataType)) == 0 and elements_per_thread % (PackedSize * 32 / element_size(DataType)) == 0 ): return PackedSize * 32 / element_size(DataType) elif ( XPerTile % (PackedSize * 16 / element_size(DataType)) == 0 and elements_per_thread % (PackedSize * 16 / element_size(DataType)) == 0 ): return int(PackedSize * 16 / element_size(DataType)) elif ( XPerTile % (PackedSize * 8 / element_size(DataType)) == 0 and elements_per_thread % (PackedSize * 8 / element_size(DataType)) == 0 ): return int(PackedSize * 8 / element_size(DataType)) elif ( element_size(DataType) >= PackedSize * 4 and XPerTile % (PackedSize * 4 / element_size(DataType)) == 0 and elements_per_thread % (PackedSize * 4 / element_size(DataType)) == 0 ): return int(PackedSize * 4 / element_size(DataType)) elif ( element_size(DataType) >= PackedSize * 2 and XPerTile % (PackedSize * 2 / element_size(DataType)) == 0 and elements_per_thread % (PackedSize * 2 / element_size(DataType)) == 0 ): return int(PackedSize * 2 / element_size(DataType)) else: return PackedSize