# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

# ============================================================================
# Pooling Tile Engine Unit Tests
#
# This CMake file creates unit tests for tile_engine generated pooling kernels.
# Each kernel configuration gets its own test executable.
# ============================================================================

# Locate tile_engine pooling scripts directory
set(TILE_ENGINE_POOLING_DIR "${PROJECT_SOURCE_DIR}/tile_engine/ops/pooling")

if(NOT EXISTS ${TILE_ENGINE_POOLING_DIR})
    message(WARNING "Tile engine pooling directory not found: ${TILE_ENGINE_POOLING_DIR}")
    return()
endif()

# ============================================================================
# create_individual_pool_test_target
#
# Creates a single test executable for a specific pooling kernel configuration.
#
# Parameters:
#   datatype     - Data type (fp16, fp32, bf16)
#   config_name  - Configuration file name without .json extension
#   trait        - Kernel trait combination string
#   tile_config  - Tile configuration parameters
#   config_json  - Full path to JSON configuration file
# ============================================================================
function(create_individual_pool_test_target datatype config_name kernel_name trait tile_config config_json)
    set(target_name "test_pooling_tile_engine_${datatype}_${config_name}_${trait}_${tile_config}")
    set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${config_name}")

    # Generated header path (already created during cmake configuration)
    # Use kernel_name from pool_kernel_list.txt to match the filename generated by pooling_instance_builder.py
    set(test_header "${working_path}/pooling_single_${kernel_name}.hpp")

    # Determine pooling dimension from trait string (format: reduce_op_output_index_propagate_nan_pooling_dim)
    # The pooling_dim is the last field: "2d" or "3d"
    string(REGEX MATCH "[23]d$" kernel_pooling_dim "${trait}")
    if(kernel_pooling_dim STREQUAL "3d")
        set(test_params_header "${working_path}/test_params_3d.hpp")
        set(pooling_dim_value 3)
    else()
        set(test_params_header "${working_path}/test_params_2d.hpp")
        set(pooling_dim_value 2)
    endif()

    # Verify header exists
    if(NOT EXISTS ${test_header})
        message(WARNING "Generated header not found: ${test_header}")
        return()
    endif()

    # Verify test parameters header exists
    if(NOT EXISTS ${test_params_header})
        message(WARNING "Test parameters header not found: ${test_params_header}")
        return()
    endif()

    # Create GTest executable for this kernel configuration
    add_gtest_executable(${target_name}
        ${CMAKE_CURRENT_SOURCE_DIR}/test_pooling_simple.cpp
    )

    # Configure GPU architectures for HIP compilation
    set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${POOLING_TEST_GPU_TARGETS})

    # Define preprocessor macros for generated header location, test parameters, and pooling dimension
    target_compile_definitions(${target_name} PRIVATE
        POOLING_SINGLE_INSTANCE_HPP="${test_header}"
        POOLING_TEST_PARAMS_HPP="${test_params_header}"
        POOLING_DIM_VALUE=${pooling_dim_value}
    )

    # Include directories for headers and dependencies
    target_include_directories(${target_name} PRIVATE
        ${PROJECT_SOURCE_DIR}/include
        ${PROJECT_BINARY_DIR}/include
        ${PROJECT_SOURCE_DIR}  # Root directory for tile_engine access
        ${GTEST_INCLUDE_DIRS}
    )

    # Compiler options matching tile_engine requirements
    target_compile_options(${target_name} PRIVATE
        -Wno-undefined-func-template
        -Wno-float-equal
        --offload-compress
        -include ${test_header}
    )

    # Add FP8 format definitions for proper data type interpretation
    if(CK_USE_OCP_FP8)
        target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8)
    endif()

    message(STATUS "  Created test target: ${target_name}")
endfunction()

