# SPDX-License-Identifier: MIT # Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import argparse from enum import IntEnum from pathlib import Path import sys from typing import List, Optional, Any import functools import itertools import copy from dataclasses import dataclass def get_if_str(idx, total, lase_else = True): if idx == 0: return 'if' elif idx < total - 1: return 'else if' else: if lase_else: return 'else' else: return 'else if' FUSED_ADD_ENUM_STR_MAP = [ 'no', 'pras', # pre-norm 'pra' ] # post-norm FUSED_FUSED_SWEEP_STR_MAP = [ 'no', 'sdquant', # smooth dynamic quant 'dquant' ] # dynamic quant (without sm_scale) DATA_TYPE_MAP = {'fp32' : 'float', 'fp16' : 'ck_tile::fp16_t', 'bf16' : 'ck_tile::bf16_t', 'int8' : 'ck_tile::int8_t', 'fp8' : 'ck_tile::fp8_t'} def BOOL_MAP(b_) -> str: if b_: return 'true' else: return 'false' class rmsnorm_fwd_codegen: API_TRAITS_DEFINE = """ // this is used to pattern-match internl kernel implementation, not to instantiate kernel template struct rmsnorm2d_fwd_traits_ { using XDataType = ck_tile::remove_cvref_t; using YDataType = ck_tile::remove_cvref_t; using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; using UnquantYDataType = ck_tile::remove_cvref_t; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); static constexpr ck_tile::index_t total_warps = (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); } else { // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); } }(); // num of warps along n static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); return 1; } else { static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); return ThreadPerBlock_N_ / ck_tile::get_warp_size(); } }(); static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; using BlockTile = ck_tile::sequence; using BlockWarps = ck_tile::sequence; using WarpTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveInvRms = kSaveInvRms_; static constexpr bool kSaveUnquant = kSaveUnquant_; static constexpr bool kTwoPass = kTwoPass_; static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; static constexpr ck_tile::index_t kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_; }; template using traits_ = rmsnorm2d_fwd_traits_; """ API_COMMON_HEADER = """ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "rmsnorm2d_fwd.hpp" #include #include #pragma once using S = ck_tile::stream_config; using A = rmsnorm2d_fwd_args; {F_traits_define} template float rmsnorm2d_fwd_(const S& s, A a) {{ using XDataType = typename Traits_::XDataType; using YDataType = typename Traits_::YDataType; using SmoothScaleDataType = typename Traits_::SmoothScaleDataType; using YScaleDataType = typename Traits_::YScaleDataType; using UnquantYDataType = typename Traits_::UnquantYDataType; using ComputeDataType = typename RmsnormTypeConfig::ComputeDataType; using PipelineTraits = ck_tile::Rmsnorm2dFwdTraits(Traits_::kFusedAdd), static_cast(Traits_::kFusedQuant), static_cast(Traits_::kUseModelSensitiveRMSNorm)>; using PipelineProblem = ck_tile::Rmsnorm2dFwdPipelineProblem::XDataType, typename RmsnormTypeConfig::GammaDataType, typename RmsnormTypeConfig::ComputeDataType, typename RmsnormTypeConfig::YDataType, typename RmsnormTypeConfig::InvRmsDataType, typename RmsnormTypeConfig::UnquantYDataType, typename RmsnormTypeConfig::SmoothScaleDataType, typename RmsnormTypeConfig::YScaleDataType, typename Traits_::Shape, PipelineTraits>; using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; using T5PassPipeline = ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass; using Pipeline = std::conditional_t< (Traits_::kUseModelSensitiveRMSNorm == 0 || Traits_::kTwoPass), // TODO: consider TwoPass for T5PassPipeline std::conditional_t, // kUseModelSensitiveRMSNorm == 0 T5PassPipeline >; using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; using Default2DEpilogue = ck_tile::Default2DEpilogue; static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1; using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue; using Default2DAndDynamicQuantEpilogueProblem = ck_tile::Default2DAndDynamicQuantEpilogueProblem< ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, UnquantYDataType, typename Traits_::Shape, ck_tile::Default2DAndDynamicQuantEpilogueTraits>; using Default2DAndDynamicQuantEpilogue = ck_tile::Default2DAndDynamicQuantEpilogue; using Epilogue = std::conditional_t, Default2DEpilogue>; using Kernel = ck_tile::Rmsnorm2dFwd; const dim3 grids = Kernel::GridSize(a); constexpr dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); if(s.log_level_ > 0) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); }} """ API_BASE = """ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "rmsnorm2d_fwd.hpp" {F_traits_define} // Note: this internal API only declare, not define here, otherwise will block `make -j` template float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a); float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile::stream_config& s) {{ float r = -1; {F_dispatch} return r; }} """ INSTANCE_BASE = """ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "rmsnorm2d_fwd_api_common.hpp" // clang-format off // rm rn tm tn vn pd rms 2p {F_instance_def} // clang-format on """ API_PER_DTYPE = """ {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{ {F_per_n_case} }} """ API_PER_N_CASE = """ {F_if} {F_N_COND} {{ {F_inner_dispatch} }} """ API_INNER_CASE = """ {F_if} {F_VEC_COND} r={F_instance_func}(s, a); """ def __init__(self, working_path, kernel_filter): self.working_path = working_path self.kernel_filter = kernel_filter class k_fuesd_add_enum(IntEnum): F_NO_ADD = 0 F_PRE_ADD = 1 F_PRE_ADD_STORE_RESIDUAL = 2 class k_fused_sweep_enum(IntEnum): F_NO_SWEEP = 0 F_RENORM = 1 F_DYNAMIC_QUANT = 2 @dataclass class k_traits: F_kPadN : bool F_kSaveMeanInvStd : bool F_kTwoPass : bool F_kFusedAdd : Any F_kFusedQuant : Any @dataclass class k_shape: F_BlockTile : List[int] F_WarpPerBlock : List[int] F_WarpTile : List[int] F_Vector_ : List[int] @property def F_BlockSize(self) -> int: return functools.reduce(lambda a, b: a*b, self.F_WarpTile) @dataclass class k_problem: F_XDataType : str F_GammaDataType : str F_ComputeDataType : str F_YDataType : str F_InvRmsDataType : str F_BlockShape : str F_Traits : Any #k_traits @dataclass class k_pipeline_one_pass: F_Problem : Any #k_problem @dataclass class k_pipeline_two_pass: F_Problem : Any #k_problem @dataclass class default_2d_epilogue_problem: F_AccDataType : str F_ODataType : str F_kPadM : bool F_kPadN : bool @dataclass class default_2d_epilogue: F_problem : Any @dataclass class k_kernel: F_pipeline : Any F_epilogue : Any @dataclass class h_traits: F_XDataType : str F_YDataType : str F_SmoothScaleDataType : str F_YScaleDataType : str F_UnquantYDataType : str F_Repeat_M : int F_Repeat_N : int F_ThreadPerBlock_M : int F_ThreadPerBlock_N : int F_Vector_N : int F_kPadN : bool F_kSaveInvRms : bool F_kSaveUnquant: bool F_kTwoPass : bool F_kFusedAdd : int F_kFusedQuant : int F_use_model_sensitive_rmsnorm : int @property def trait_name(self) ->str: t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {DATA_TYPE_MAP[self.F_UnquantYDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kSaveUnquant):5}' t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}, {self.F_use_model_sensitive_rmsnorm:4}' return t_ # string when calling this kernel @property def call_name(self) -> str: return f'rmsnorm2d_fwd_>' # string when define this kernel @property def def_name(self) -> str: return f'template float rmsnorm2d_fwd_>(const S&, A);' # this class hold kernel under same source file @dataclass class h_instance: F_DataTypePair : str F_N : str F_add : int F_sweep : int F_saveunquant : bool F_use_model_sensitive_rmsnorm : int instance_list : List[Any] # List[h_traits] @property def name(self) -> str: prec_i, prec_o = self.F_DataTypePair.split(',') dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}' nnn = f'rmsnorm2d_fwd_{dtype_str}_n{self.F_N}' if self.F_add != 0: nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] if self.F_sweep != 0: nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] if self.F_saveunquant: nnn = nnn + '_saveunquant' if self.F_use_model_sensitive_rmsnorm == 0: nnn = nnn + '_nsm' elif self.F_use_model_sensitive_rmsnorm == 1: nnn = nnn + '_t5ml' return nnn @property def instance_name(self) ->str: return self.name @property def content(self) ->str: instance_defs = '' for ins in self.instance_list: instance_defs += ins.def_name + '\n' return rmsnorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs) @property def name_api(self) -> str: return 'rmsnorm2d_fwd_api' @property def name_common_header(self) -> str: return 'rmsnorm2d_fwd_api_common' @property def content_api(self) -> str: # 1 sort based on dtype t_dtype_dict = dict() blobs = self.get_blobs() for blob in blobs: if blob.F_DataTypePair not in t_dtype_dict: t_dtype_dict[blob.F_DataTypePair] = {} if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]: t_dtype_dict[blob.F_DataTypePair][blob.F_N] = [] t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob) d_str = '' for i_d, dtype_ in enumerate(t_dtype_dict): blob_per_t = t_dtype_dict[dtype_] n_str = '' for i_n, n_ in enumerate(blob_per_t): blob_per_n = blob_per_t[n_] inner_str = "" for i_b, b_ in enumerate(blob_per_n): # generate single kernel instance file #vec_str = "" for i_ins, ins in enumerate(b_.instance_list): idx_in_n = i_b * len(b_.instance_list) + i_ins len_in_n = len(blob_per_n) * len(b_.instance_list) # _if = 'if' if i_ins == 0 else 'else if' if ins.F_kFusedQuant == 0: _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) elif ins.F_kFusedQuant == 1: _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format( f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant)) elif ins.F_kFusedQuant == 2: _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format( f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant)) _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}) && (t.use_model_sensitive_rmsnorm == {f_use_model_sensitive_rmsnorm}) )'.format( f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, f_sweep_cond = _sweep_cond, f_use_model_sensitive_rmsnorm = ins.F_use_model_sensitive_rmsnorm) inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), F_VEC_COND = _cond, F_instance_func=ins.call_name) #inner_str = inner_str + vec_str n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else '' n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) prec_i, prec_o = dtype_.split(',') d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str) return api_base @property def content_common_header(self) -> str: return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE) def get_blobs(self): h_traits = rmsnorm_fwd_codegen.h_traits h_instance = rmsnorm_fwd_codegen.h_instance dynamic_quant_out_dtype = ['int8', 'fp8'] # some predefined support range # (prec_i,prec_o) for simplicity this string will be used as key for dict scale_list = [('fp32,fp32')] dtype_list = [('fp16,fp16'), ('bf16,bf16'), ('fp16,int8'), ('bf16,int8'), ('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 out #fused_add_list = [0, 1, 2] #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant fused_add_list = [0, 1] fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant bool_list = [False, True] h_trait_dicts = { 0: { # rm rn tm tn vn pd mv unquant 2p add sweep srm '64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 0)], '128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 0)], '256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 0)], '512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 0)], '640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 0)], '768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 0)], '1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 0)], '1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 0)], '2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 0)], '3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 0)], '4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 0)], '6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 0)], '8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 0)], 'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 0)] }, 1: { # rm rn tm tn vn pd mv unquant 2p add sweep srm '64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 1)], '128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 1)], '256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 32, 8, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 1)], '512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 1)], '640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 1)], '768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 1)], '1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 1)], '1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 1)], '2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 1)], '3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 1)], '4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 1)], '6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 1)], '8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 1)], 'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)] } } total_blob = list() for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive current_trait_dict = h_trait_dicts[model_sensitive_flag] for hs_key in current_trait_dict: hs = current_trait_dict[hs_key] current_n = hs_key for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list): prec_i, prec_o = dtype.split(',') scale_sm, scale_y = scale_type.split(',') if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2: continue # skip non dynamic quant case if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big': continue if (fused_quant == 0 and save_unquant == True): continue # save_unquant should always be false when there is no quant enabled current_hs = list() for chs_ in hs: h_ = copy.copy(chs_) # copy the base instance out h_.F_XDataType = prec_i h_.F_YDataType = prec_o h_.F_SmoothScaleDataType = scale_sm h_.F_YScaleDataType = scale_y h_.F_UnquantYDataType = prec_i h_.F_kFusedAdd = fused_add h_.F_kFusedQuant = fused_quant h_.F_kSaveUnquant = save_unquant current_hs.append(h_) # + "\n" #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ current_n_str = 'big' if hs_key == 'big' else current_n total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, h_.F_use_model_sensitive_rmsnorm, current_hs)) return total_blob def list_blobs(self) -> None: w_p = Path(self.working_path) list_p = w_p / 'rmsnorm2d_fwd_blobs.txt' blobs = self.get_blobs() with list_p.open('w') as list_f: # api related file list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") # kernel instance file for b in blobs: list_f.write(str(w_p / (b.name + ".cpp")) + "\n") def gen_blobs(self) -> None: w_p = Path(self.working_path) (w_p / (self.name_api + ".cpp")).write_text(self.content_api) (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) blobs = self.get_blobs() for b in blobs: (w_p / (b.name + ".cpp")).write_text(b.content) def list_blobs(args): api_list = args.api.split(',') for api in api_list: if api == 'fwd': rmsnorm_fwd_codegen(args.working_path, args.filter).list_blobs() def gen_blobs(args): api_list = args.api.split(',') for api in api_list: if api == 'fwd': rmsnorm_fwd_codegen(args.working_path, args.filter).gen_blobs() if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate", description="gen API for CK rmsnorm kernel", ) parser.add_argument( "-a", "--api", default='fwd[all]', required=False, help="supply API(s) to generate (default: fwd). separated by comma." ) # the directory for list_blobs/gen_blobs to write files into parser.add_argument( "-w", "--working_path", default="./", required=False, help="the path where all the blobs are going to be generated" ) # this script have 2 modes # 1) list_blobs mode, will generate a txt file with all the files going to be generated. # this is useful in build system like cmake to construct source code dependency, by # reading the content out of this file # 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework # like FA, only need to use this mode parser.add_argument( "-l", "--list_blobs", action='store_true', help="list all the kernels to a file, " ) parser.add_argument( "-g", "--gen_blobs", action='store_true', help="generate all kernels into different tile" ) # TODO: if using filter, must apply same value to output_dir and list_blobs parser.add_argument( "-f", "--filter", required=False, help="filter out kernels that need to generate, using fnmatch module" ) parser.add_argument( "-t", "--traits", default="all", required=False, help="enable/disable some feature. default generate all" ) parser.add_argument( "-r", "--receipt", default=0, required=False, help="codegen receipt." ) args = parser.parse_args() # print(f'{args.list_blobs}-{args.gen_blobs}') if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)): print('gen_blobs/list_blobs must specify only one option') sys.exit() p = Path(args.working_path) if not p.exists(): p.mkdir() if args.list_blobs: list_blobs(args) else: gen_blobs(args)