mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Ck tile engine commons (#3166)
* Moving Preshuffle to commons * Fixing Common Validations * Addressing Review Comments * Partial Rebasing * Partial Rebasing * Partial Rebasing * Rebasing Complete
This commit is contained in:
committed by
GitHub
parent
797ddfa41e
commit
9af30f04b6
@@ -1,16 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Validation utilities for GEMM kernel generation.
|
||||
Extracted from tile_engine_develop for consistency.
|
||||
"""
|
||||
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
import logging
|
||||
from typing import Tuple, List
|
||||
|
||||
# Element size mapping for different data types
|
||||
GEMM_PIPELINES = ["mem", "compv3", "compv4"]
|
||||
|
||||
GEMM_PRESHUFFLE_PIPELINES = ["preshufflev2"]
|
||||
|
||||
LAYOUT_MAP = {
|
||||
"r": "ck_tile::tensor_layout::gemm::RowMajor",
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
ELEMENT_SIZE_MAP = {
|
||||
"fp16": 2,
|
||||
"bf16": 2,
|
||||
@@ -47,9 +50,79 @@ WARP_SUPPORTED_COMBINATIONS = {
|
||||
],
|
||||
}
|
||||
|
||||
# [TODO] Handle this while moving code to commons
|
||||
# Supported warp tile combinations for different GPU architectures and data types
|
||||
WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
GEMM_PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
},
|
||||
"gfx942": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
|
||||
},
|
||||
"gfx950": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"bf8_bf8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 32],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
@@ -132,7 +205,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
},
|
||||
}
|
||||
|
||||
# Unsupported trait combinations
|
||||
TRAIT_UNSUPPORTED_COMBINATIONS = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
@@ -220,7 +292,7 @@ def validate_lds_capacity(
|
||||
matrix_b_size = (tile_n * tile_k) * element_size(b_datatype)
|
||||
total_tile_in_lds = matrix_a_size + matrix_b_size
|
||||
|
||||
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
|
||||
max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16
|
||||
|
||||
if total_tile_in_lds > max_tile_size:
|
||||
error_msg = (
|
||||
@@ -234,7 +306,7 @@ def validate_lds_capacity(
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_warp_tile_combination(
|
||||
def validate_gemm_warp_tile_combination(
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
@@ -250,7 +322,51 @@ def validate_warp_tile_combination(
|
||||
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
|
||||
|
||||
# Check if we have GPU-specific combinations
|
||||
gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {})
|
||||
gpu_warp_tile_combinations = GEMM_WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {})
|
||||
if not gpu_warp_tile_combinations:
|
||||
# If GPU not recognized, try to be permissive but log warning
|
||||
logging.warning(f"No warp tile combinations found for GPU: {gpu_name}")
|
||||
return True, ""
|
||||
|
||||
# Check if we have combinations for this data type combination
|
||||
allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, [])
|
||||
if not allowed_combinations:
|
||||
# For data type combinations not in the list, be permissive
|
||||
logging.debug(
|
||||
f"No warp tile combinations found for data types: {warp_tile_key}"
|
||||
)
|
||||
return True, ""
|
||||
|
||||
# Check if current combination is in the allowed list
|
||||
if current_combination not in allowed_combinations:
|
||||
error_msg = (
|
||||
f"Invalid warp tile combination: {current_combination} not in allowed list. "
|
||||
f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_gemm_preshuffle_warp_tile_combination(
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
gpu_name: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate warp tile combination against GPU-specific supported combinations."""
|
||||
|
||||
# Construct the key for looking up supported combinations
|
||||
warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}"
|
||||
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
|
||||
|
||||
# Check if we have GPU-specific combinations
|
||||
gpu_warp_tile_combinations = GEMM_PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS.get(
|
||||
gpu_name, {}
|
||||
)
|
||||
if not gpu_warp_tile_combinations:
|
||||
# If GPU not recognized, try to be permissive but log warning
|
||||
logging.warning(f"No warp tile combinations found for GPU: {gpu_name}")
|
||||
@@ -292,7 +408,6 @@ def is_tile_config_valid(
|
||||
pipeline: str,
|
||||
layout: str,
|
||||
gpu_target: str,
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Comprehensive tile configuration validation.
|
||||
@@ -349,37 +464,81 @@ def is_tile_config_valid(
|
||||
logging.debug(f"LDS validation failed: {lds_error}")
|
||||
return False
|
||||
|
||||
# Validate whole workgroup cover configuration
|
||||
wr_cover_valid, wg_cover_error = validate_whole_wg_cover_configuration(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
layout,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
)
|
||||
if not wr_cover_valid:
|
||||
logging.debug(
|
||||
f"Whole workgroup cover configuration validation failed: {wg_cover_error}"
|
||||
if pipeline in GEMM_PIPELINES:
|
||||
gemm_valid, gemm_valid_error = validate_gemm(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
pipeline,
|
||||
layout,
|
||||
gpu_target,
|
||||
)
|
||||
return False
|
||||
if not gemm_valid:
|
||||
logging.debug(f"GEMM validation failed: {gemm_valid_error}")
|
||||
return False
|
||||
|
||||
# Validate warp tile combination
|
||||
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
gpu_target,
|
||||
)
|
||||
if not warp_tile_valid:
|
||||
logging.debug(f"Warp tile validation failed: {warp_tile_error}")
|
||||
return False
|
||||
# Validate warp tile combination
|
||||
warp_tile_valid, warp_tile_error = validate_gemm_warp_tile_combination(
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
gpu_target,
|
||||
)
|
||||
if not warp_tile_valid:
|
||||
logging.debug(f"Warp tile validation failed: {warp_tile_error}")
|
||||
return False
|
||||
|
||||
elif pipeline in GEMM_PRESHUFFLE_PIPELINES:
|
||||
preshuffle_valid, preshuffle_valid_error = validate_gemm_preshuffle(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
pipeline,
|
||||
layout,
|
||||
gpu_target,
|
||||
)
|
||||
if not preshuffle_valid:
|
||||
logging.debug(
|
||||
f"GEMM Preshuffle validation failed: {preshuffle_valid_error}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Validate warp tile combination
|
||||
warp_tile_valid, warp_tile_error = (
|
||||
validate_gemm_preshuffle_warp_tile_combination(
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
gpu_target,
|
||||
)
|
||||
)
|
||||
if not warp_tile_valid:
|
||||
logging.debug(f"Warp tile validation failed: {warp_tile_error}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -398,12 +557,6 @@ def get_dtype_string(datatype: str) -> str:
|
||||
return dtype_map.get(datatype, "float")
|
||||
|
||||
|
||||
LAYOUT_MAP = {
|
||||
"r": "ck_tile::tensor_layout::gemm::RowMajor",
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
|
||||
def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]:
|
||||
"""
|
||||
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'.
|
||||
@@ -600,3 +753,200 @@ def get_global_vector_load_size(
|
||||
return int(PackedSize * 2 / element_size(DataType))
|
||||
else:
|
||||
return PackedSize
|
||||
|
||||
|
||||
def validate_gemm(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
tile_k: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
pipeline: str,
|
||||
layout: str,
|
||||
gpu_target: str,
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
# GEMM Validation
|
||||
# Validate whole workgroup cover configuration
|
||||
whole_workgroup_cover_valid, whole_workgroup_cover_error = (
|
||||
validate_whole_wg_cover_configuration(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
layout,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
)
|
||||
)
|
||||
if not whole_workgroup_cover_valid:
|
||||
logging.debug(
|
||||
f"Whole workgroup cover configuration validation failed: {whole_workgroup_cover_error}"
|
||||
)
|
||||
return False, whole_workgroup_cover_error
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_gemm_preshuffle(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
tile_k: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
pipeline: str,
|
||||
layout: str,
|
||||
gpu_target: str,
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
# Preshuffle Validations
|
||||
# Validate vector load alignment
|
||||
m_iter_per_warp = tile_m / (warp_m * warp_tile_m)
|
||||
vector_valid, vector_error = validate_vector_load_alignment(
|
||||
warp_tile_m,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
m_iter_per_warp,
|
||||
wave_size=64,
|
||||
vector_load_size=16,
|
||||
)
|
||||
if not vector_valid:
|
||||
logging.debug(f"Vector load alignment failed: {vector_error}")
|
||||
return False, "vector load alignment error"
|
||||
|
||||
# Validate M0, M1, M2 configuration for matrix A row-major layout
|
||||
m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration(
|
||||
tile_m,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
a_datatype,
|
||||
vector_load_size=16,
|
||||
warp_size=64,
|
||||
)
|
||||
if not m0_m1_m2_valid:
|
||||
logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}")
|
||||
return False, m0_m1_m2_error
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_vector_load_alignment(
|
||||
wg_m: int,
|
||||
wg_k: int,
|
||||
a_datatype: str,
|
||||
m_iter_per_warp: int,
|
||||
wave_size: int,
|
||||
vector_load_size: int,
|
||||
) -> Tuple[bool, str]:
|
||||
try:
|
||||
# Calculate the memory access pattern size
|
||||
a_element_size = element_size(a_datatype)
|
||||
access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size
|
||||
|
||||
# Check if it's aligned to vector load size
|
||||
if access_size % vector_load_size != 0:
|
||||
error_msg = (
|
||||
f"Vector load alignment violation: "
|
||||
f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) "
|
||||
f"% {vector_load_size} = {access_size % vector_load_size} != 0. "
|
||||
f"Access size: {access_size} bytes"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
return False, f"Error in vector load validation: {str(e)}"
|
||||
|
||||
|
||||
def validate_m0_m1_m2_configuration(
|
||||
tile_m: int,
|
||||
tile_k: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
a_datatype: str,
|
||||
vector_load_size: int = 16,
|
||||
warp_size: int = 64,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate M0, M1, M2 configuration for matrix A row-major layout.
|
||||
This ensures proper memory access pattern alignment.
|
||||
"""
|
||||
try:
|
||||
# Validation for A as row-major
|
||||
MPerBlock = tile_m
|
||||
|
||||
# Calculate K1 using element size
|
||||
K1 = vector_load_size / element_size(a_datatype)
|
||||
|
||||
# Check if K1 is valid (must be integer)
|
||||
if K1 != int(K1):
|
||||
return (
|
||||
False,
|
||||
f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})",
|
||||
)
|
||||
K1 = int(K1)
|
||||
|
||||
# Calculate K0
|
||||
if tile_k % K1 != 0:
|
||||
return False, f"tile_k({tile_k}) must be divisible by K1({K1})"
|
||||
K0 = tile_k // K1
|
||||
|
||||
# Calculate M2
|
||||
if warp_size % K0 != 0:
|
||||
return False, f"warp_size({warp_size}) must be divisible by K0({K0})"
|
||||
M2 = warp_size // K0
|
||||
|
||||
# Calculate number of warps and block size
|
||||
NumWarps = warp_m * warp_n * warp_k
|
||||
BlockSize = NumWarps * warp_size
|
||||
|
||||
# Calculate M0 (assuming get_warp_size() returns warp_size)
|
||||
M0 = BlockSize // warp_size # This should equal NumWarps
|
||||
|
||||
# Calculate M1
|
||||
if (M2 * M0) == 0:
|
||||
return False, f"M2({M2}) * M0({M0}) cannot be zero"
|
||||
|
||||
if MPerBlock % (M2 * M0) != 0:
|
||||
return (
|
||||
False,
|
||||
f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}",
|
||||
)
|
||||
M1 = MPerBlock // (M2 * M0)
|
||||
|
||||
# Validate the assertion: M0 * M1 * M2 == MPerBlock
|
||||
calculated_m_per_block = M0 * M1 * M2
|
||||
if calculated_m_per_block != MPerBlock:
|
||||
error_msg = (
|
||||
f"Incorrect M0, M1, M2 configuration! "
|
||||
f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). "
|
||||
f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
except ZeroDivisionError as e:
|
||||
return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}"
|
||||
except Exception as e:
|
||||
return False, f"Error in M0/M1/M2 validation: {str(e)}"
|
||||
@@ -1,210 +0,0 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Mappings and utility functions for kernel code generation.
|
||||
"""
|
||||
|
||||
DATA_TYPE_MAP = {
|
||||
"fp32": "float",
|
||||
"fp16": "ck_tile::half_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"int8": "ck_tile::int8_t",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"int4": "ck_tile::pk_int4_t",
|
||||
"int32": "ck_tile::int32_t",
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {
|
||||
"r": "ck_tile::tensor_layout::gemm::RowMajor",
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
DEFAULT_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
true,
|
||||
memory_operation>>;
|
||||
"""
|
||||
|
||||
CSHUFFLE_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
WarpM,
|
||||
WarpN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
"""
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"],
|
||||
"compv3": [
|
||||
"ck_tile::BaseGemmPipelineAgBgCrCompV3",
|
||||
"ck_tile::GemmPipelineAgBgCrCompV3",
|
||||
],
|
||||
"compv4": [
|
||||
"ck_tile::BaseGemmPipelineAgBgCrCompV4",
|
||||
"ck_tile::GemmPipelineAgBgCrCompV4",
|
||||
],
|
||||
}
|
||||
|
||||
SCHEDULER_MAP = {
|
||||
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
|
||||
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
|
||||
}
|
||||
|
||||
EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE}
|
||||
|
||||
|
||||
def BOOL_MAP(b_):
|
||||
return {True: "true", False: "false"}[bool(b_)]
|
||||
|
||||
|
||||
# To Do: add some more supported combinations
|
||||
warp_tile_supported_combinations = {
|
||||
"gfx90a": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
},
|
||||
"gfx942": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
|
||||
},
|
||||
"gfx950": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"bf8_bf8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 32],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"fp8_bf8_fp16": [
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"bf8_fp8_fp16": [
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
},
|
||||
"gfx1201": {
|
||||
"fp16_fp16_fp16": [
|
||||
[16, 16, 16],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# To Do: remove some unsupported combinations
|
||||
trait_unsupported_combinations = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave"),
|
||||
}
|
||||
|
||||
|
||||
ELEMENT_SIZE_MAP = {
|
||||
"fp16": 2,
|
||||
"bf16": 2,
|
||||
"int8": 1,
|
||||
"fp8": 1,
|
||||
"bf8": 1,
|
||||
"int4": 0.5,
|
||||
"int32": 4,
|
||||
}
|
||||
|
||||
|
||||
def element_size(data_type: str) -> float:
|
||||
"""Calculate the size (in bytes) of a single element for given data type."""
|
||||
data_type = data_type.lower()
|
||||
if data_type not in ELEMENT_SIZE_MAP:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
@@ -21,7 +21,8 @@ def _import_validation_utils():
|
||||
|
||||
# Load the module dynamically
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"validation_utils", os.path.join(parent_dir, "commons", "validation_utils.py")
|
||||
"validation_utils",
|
||||
os.path.join(parent_dir, "commons", "gemm_validation_utils.py"),
|
||||
)
|
||||
validation_utils = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(validation_utils)
|
||||
|
||||
@@ -21,7 +21,8 @@ def _import_validation_utils():
|
||||
|
||||
# Load the module dynamically
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"validation_utils", os.path.join(parent_dir, "commons", "validation_utils.py")
|
||||
"validation_utils",
|
||||
os.path.join(parent_dir, "commons", "gemm_validation_utils.py"),
|
||||
)
|
||||
validation_utils = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(validation_utils)
|
||||
@@ -824,7 +825,7 @@ def main():
|
||||
elif elementwise_function == "add":
|
||||
function_name = "MultiDAdd"
|
||||
elif elementwise_function == "passthrough":
|
||||
function_name = "PassThrough" # TODO Change this
|
||||
function_name = "PassThrough"
|
||||
|
||||
args.elementwise_function = function_name
|
||||
|
||||
|
||||
@@ -1,483 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Validation utilities for GEMM kernel generation.
|
||||
Extracted from tile_engine_develop for consistency.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Tuple, List
|
||||
|
||||
# Element size mapping for different data types
|
||||
ELEMENT_SIZE_MAP = {
|
||||
"fp16": 2,
|
||||
"bf16": 2,
|
||||
"int8": 1,
|
||||
"fp8": 1,
|
||||
"bf8": 1,
|
||||
"int4": 0.5,
|
||||
"int32": 4,
|
||||
"fp32": 4,
|
||||
"fp64": 8,
|
||||
}
|
||||
|
||||
# [TODO] Handle this while moving code to commons
|
||||
# Supported warp tile combinations for different GPU architectures and data types
|
||||
WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
},
|
||||
"gfx942": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
|
||||
},
|
||||
"gfx950": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"bf8_bf8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 32],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Unsupported trait combinations
|
||||
TRAIT_UNSUPPORTED_COMBINATIONS = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave"),
|
||||
}
|
||||
|
||||
|
||||
def element_size(data_type: str) -> float:
|
||||
"""Calculate the size (in bytes) of a single element for given data type."""
|
||||
data_type = data_type.lower()
|
||||
if data_type not in ELEMENT_SIZE_MAP:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
|
||||
|
||||
def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool:
|
||||
"""Check if a trait combination is valid."""
|
||||
if pipeline not in ["preshufflev2"]:
|
||||
raise ValueError("Accepted pipeline values are: ['preshufflev2']")
|
||||
if epilogue not in ["default", "cshuffle"]:
|
||||
return ValueError("Accepted epilogue values are: ['default', 'cshuffle']")
|
||||
if scheduler not in ["default"]:
|
||||
return ValueError("Accepted scheduler values are: ['default']")
|
||||
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
|
||||
|
||||
|
||||
def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool:
|
||||
"""Validate warp configuration."""
|
||||
return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
|
||||
|
||||
|
||||
def validate_dimension_alignment(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
tile_k: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""Check if tile dimensions are properly aligned with warp dimensions."""
|
||||
alignment_issues = []
|
||||
|
||||
if tile_m % (warp_m * warp_tile_m) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}"
|
||||
)
|
||||
if tile_n % (warp_n * warp_tile_n) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}"
|
||||
)
|
||||
if tile_k % (warp_k * warp_tile_k) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}"
|
||||
)
|
||||
|
||||
return len(alignment_issues) == 0, alignment_issues
|
||||
|
||||
|
||||
def validate_lds_capacity(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
tile_k: int,
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
pipeline: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate LDS capacity requirements."""
|
||||
matrix_a_size = (tile_m * tile_k) * element_size(a_datatype)
|
||||
matrix_b_size = (tile_n * tile_k) * element_size(b_datatype)
|
||||
total_tile_in_lds = matrix_a_size + matrix_b_size
|
||||
|
||||
max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16
|
||||
|
||||
if total_tile_in_lds > max_tile_size:
|
||||
error_msg = (
|
||||
f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > "
|
||||
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n"
|
||||
f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
|
||||
f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_warp_tile_combination(
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
gpu_name: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate warp tile combination against GPU-specific supported combinations."""
|
||||
|
||||
# Construct the key for looking up supported combinations
|
||||
warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}"
|
||||
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
|
||||
|
||||
# Check if we have GPU-specific combinations
|
||||
gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {})
|
||||
if not gpu_warp_tile_combinations:
|
||||
# If GPU not recognized, try to be permissive but log warning
|
||||
logging.warning(f"No warp tile combinations found for GPU: {gpu_name}")
|
||||
return True, ""
|
||||
|
||||
# Check if we have combinations for this data type combination
|
||||
allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, [])
|
||||
if not allowed_combinations:
|
||||
# For data type combinations not in the list, be permissive
|
||||
logging.debug(
|
||||
f"No warp tile combinations found for data types: {warp_tile_key}"
|
||||
)
|
||||
return True, ""
|
||||
|
||||
# Check if current combination is in the allowed list
|
||||
if current_combination not in allowed_combinations:
|
||||
error_msg = (
|
||||
f"Invalid warp tile combination: {current_combination} not in allowed list. "
|
||||
f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def is_tile_config_valid(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
tile_k: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
pipeline: str,
|
||||
gpu_target: str,
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Comprehensive tile configuration validation.
|
||||
Returns True if configuration is valid, False otherwise.
|
||||
"""
|
||||
# Basic sanity checks
|
||||
if tile_m <= 0 or tile_n <= 0 or tile_k <= 0:
|
||||
return False
|
||||
if warp_m <= 0 or warp_n <= 0 or warp_k <= 0:
|
||||
return False
|
||||
if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0:
|
||||
return False
|
||||
|
||||
# Check that warp tiles fit within block tiles
|
||||
if warp_m * warp_tile_m > tile_m:
|
||||
return False
|
||||
if warp_n * warp_tile_n > tile_n:
|
||||
return False
|
||||
if warp_k * warp_tile_k > tile_k:
|
||||
return False
|
||||
|
||||
# Validate vector load alignment
|
||||
m_iter_per_warp = tile_m / (warp_m * warp_tile_m)
|
||||
vector_valid, vector_error = validate_vector_load_alignment(
|
||||
warp_tile_m,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
m_iter_per_warp,
|
||||
wave_size=64,
|
||||
vector_load_size=16,
|
||||
)
|
||||
if not vector_valid:
|
||||
logging.debug(f"Vector load alignment failed: {vector_error}")
|
||||
return False
|
||||
|
||||
# Validate M0, M1, M2 configuration for matrix A row-major layout
|
||||
m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration(
|
||||
tile_m,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
a_datatype,
|
||||
vector_load_size=16,
|
||||
warp_size=64,
|
||||
)
|
||||
if not m0_m1_m2_valid:
|
||||
logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}")
|
||||
return False
|
||||
|
||||
# Validate warp configuration
|
||||
if not validate_warp_configuration(warp_m, warp_n, warp_k):
|
||||
logging.debug(
|
||||
f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})"
|
||||
)
|
||||
return False
|
||||
|
||||
# Validate dimension alignment
|
||||
is_aligned, alignment_issues = validate_dimension_alignment(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
)
|
||||
if not is_aligned:
|
||||
logging.debug(
|
||||
f"Dimension alignment failed: {', '.join(alignment_issues)}. "
|
||||
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
|
||||
f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Validate LDS capacity
|
||||
lds_valid, lds_error = validate_lds_capacity(
|
||||
tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline
|
||||
)
|
||||
if not lds_valid:
|
||||
logging.debug(f"LDS validation failed: {lds_error}")
|
||||
return False
|
||||
|
||||
# Validate warp tile combination
|
||||
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
gpu_target,
|
||||
)
|
||||
if not warp_tile_valid:
|
||||
logging.debug(f"Warp tile validation failed: {warp_tile_error}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_vector_load_alignment(
|
||||
wg_m: int,
|
||||
wg_k: int,
|
||||
a_datatype: str,
|
||||
m_iter_per_warp: int,
|
||||
wave_size: int,
|
||||
vector_load_size: int,
|
||||
) -> Tuple[bool, str]:
|
||||
try:
|
||||
# Calculate the memory access pattern size
|
||||
a_element_size = element_size(a_datatype)
|
||||
access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size
|
||||
|
||||
# Check if it's aligned to vector load size
|
||||
if access_size % vector_load_size != 0:
|
||||
error_msg = (
|
||||
f"Vector load alignment violation: "
|
||||
f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) "
|
||||
f"% {vector_load_size} = {access_size % vector_load_size} != 0. "
|
||||
f"Access size: {access_size} bytes"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
return False, f"Error in vector load validation: {str(e)}"
|
||||
|
||||
|
||||
def validate_m0_m1_m2_configuration(
|
||||
tile_m: int,
|
||||
tile_k: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
a_datatype: str,
|
||||
vector_load_size: int = 16,
|
||||
warp_size: int = 64,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate M0, M1, M2 configuration for matrix A row-major layout.
|
||||
This ensures proper memory access pattern alignment.
|
||||
"""
|
||||
try:
|
||||
# Validation for A as row-major
|
||||
MPerBlock = tile_m
|
||||
|
||||
# Calculate K1 using element size
|
||||
K1 = vector_load_size / element_size(a_datatype)
|
||||
|
||||
# Check if K1 is valid (must be integer)
|
||||
if K1 != int(K1):
|
||||
return (
|
||||
False,
|
||||
f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})",
|
||||
)
|
||||
K1 = int(K1)
|
||||
|
||||
# Calculate K0
|
||||
if tile_k % K1 != 0:
|
||||
return False, f"tile_k({tile_k}) must be divisible by K1({K1})"
|
||||
K0 = tile_k // K1
|
||||
|
||||
# Calculate M2
|
||||
if warp_size % K0 != 0:
|
||||
return False, f"warp_size({warp_size}) must be divisible by K0({K0})"
|
||||
M2 = warp_size // K0
|
||||
|
||||
# Calculate number of warps and block size
|
||||
NumWarps = warp_m * warp_n * warp_k
|
||||
BlockSize = NumWarps * warp_size
|
||||
|
||||
# Calculate M0 (assuming get_warp_size() returns warp_size)
|
||||
M0 = BlockSize // warp_size # This should equal NumWarps
|
||||
|
||||
# Calculate M1
|
||||
if (M2 * M0) == 0:
|
||||
return False, f"M2({M2}) * M0({M0}) cannot be zero"
|
||||
|
||||
if MPerBlock % (M2 * M0) != 0:
|
||||
return (
|
||||
False,
|
||||
f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}",
|
||||
)
|
||||
M1 = MPerBlock // (M2 * M0)
|
||||
|
||||
# Validate the assertion: M0 * M1 * M2 == MPerBlock
|
||||
calculated_m_per_block = M0 * M1 * M2
|
||||
if calculated_m_per_block != MPerBlock:
|
||||
error_msg = (
|
||||
f"Incorrect M0, M1, M2 configuration! "
|
||||
f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). "
|
||||
f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
except ZeroDivisionError as e:
|
||||
return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}"
|
||||
except Exception as e:
|
||||
return False, f"Error in M0/M1/M2 validation: {str(e)}"
|
||||
|
||||
|
||||
# [TODO] Handle this while moving code to commons Add more datatype to this function if needed
|
||||
def get_dtype_string(datatype: str) -> str:
|
||||
"""Get C++ type string for datatype"""
|
||||
dtype_map = {
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp32": "float",
|
||||
"fp64": "double",
|
||||
}
|
||||
return dtype_map.get(datatype, "float")
|
||||
|
||||
|
||||
LAYOUT_MAP = {
|
||||
"r": "ck_tile::tensor_layout::gemm::RowMajor",
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
|
||||
def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]:
|
||||
"""
|
||||
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'.
|
||||
"""
|
||||
code = str(layout_code).strip().lower()
|
||||
|
||||
a_layout = LAYOUT_MAP[code[0]]
|
||||
b_layout = LAYOUT_MAP[code[1]]
|
||||
c_layout = LAYOUT_MAP[code[2]]
|
||||
return a_layout, b_layout, c_layout
|
||||
@@ -8,15 +8,34 @@ import itertools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import concurrent.futures
|
||||
|
||||
from pathlib import Path
|
||||
import importlib.util
|
||||
|
||||
from commons.validation_utils import (
|
||||
is_tile_config_valid,
|
||||
is_trait_combination_valid,
|
||||
get_dtype_string,
|
||||
get_abc_layouts,
|
||||
)
|
||||
|
||||
def _import_validation_utils():
|
||||
"""Import validation utilities from commons directory."""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
|
||||
# Load the module dynamically
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"validation_utils",
|
||||
os.path.join(parent_dir, "commons", "gemm_validation_utils.py"),
|
||||
)
|
||||
validation_utils = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(validation_utils)
|
||||
|
||||
return validation_utils
|
||||
|
||||
|
||||
# Import validation functions
|
||||
_validation_utils = _import_validation_utils()
|
||||
is_tile_config_valid = _validation_utils.is_tile_config_valid
|
||||
is_trait_combination_valid = _validation_utils.is_trait_combination_valid
|
||||
get_dtype_string = _validation_utils.get_dtype_string
|
||||
get_abc_layouts = _validation_utils.get_abc_layouts
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class GemmPreshuffleKernelBuilder:
|
||||
@@ -305,6 +324,8 @@ class GemmPreshuffleKernelBuilder:
|
||||
b_datatype = self.datatype
|
||||
c_datatype = self.datatype
|
||||
|
||||
layout = self.layout
|
||||
|
||||
# Special handling for certain data types
|
||||
if self.datatype in ["fp8", "bf8"]:
|
||||
c_datatype = "fp16"
|
||||
@@ -324,6 +345,7 @@ class GemmPreshuffleKernelBuilder:
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
pipeline,
|
||||
layout,
|
||||
self.gpu_target,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user