# ============================================================================
# build_pool_test_targets
#
# Builds all test targets for a specific datatype/config combination.
# Uses tile_engine's two-step process: list kernels, then generate tests.
#
# Parameters:
#   datatype     - Data type (fp16, fp32, bf16)
#   config_name  - Configuration file name without .json extension
# ============================================================================
function(build_pool_test_targets datatype config_name)
    set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${config_name}")

    # Locate and validate configuration file
    set(config_filename "${config_name}.json")
    set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}")

    if(NOT EXISTS ${json_blob})
        message(WARNING "Test config file not found: ${json_blob}")
        return()
    endif()

    # Prepare build directory for this configuration
    file(MAKE_DIRECTORY ${working_path})

    # STEP 1: Discovery phase - list all valid kernel configurations
    execute_process(
        COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_POOLING_DIR}/pooling_instance_builder.py
                --working_path ${working_path}
                --datatype ${datatype}
                --config_json ${json_blob}
                --list_kernels
        WORKING_DIRECTORY ${TILE_ENGINE_POOLING_DIR}
        RESULT_VARIABLE ret
        OUTPUT_VARIABLE list_output
        ERROR_VARIABLE list_error
    )

    if(NOT ret EQUAL 0)
        message(WARNING "Failed to list pooling kernels for ${datatype}_${config_name}: ${list_error}")
        return()
    endif()

    # Verify kernel list file was generated
    if(NOT EXISTS ${working_path}/pool_kernel_list.txt)
        message(STATUS "No pooling kernels found for ${datatype}_${config_name}")
        return()
    endif()

    message(STATUS "Building pooling tests for ${datatype}_${config_name}")

    # STEP 2a: Extract test parameters from config for BOTH 2D and 3D dimensions.
    # Each kernel's pooling_dim is embedded in its trait string, so we generate
    # separate test_params headers and select the right one per kernel target.
    set(test_params_file_2d "${working_path}/test_params_2d.hpp")
    execute_process(
        COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/extract_test_params.py
                --config_file ${json_blob}
                --output_file ${test_params_file_2d}
                --pooling_dim 2d
        WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
        RESULT_VARIABLE extract_ret_2d
        OUTPUT_VARIABLE extract_output_2d
        ERROR_VARIABLE extract_error_2d
    )
    if(NOT extract_ret_2d EQUAL 0)
        message(WARNING "Failed to extract 2D test parameters for pooling ${datatype}: ${extract_error_2d}")
        return()
    endif()

    set(test_params_file_3d "${working_path}/test_params_3d.hpp")
    execute_process(
        COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/extract_test_params.py
                --config_file ${json_blob}
                --output_file ${test_params_file_3d}
                --pooling_dim 3d
        WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
        RESULT_VARIABLE extract_ret_3d
        OUTPUT_VARIABLE extract_output_3d
        ERROR_VARIABLE extract_error_3d
    )
    if(NOT extract_ret_3d EQUAL 0)
        message(WARNING "Failed to extract 3D test parameters for pooling ${datatype}: ${extract_error_3d}")
        return()
    endif()

    # STEP 2c: Header generation phase - generate headers using --gen_single
    message(STATUS "  Generating pooling headers using --gen_single...")

    file(STRINGS ${working_path}/pool_kernel_list.txt kernel_lines)
    set(gen_count 0)

    foreach(line IN LISTS kernel_lines)
        # Parse kernel specification format: kernel_name|tile_config|trait_combo
        string(REPLACE "|" ";" parts "${line}")
        list(LENGTH parts parts_len)
        if(parts_len EQUAL 3)
            list(GET parts 0 kernel_name)
            list(GET parts 1 tile_config)
            list(GET parts 2 trait_combo)

            # Generate header using --gen_single
            execute_process(
                COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_POOLING_DIR}/pooling_instance_builder.py
                        --working_path ${working_path}
                        --datatype ${datatype}
                        --config_json ${json_blob}
                        --gen_single
                        --kernel_name "${kernel_name}"
                        --tile_config "${tile_config}"
                        --trait_combo "${trait_combo}"
                WORKING_DIRECTORY ${TILE_ENGINE_POOLING_DIR}
                RESULT_VARIABLE gen_ret
                OUTPUT_VARIABLE gen_output
                ERROR_VARIABLE gen_error
            )

            if(NOT gen_ret EQUAL 0)
                message(WARNING "Failed to generate pooling header for ${kernel_name}: ${gen_error}")
            else()
                math(EXPR gen_count "${gen_count} + 1")
            endif()
        endif()
    endforeach()

    message(STATUS "  Generated ${gen_count} pooling headers for ${datatype}")

    # STEP 3: Target creation phase - create test targets
    message(STATUS "  Creating pooling test targets...")
    file(STRINGS ${working_path}/pool_kernel_list.txt kernel_lines)
    set(test_count 0)
    foreach(line IN LISTS kernel_lines)
        string(REPLACE "|" ";" parts "${line}")
        list(LENGTH parts parts_len)
        if(parts_len EQUAL 3)
            list(GET parts 0 kernel_name)
            list(GET parts 1 tile_config)
            list(GET parts 2 trait_combo)

            create_individual_pool_test_target("${datatype}" "${config_name}" "${kernel_name}" "${trait_combo}" "${tile_config}" "${json_blob}")
            math(EXPR test_count "${test_count} + 1")
        endif()
    endforeach()
    message(STATUS "  Created ${test_count} pooling test targets for ${datatype}")
