# SPDX-License-Identifier: MIT # Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # -*- coding: utf-8 -*- """ generate kernel instances to speed up compilation """ DATA_TYPE_MAP = {'fp32' : 'float', 'fp16' : 'ck_tile::half_t', 'bf16' : 'ck_tile::bf16_t', 'int8' : 'ck_tile::int8_t', 'fp8' : 'ck_tile::fp8_t', 'bf8' : 'ck_tile::bf8_t', 'int4' : 'ck_tile::pk_int4_t' } LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor', 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} DEFAULT_EPILOGUE = """ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< ck_tile::DefaultGemm2DEpilogueProblem>; """ CSHUFFLE_EPILOGUE = """ using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; """ HOT_LOOP_FALSE = """ if(tail_num == ck_tile::TailNumber::Full) { RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { throw std::runtime_error("Num K loop must be larger than number of prefetech stages."); } """ RUN_MEM = """ // Handle One and Full cases directly if (tail_num == ck_tile::TailNumber::One) { RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if (tail_num == ck_tile::TailNumber::Full) { RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } // Variadic call using fold expression auto check_tail = [&](auto... TNs) { (try_run< BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...); }; check_tail( ck_tile::integral_constant{}, ck_tile::integral_constant{}, ck_tile::integral_constant{}, ck_tile::integral_constant{}, ck_tile::integral_constant{}, ck_tile::integral_constant{} ); """ RUN_COMPV3 = """ if(tail_num == ck_tile::TailNumber::Full) { RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even."); } """ RUN_COMPV4 = """ if(tail_num == ck_tile::TailNumber::Three) { RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } """ PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'], 'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'], 'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']} SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave', 'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'} EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE, 'cshuffle' : CSHUFFLE_EPILOGUE} HOT_LOOP_TRUE = {'mem' : RUN_MEM, 'compv3' : RUN_COMPV3, 'compv4' : RUN_COMPV4} BOOL_MAP = lambda b_: {True: 'true', False: 'false'}[bool(b_)]