mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Before, generate.py appended the list at the end of the output file. When running the cmake configuration steps multiple times on the examples, the blob list (such as fwd_blob_list.txt) would grow at every configuration. `library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt` worked around this issue by removing the output file if it exists. Now, generate.py overrides the content of the output file. There is no need for the workaround in the CMakeLists.txt; and the issue is solved for the example projects too.
672 lines
29 KiB
Python
672 lines
29 KiB
Python
# SPDX-License-Identifier: MIT
|
|
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
# generate kernel instances to speed up compilation
|
|
|
|
import argparse
|
|
from enum import IntEnum
|
|
from pathlib import Path
|
|
import sys
|
|
from typing import List, Optional, Any
|
|
import functools
|
|
import itertools
|
|
import copy
|
|
from dataclasses import dataclass
|
|
|
|
def get_if_str(idx, total, lase_else = True):
|
|
if idx == 0:
|
|
return 'if'
|
|
elif idx < total - 1:
|
|
return 'else if'
|
|
else:
|
|
if lase_else:
|
|
return 'else'
|
|
else:
|
|
return 'else if'
|
|
|
|
FUSED_ADD_ENUM_STR_MAP = [
|
|
'no',
|
|
'pras', # pre-norm
|
|
'pra' ] # post-norm
|
|
|
|
FUSED_FUSED_SWEEP_STR_MAP = [
|
|
'no',
|
|
'dquant' ]
|
|
|
|
DATA_TYPE_MAP = {'fp32' : 'float',
|
|
'fp16' : 'ck_tile::fp16_t',
|
|
'bf16' : 'ck_tile::bf16_t',
|
|
'int8' : 'ck_tile::int8_t'}
|
|
|
|
def BOOL_MAP(b_) -> str:
|
|
if b_:
|
|
return 'true'
|
|
else:
|
|
return 'false'
|
|
|
|
class layernorm_fwd_codegen:
|
|
API_TRAITS_DEFINE = """
|
|
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
|
template <typename XDataType_,
|
|
typename YDataType_,
|
|
typename XScaleDataType_,
|
|
typename YScaleDataType_,
|
|
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
|
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
|
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
|
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
|
ck_tile::index_t Vector_N_, // vector size along N
|
|
bool kPadN_,
|
|
bool kSaveMeanInvStd_,
|
|
bool kTwoPass_,
|
|
ck_tile::index_t kFusedAdd_ = 0,
|
|
ck_tile::index_t kFusedQuant_ = 0>
|
|
struct layernorm2d_fwd_traits_
|
|
{
|
|
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
|
|
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
|
|
using XScaleDataType = ck_tile::remove_cvref_t<XScaleDataType_>;
|
|
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
|
|
|
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
|
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
|
static constexpr ck_tile::index_t total_warps =
|
|
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
|
|
|
// num of warps along m
|
|
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
|
if constexpr(is_warp_per_row)
|
|
{
|
|
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
|
return total_warps * (warpSize / ThreadPerBlock_N_);
|
|
}
|
|
else
|
|
{
|
|
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
|
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
|
}
|
|
}();
|
|
|
|
// num of warps along n
|
|
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
|
if constexpr(is_warp_per_row)
|
|
{
|
|
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
|
return 1;
|
|
}
|
|
else
|
|
{
|
|
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
|
return ThreadPerBlock_N_ / warpSize;
|
|
}
|
|
}();
|
|
|
|
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
|
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
|
|
|
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
|
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
|
|
|
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
|
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
|
|
|
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
|
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
|
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
|
using Vector = ck_tile::sequence<1, Vector_N_>;
|
|
|
|
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
|
|
|
static constexpr bool kPadN = kPadN_;
|
|
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
|
static constexpr bool kTwoPass = kTwoPass_;
|
|
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
|
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
|
};
|
|
|
|
template <typename XDataType_,
|
|
typename YDataType_,
|
|
typename XScaleDataType_,
|
|
typename YScaleDataType_,
|
|
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
|
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
|
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
|
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
|
ck_tile::index_t Vector_N_, // vector size along N
|
|
bool kPadN_,
|
|
bool kSaveMeanInvStd_,
|
|
bool kTwoPass_,
|
|
int kFusedAdd_,
|
|
int kFusedQuant_>
|
|
using traits_ = layernorm2d_fwd_traits_<XDataType_,
|
|
YDataType_,
|
|
XScaleDataType_,
|
|
YScaleDataType_,
|
|
Repeat_M_,
|
|
Repeat_N_,
|
|
ThreadPerBlock_M_,
|
|
ThreadPerBlock_N_,
|
|
Vector_N_,
|
|
kPadN_,
|
|
kSaveMeanInvStd_,
|
|
kTwoPass_,
|
|
kFusedAdd_,
|
|
kFusedQuant_>;
|
|
"""
|
|
API_COMMON_HEADER = """
|
|
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include <ck_tile/core.hpp>
|
|
#include "layernorm2d_fwd.hpp"
|
|
#include <ck_tile/ops/epilogue.hpp>
|
|
#include <iostream>
|
|
|
|
#pragma once
|
|
|
|
using S = ck_tile::stream_config;
|
|
using A = layernorm2d_fwd_args;
|
|
|
|
{F_traits_define}
|
|
|
|
template <typename Traits_>
|
|
float layernorm2d_fwd_(const S& s, A a)
|
|
{{
|
|
using XDataType = typename Traits_::XDataType;
|
|
using YDataType = typename Traits_::YDataType;
|
|
using XScaleDataType = typename Traits_::XScaleDataType;
|
|
using YScaleDataType = typename Traits_::YScaleDataType;
|
|
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType;
|
|
|
|
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
|
|
Traits_::kSaveMeanInvStd,
|
|
Traits_::kTwoPass,
|
|
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
|
|
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
|
|
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
|
|
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::MeanDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::InvStdDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XScaleDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YScaleDataType,
|
|
typename Traits_::Shape,
|
|
PipelineTraits>;
|
|
|
|
using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass<PipelineProblem>;
|
|
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
|
|
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
|
|
|
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
|
|
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
|
|
|
|
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
|
|
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, XScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
|
|
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>;
|
|
|
|
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
|
|
|
|
using Epilogue = std::conditional_t<Traits_::kFusedQuant == 1, DynamicQuantEpilogue, Default2DEpilogue>;
|
|
|
|
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
|
|
|
|
const dim3 grids = Kernel::GridSize(a);
|
|
constexpr dim3 blocks = Kernel::BlockSize();
|
|
constexpr ck_tile::index_t kBlockPerCu = 1;
|
|
|
|
auto kargs = Kernel::MakeKargs(a);
|
|
if(s.log_level_ > 0)
|
|
std::cout << ", " << Kernel::GetName() << std::flush;
|
|
|
|
return ck_tile::launch_kernel(
|
|
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
|
|
}}
|
|
|
|
"""
|
|
|
|
API_BASE = """
|
|
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include <ck_tile/core.hpp>
|
|
#include "layernorm2d_fwd.hpp"
|
|
|
|
{F_traits_define}
|
|
|
|
// Note: this internal API only declare, not define here, otherwise will block `make -j`
|
|
template <typename Traits_>
|
|
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a);
|
|
|
|
float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
|
layernorm2d_fwd_args a,
|
|
const ck_tile::stream_config& s)
|
|
{{
|
|
float r = -1;
|
|
{F_dispatch}
|
|
return r;
|
|
}}
|
|
|
|
"""
|
|
|
|
API_PER_DTYPE=""" {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{
|
|
{F_per_n_case}
|
|
}}
|
|
"""
|
|
API_PER_N_CASE=""" {F_if} {F_N_COND} {{
|
|
{F_inner_dispatch}
|
|
}}
|
|
"""
|
|
API_INNER_CASE=""" {F_if} {F_VEC_COND}
|
|
r={F_instance_func}(s, a);
|
|
"""
|
|
|
|
INSTANCE_BASE = """
|
|
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include "layernorm2d_fwd_api_common.hpp"
|
|
|
|
// clang-format off
|
|
// prec_i prec_o prec_sy rm rn tm tn vn pd mv 2p add sweep
|
|
{F_instance_def}
|
|
// clang-format on
|
|
|
|
"""
|
|
|
|
def __init__(self, working_path, kernel_filter):
|
|
self.working_path = working_path
|
|
self.kernel_filter = kernel_filter
|
|
|
|
class k_fuesd_add_enum(IntEnum):
|
|
F_NO_ADD = 0
|
|
F_PRE_ADD = 1
|
|
F_PRE_ADD_STORE_RESIDUAL = 2
|
|
|
|
class k_fused_sweep_enum(IntEnum):
|
|
F_NO_SWEEP = 0
|
|
F_RENORM = 1
|
|
F_DYNAMIC_QUANT = 2
|
|
|
|
@dataclass
|
|
class k_traits:
|
|
F_kPadN : bool
|
|
F_kSaveMeanInvStd : bool
|
|
F_kTwoPass : bool
|
|
F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum
|
|
F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum
|
|
|
|
@dataclass
|
|
class k_shape:
|
|
F_BlockTile : List[int]
|
|
F_WarpPerBlock : List[int]
|
|
F_WarpTile : List[int]
|
|
F_Vector_ : List[int]
|
|
@property
|
|
def F_BlockSize(self) -> int:
|
|
return functools.reduce(lambda a, b: a*b, self.F_WarpTile)
|
|
|
|
@dataclass
|
|
class k_problem:
|
|
F_XDataType : str
|
|
F_GammaDataType : str
|
|
F_BetaDataType : str
|
|
F_ComputeDataType : str
|
|
F_YDataType : str
|
|
F_MeanDataType : str
|
|
F_InvStdDataType : str
|
|
F_BlockShape : str
|
|
F_Traits : Any #k_traits
|
|
|
|
@dataclass
|
|
class k_pipeline_one_pass:
|
|
F_Problem : Any #k_problem
|
|
|
|
@dataclass
|
|
class k_pipeline_two_pass:
|
|
F_Problem : Any #k_problem
|
|
|
|
@dataclass
|
|
class default_2d_epilogue_problem:
|
|
F_AccDataType : str
|
|
F_ODataType : str
|
|
F_kPadM : bool
|
|
F_kPadN : bool
|
|
|
|
@dataclass
|
|
class default_2d_epilogue:
|
|
F_problem : Any
|
|
|
|
@dataclass
|
|
class k_kernel:
|
|
F_pipeline : Any
|
|
F_epilogue : Any
|
|
|
|
@dataclass
|
|
class h_traits:
|
|
F_XDataType : str
|
|
F_YDataType : str
|
|
F_XScaleDataType : str
|
|
F_YScaleDataType : str
|
|
F_Repeat_M : int
|
|
F_Repeat_N : int
|
|
F_ThreadPerBlock_M : int
|
|
F_ThreadPerBlock_N : int
|
|
F_Vector_N : int
|
|
F_kPadN : bool
|
|
F_kSaveMeanInvStd_ : bool
|
|
F_kTwoPass_ : bool
|
|
F_kFusedAdd : int
|
|
F_kFusedQuant : int
|
|
|
|
@property
|
|
def trait_name(self) ->str:
|
|
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
|
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}'
|
|
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
|
|
return t_
|
|
|
|
# string when calling this kernel
|
|
@property
|
|
def call_name(self) -> str:
|
|
return f'layernorm2d_fwd_<traits_<{self.trait_name}>>'
|
|
|
|
# string when define this kernel
|
|
@property
|
|
def def_name(self) -> str:
|
|
return f'template float layernorm2d_fwd_<traits_<{self.trait_name}>>(const S&, A);'
|
|
|
|
# this class hold kernel under same source file
|
|
@dataclass
|
|
class h_instance:
|
|
F_DataTypePair : str
|
|
F_N : str
|
|
F_add : int
|
|
F_sweep : int
|
|
instance_list : List[Any] # List[h_traits]
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
prec_i, prec_o = self.F_DataTypePair.split(',')
|
|
dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}'
|
|
nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}'
|
|
if self.F_add != 0:
|
|
nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add]
|
|
if self.F_sweep != 0:
|
|
nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
|
|
return nnn
|
|
|
|
@property
|
|
def instance_name(self) ->str:
|
|
return self.name
|
|
|
|
@property
|
|
def content(self) ->str:
|
|
instance_defs = ''
|
|
for ins in self.instance_list:
|
|
instance_defs += ins.def_name + '\n'
|
|
return layernorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs)
|
|
|
|
@property
|
|
def name_api(self) -> str:
|
|
return 'layernorm2d_fwd_api'
|
|
|
|
@property
|
|
def name_common_header(self) -> str:
|
|
return 'layernorm2d_fwd_api_common'
|
|
|
|
@property
|
|
def content_api(self) -> str:
|
|
# 1 sort based on dtype
|
|
t_dtype_dict = dict()
|
|
blobs = self.get_blobs()
|
|
for blob in blobs:
|
|
if blob.F_DataTypePair not in t_dtype_dict:
|
|
t_dtype_dict[blob.F_DataTypePair] = {}
|
|
if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]:
|
|
t_dtype_dict[blob.F_DataTypePair][blob.F_N] = []
|
|
t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob)
|
|
|
|
d_str = ''
|
|
for i_d, dtype_ in enumerate(t_dtype_dict):
|
|
blob_per_t = t_dtype_dict[dtype_]
|
|
n_str = ''
|
|
for i_n, n_ in enumerate(blob_per_t):
|
|
blob_per_n = blob_per_t[n_]
|
|
inner_str = ""
|
|
for i_b, b_ in enumerate(blob_per_n):
|
|
# generate single kernel instance file
|
|
#vec_str = ""
|
|
for i_ins, ins in enumerate(b_.instance_list):
|
|
idx_in_n = i_b * len(b_.instance_list) + i_ins
|
|
len_in_n = len(blob_per_n) * len(b_.instance_list)
|
|
# _if = 'if' if i_ins == 0 else 'else if'
|
|
if ins.F_kFusedQuant == 0:
|
|
_sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant)
|
|
elif ins.F_kFusedQuant == 1:
|
|
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sx == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format(
|
|
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_XScaleDataType, f_sy_type=ins.F_YScaleDataType)
|
|
elif ins.F_kFusedQuant == 2:
|
|
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
|
|
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
|
|
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
|
|
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd,
|
|
f_sweep_cond = _sweep_cond)
|
|
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
|
|
F_VEC_COND = _cond, F_instance_func=ins.call_name)
|
|
#inner_str = inner_str + vec_str
|
|
n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else ''
|
|
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
|
|
prec_i, prec_o = dtype_.split(',')
|
|
d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str)
|
|
|
|
api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str)
|
|
return api_base
|
|
|
|
@property
|
|
def content_common_header(self) -> str:
|
|
return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE)
|
|
|
|
def get_blobs(self):
|
|
h_traits = layernorm_fwd_codegen.h_traits
|
|
h_instance = layernorm_fwd_codegen.h_instance
|
|
|
|
dynamic_quant_out_dtype = ['int8']
|
|
# some predefined support range
|
|
# (prec_i,prec_o) for simplicity this string will be used as key for dict
|
|
scale_list = [('fp32,fp32')]
|
|
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
|
|
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
|
|
#fused_add_list = [0, 1, 2]
|
|
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant
|
|
fused_add_list = [0, 1]
|
|
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
|
|
|
|
# rm rn tm tn vn pd mv 2p add sweep
|
|
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)],
|
|
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)],
|
|
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)],
|
|
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)],
|
|
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)],
|
|
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)],
|
|
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)],
|
|
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)],
|
|
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)],
|
|
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)],
|
|
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)],
|
|
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)],
|
|
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0),
|
|
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]}
|
|
total_blob = list()
|
|
for hs_key in h_trait_dict:
|
|
hs = h_trait_dict[hs_key]
|
|
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
|
|
for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list):
|
|
prec_i, prec_o = dtype.split(',')
|
|
scale_x, scale_y = scale_type.split(',')
|
|
if prec_o in dynamic_quant_out_dtype and fused_quant != 1:
|
|
continue # skip non dynamic quant case
|
|
if fused_quant == 1 and hs_key == 'big':
|
|
continue
|
|
current_hs = list()
|
|
for chs_ in hs:
|
|
h_ = copy.copy(chs_) # copy the base instance out
|
|
h_.F_XDataType = prec_i
|
|
h_.F_YDataType = prec_o
|
|
h_.F_XScaleDataType = scale_y
|
|
h_.F_YScaleDataType = scale_x
|
|
h_.F_kFusedAdd = fused_add
|
|
h_.F_kFusedQuant = fused_quant
|
|
current_hs.append(h_) # + "\n"
|
|
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
|
current_n_str = 'big' if hs_key == 'big' else current_n
|
|
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs))
|
|
return total_blob
|
|
|
|
def list_blobs(self) -> None:
|
|
w_p = Path(self.working_path)
|
|
list_p = w_p / 'layernorm2d_fwd_blobs.txt'
|
|
blobs = self.get_blobs()
|
|
with list_p.open('w') as list_f:
|
|
# api related file
|
|
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
|
|
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
|
|
# kernel instance file
|
|
for b in blobs:
|
|
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
|
|
|
|
def gen_blobs(self) -> None:
|
|
w_p = Path(self.working_path)
|
|
(w_p / (self.name_api + ".cpp")).write_text(self.content_api)
|
|
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
|
|
blobs = self.get_blobs()
|
|
for b in blobs:
|
|
(w_p / (b.name + ".cpp")).write_text(b.content)
|
|
|
|
def list_blobs(args):
|
|
api_list = args.api.split(',')
|
|
for api in api_list:
|
|
if api == 'fwd':
|
|
layernorm_fwd_codegen(args.working_path, args.filter).list_blobs()
|
|
|
|
|
|
def gen_blobs(args):
|
|
api_list = args.api.split(',')
|
|
for api in api_list:
|
|
if api == 'fwd':
|
|
layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs()
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
prog="generate",
|
|
description="gen API for CK layernorm kernel",
|
|
)
|
|
parser.add_argument(
|
|
"-a",
|
|
"--api",
|
|
default='fwd[all]',
|
|
required=False,
|
|
help="supply API(s) to generate (default: fwd). separated by comma."
|
|
)
|
|
|
|
# the directory for list_blobs/gen_blobs to write files into
|
|
parser.add_argument(
|
|
"-w",
|
|
"--working_path",
|
|
default="./",
|
|
required=False,
|
|
help="the path where all the blobs are going to be generated"
|
|
)
|
|
|
|
# this script have 2 modes
|
|
# 1) list_blobs mode, will generate a txt file with all the files going to be generated.
|
|
# this is useful in build system like cmake to construct source code dependency, by
|
|
# reading the content out of this file
|
|
# 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework
|
|
# like FA, only need to use this mode
|
|
parser.add_argument(
|
|
"-l",
|
|
"--list_blobs",
|
|
action='store_true',
|
|
help="list all the kernels to a file, "
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-g",
|
|
"--gen_blobs",
|
|
action='store_true',
|
|
help="generate all kernels into different tile"
|
|
)
|
|
|
|
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
|
parser.add_argument(
|
|
"-f",
|
|
"--filter",
|
|
required=False,
|
|
help="filter out kernels that need to generate, using fnmatch module"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-t",
|
|
"--traits",
|
|
default="all",
|
|
required=False,
|
|
help="enable/disable some feature. default generate all"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-r",
|
|
"--receipt",
|
|
default=0,
|
|
required=False,
|
|
help="codegen receipt."
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# print(f'{args.list_blobs}-{args.gen_blobs}')
|
|
if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)):
|
|
print('gen_blobs/list_blobs must specify only one option')
|
|
sys.exit()
|
|
|
|
p = Path(args.working_path)
|
|
if not p.exists():
|
|
p.mkdir()
|
|
|
|
if args.list_blobs:
|
|
list_blobs(args)
|
|
else:
|
|
gen_blobs(args)
|