endfunction()

# ============================================================================
# MAIN EXECUTION - Test Target Generation
# ============================================================================

message(STATUS "=== Starting Pooling Tile Engine Test Configuration ===")
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")

# GPU architecture filtering - only build tests for supported architectures
set(POOLING_TEST_GPU_TARGETS "")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")

foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
    if(target IN_LIST DESIRED_TARGETS)
        list(APPEND POOLING_TEST_GPU_TARGETS ${target})
        message(STATUS "  Adding GPU target for pooling tests: ${target}")
    endif()
endforeach()

# Early exit if no compatible GPU architectures are available
if(NOT POOLING_TEST_GPU_TARGETS)
    message(WARNING "Skipping Pooling Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
    return()
endif()

message(STATUS "Building Pooling tile engine tests for GPU targets: ${POOLING_TEST_GPU_TARGETS}")

# Enable parallel compilation optimizations
set_property(GLOBAL PROPERTY JOB_POOLS
    compile_heavy=4
    compile_normal=16
)

# Enable compiler cache if available and explicitly requested
option(ENABLE_CCACHE_TESTS "Enable ccache for test compilation" OFF)
if(ENABLE_CCACHE_TESTS)
    find_program(CCACHE_PROGRAM ccache)
    if(CCACHE_PROGRAM)
        set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
        message(STATUS "Using ccache for faster test compilation")
    else()
        message(WARNING "ccache requested but not found")
    endif()
else()
    message(STATUS "ccache disabled for tests (use -DENABLE_CCACHE_TESTS=ON to enable)")
endif()

# ============================================================================
# Test Configuration Matrix
# ============================================================================

set(TEST_DATATYPES "fp16;fp32")

# ============================================================================
# Test Target Generation
# ============================================================================

# 1. SIMPLE TEST: Basic functionality validation (always built)
set(SIMPLE_TEST_CONFIG "simple_test_config")
set(SIMPLE_TEST_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${SIMPLE_TEST_CONFIG}.json")

if(EXISTS ${SIMPLE_TEST_CONFIG_FILE})
    message(STATUS "Processing pooling simple test config: ${SIMPLE_TEST_CONFIG}")
    foreach(datatype IN LISTS TEST_DATATYPES)
        build_pool_test_targets("${datatype}" "${SIMPLE_TEST_CONFIG}")
    endforeach()
else()
    message(WARNING "Pooling simple test config file not found: ${SIMPLE_TEST_CONFIG_FILE}")
endif()

# 2. COVERAGE LEVEL: Quick or comprehensive testing
#    Quick: ~2 kernels (1 tile config × 1 trait combo × fp16/fp32) from simple config only
#    Comprehensive: ~200+ kernels with extensive tile sizes, warp configurations, and all trait combinations
set(POOLING_COVERAGE_LEVEL "quick" CACHE STRING "Pooling coverage level: quick or comprehensive")
set_property(CACHE POOLING_COVERAGE_LEVEL PROPERTY STRINGS "quick" "comprehensive")

if(POOLING_COVERAGE_LEVEL STREQUAL "comprehensive")
    set(COMPREHENSIVE_CONFIG "comprehensive_coverage_config")
    set(COMPREHENSIVE_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${COMPREHENSIVE_CONFIG}.json")

    if(EXISTS ${COMPREHENSIVE_CONFIG_FILE})
        message(STATUS "Processing pooling comprehensive coverage config: ${COMPREHENSIVE_CONFIG}")
        foreach(datatype IN LISTS TEST_DATATYPES)
            build_pool_test_targets("${datatype}" "${COMPREHENSIVE_CONFIG}")
        endforeach()
    else()
        message(WARNING "Pooling comprehensive config file not found: ${COMPREHENSIVE_CONFIG_FILE}")
    endif()
elseif(NOT POOLING_COVERAGE_LEVEL STREQUAL "quick")
    message(FATAL_ERROR "Invalid POOLING_COVERAGE_LEVEL: ${POOLING_COVERAGE_LEVEL}. Must be 'quick' or 'comprehensive'")
endif()

message(STATUS "Pooling tile engine tests configured:")
message(STATUS "  - Simple test: fp16/fp32 (always)")
message(STATUS "  - Coverage level: ${POOLING_COVERAGE_LEVEL}")
message(STATUS "    Use -DPOOLING_COVERAGE_LEVEL=comprehensive for extensive testing")
