mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-26 18:17:40 +00:00
[CK] Address a bunch of errors associated with targeting gfx1200 on Windows (#5045) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Still addressing errors that are blocking the merge of TheRock PR: https://github.com/ROCm/TheRock/actions/runs/22545831304/job/65308264096?pr=3382 ## Technical Details 1. There are multiple fmha python scripts that are writing native paths which are confusing cmake. I addressed one of these in an earlier PR https://github.com/ROCm/rocm-libraries/pull/4812 and now I'm addressing more that are exposed with gfx1200 target: ``` [composable_kernel configure] CMake Error at example/ck_tile/50_sparse_attn/CMakeLists.txt:61 (add_library): [composable_kernel configure] Syntax error in cmake code when parsing string [composable_kernel configure] [composable_kernel configure] B:\build\ml-libs\composable_kernel\build\example\ck_tile\50_sparse_attn\fmha_jenga_fwd_d128_fp16_batch_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psddv_nlogits_nbias_nmask_nskip_nsquant_ntrload.cpp [composable_kernel configure] [composable_kernel configure] Invalid character escape '\b'. ``` 2. In the following compiler error we see gemm_prec_str<ADataType, BDataType> being passed as a function to concat(...), instead of being evaluated with the parenthesis operator(), i.e., gemm_prec_str<ADataType, BDataType>(). There are multiples instances of this, I wonder what non-msvc compilers do here: ``` [composable_kernel] FAILED: [code=1] example/ck_tile/38_block_scale_gemm/CMakeFiles/tile_example_gemm_quant.dir/gemm_bquant_quantgrouped_mx_bf16bf8.cpp.obj [composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp:4: [composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/example/ck_tile/38_block_scale_gemm\run_gemm_quant_example.inc:17: [composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/host.hpp:7: [composable_kernel] E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/host/concat.hpp:119:21: error: implicit conversion between pointer-to-function and pointer-to-object is a Microsoft extension [-Werror,-Wmicrosoft-cast] [composable_kernel] 119 | ((oss << sep << rest), ...); [composable_kernel] | ^~~~ [composable_kernel] E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp:248:16: note: in instantiation of function template specialization 'ck_tile::concat<char, char[11], std::basic_string<char> (), std::basic_string<char>>' requested here [composable_kernel] 248 | return concat('_', "gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName()); [composable_kernel] | ^ ``` There are plenty of other places where we use gemm_prec_str with the operator(), so I'm pretty sure these were just typos...but I'd like some eyes on it. 3. There are 2 tests that fail to build on Windows, which I've excluded from the build but will open bug tickets for: 1. gemm_weight_preshuffle 2. grouped_gemm_preshuffle Here's a sample of the compiler error for these tests: ``` [composable_kernel] [16/19] Building HIP object test/ck_tile/grouped_gemm_preshuffle/CMakeFiles/test_ck_tile_grouped_gemm_preshuffle.dir/test_grouped_gemm_preshuffle.cpp.obj [composable_kernel] FAILED: [code=1] test/ck_tile/grouped_gemm_preshuffle/CMakeFiles/test_ck_tile_grouped_gemm_preshuffle.dir/test_grouped_gemm_preshuffle.cpp.obj [composable_kernel] E:\TheRock\build\core\clr\dist\lib\llvm\bin\clang++.exe -DCK_ENABLE_BF16 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_FP8 -DCK_ENABLE_INT8 -DCK_TILE_USE_WMMA=1 -DCK_TIME_KERNEL=1 -DCK_USE_OCP_FP8 -DCK_USE_WMMA -DCK_USE_WMMA_FP8 -DCK_USE_XDL -DDPP_KERNELS -DUSE_PROF_API=1 -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1 -D__HIP_ROCclr__=1 -IE:/TheRock/rocm-libraries/projects/composablekernel/profiler/include -IE:/TheRock/rocm-libraries/projects/composablekernel -IE:/TheRock/rocm-libraries/projects/composablekernel/library/include -IE:/TheRock/rocm-libraries/projects/composablekernel/include -IE:/TheRock/build/ml-libs/composable_kernel/build/include -IE:/TheRock/build/base/half/stage/include -isystem E:/TheRock/build/core/clr/dist/include -isystem E:/TheRock/build/ml-libs/composable_kernel/build/_deps/gtest-src/googletest/include -isystem E:/TheRock/build/ml-libs/composable_kernel/build/_deps/gtest-src/googletest -isystem E:/TheRock/build/ml-libs/composable_kernel/build/_deps/getopt-src/src -O3 -DNDEBUG -std=gnu++20 --offload-arch=gfx1200 -D_DLL -D_MT -Xclang --dependent-lib=msvcrt -Wall -Wextra -Wcomment -Wendif-labels -Wformat -Winit-self -Wreturn-type -Wsequence-point -Wswitch -Wtrigraphs -Wundef -Wuninitialized -Wunreachable-code -Wunused -Wno-reserved-identifier -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt -Wno-unused-template -Wno-missing-field-initializers -Wno-error=deprecated-declarations -Wall -Wextra -Wcomment -Wendif-labels -Wformat -Winit-self -Wreturn-type -Wsequence-point -Wswitch -Wtrigraphs -Wundef -Wuninitialized -Wunreachable-code -Wunused -Wno-reserved-identifier -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt -Wno-unused-template -Weverything -Wno-c++98-compat -Wno-c++98-compat-pedantic -Wno-conversion -Wno-double-promotion -Wno-exit-time-destructors -Wno-extra-semi -Wno-float-conversion -Wno-gnu-anonymous-struct -Wno-gnu-zero-variadic-macro-arguments -Wno-missing-prototypes -Wno-nested-anon-types -Wno-padded -Wno-return-std-move-in-c++11 -Wno-shorten-64-to-32 -Wno-sign-conversion -Wno-unknown-warning-option -Wno-unused-command-line-argument -Wno-weak-vtables -Wno-covered-switch-default -Wno-unsafe-buffer-usage -Wno-unused-lambda-capture -Wno-nvcc-compat -Wno-c++20-compat -Wno-bit-int-extension -Wno-pass-failed -Wno-switch-default -Wno-unique-object-duplication -fbracket-depth=1024 -Wno-nrvo -Werror -Weverything -fcolor-diagnostics -Wno-c++20-extensions -Wno-global-constructors -Wno-undef -DCK_TILE_USE_OCP_FP8 -MD -MT test/ck_tile/grouped_gemm_preshuffle/CMakeFiles/test_ck_tile_grouped_gemm_preshuffle.dir/test_grouped_gemm_preshuffle.cpp.obj -MF test\ck_tile\grouped_gemm_preshuffle\CMakeFiles\test_ck_tile_grouped_gemm_preshuffle.dir\test_grouped_gemm_preshuffle.cpp.obj.d -o test/ck_tile/grouped_gemm_preshuffle/CMakeFiles/test_ck_tile_grouped_gemm_preshuffle.dir/test_grouped_gemm_preshuffle.cpp.obj -x hip -c E:/TheRock/rocm-libraries/projects/composablekernel/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp [composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp:8: [composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/host.hpp:6: [composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/host/check_err.hpp:16: [composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/core.hpp:89: [composable_kernel] E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/core/utility/env.hpp:110:31: warning: 'getenv' is deprecated: This function or variable may be unsafe. Consider using _dupenv_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details. [-Wdeprecated-declarations] [composable_kernel] 110 | const char* vp = std::getenv(name); [composable_kernel] | ^ [composable_kernel] C:\Program Files (x86)\Windows Kits\10\include\10.0.22621.0\ucrt\stdlib.h:1183:20: note: 'getenv' has been explicitly marked deprecated here [composable_kernel] 1183 | _Check_return_ _CRT_INSECURE_DEPRECATE(_dupenv_s) [composable_kernel] | ^ [composable_kernel] C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\include\vcruntime.h:368:55: note: expanded from macro '_CRT_INSECURE_DEPRECATE' [composable_kernel] 368 | #define _CRT_INSECURE_DEPRECATE(_Replacement) _CRT_DEPRECATE_TEXT( \ [composable_kernel] | ^ [composable_kernel] C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\include\vcruntime.h:358:47: note: expanded from macro '_CRT_DEPRECATE_TEXT' [composable_kernel] 358 | #define _CRT_DEPRECATE_TEXT(_Text) __declspec(deprecated(_Text)) [composable_kernel] | ^ [composable_kernel] clang++: error: clang frontend command failed due to signal (use -v to see invocation) [composable_kernel] AMD clang version 22.0.0git (https://github.com/ROCm/llvm-project.git a2dc42b87c63e686377a69f09ea23aec7550babc+PATCHED:e4d5bf498b7b8626bb9716f1f5a5946d45025918) [composable_kernel] Target: x86_64-pc-windows-msvc [composable_kernel] Thread model: posix [composable_kernel] InstalledDir: E:\TheRock\build\core\clr\dist\lib\llvm\bin [composable_kernel] clang++: note: diagnostic msg: Error generating preprocessed source(s). [composable_kernel] ninja: build stopped: subcommand failed. [composable_kernel FAILED WITH CODE 1 in 238 seconds] ninja: build stopped: subcommand failed. ``` ## Test Plan Wait for internal CI and make sure build compiles locally. ## Test Result Waiting on CI ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1663 lines
49 KiB
Python
1663 lines
49 KiB
Python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
# generate kernel instances to speed up compilation
|
|
|
|
|
|
import argparse
|
|
from enum import IntEnum
|
|
from pathlib import Path
|
|
import sys
|
|
from typing import List, Any
|
|
import functools
|
|
import itertools
|
|
import copy
|
|
from dataclasses import dataclass
|
|
|
|
|
|
def get_if_str(idx, total, lase_else=True):
|
|
if idx == 0:
|
|
return "if"
|
|
elif idx < total - 1:
|
|
return "else if"
|
|
else:
|
|
if lase_else:
|
|
return "else"
|
|
else:
|
|
return "else if"
|
|
|
|
|
|
XBIAS_ENUM_STR_MAP = ["no", "xbias"] # pre-norm add bias
|
|
|
|
FUSED_ADD_ENUM_STR_MAP = [
|
|
"no",
|
|
"pras", # pre-norm
|
|
"pra",
|
|
] # post-norm
|
|
|
|
FUSED_FUSED_SWEEP_STR_MAP = ["no", "dquant"]
|
|
|
|
DATA_TYPE_MAP = {
|
|
"fp32": "float",
|
|
"fp16": "ck_tile::fp16_t",
|
|
"bf16": "ck_tile::bf16_t",
|
|
"int8": "ck_tile::int8_t",
|
|
"fp8": "ck_tile::fp8_t",
|
|
}
|
|
|
|
|
|
def BOOL_MAP(b_) -> str:
|
|
if b_:
|
|
return "true"
|
|
else:
|
|
return "false"
|
|
|
|
|
|
class layernorm_fwd_codegen:
|
|
API_TRAITS_DEFINE = """
|
|
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
|
template <typename XDataType_,
|
|
typename YDataType_,
|
|
typename SmoothScaleDataType_,
|
|
typename YScaleDataType_,
|
|
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
|
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
|
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
|
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
|
ck_tile::index_t Vector_N_, // vector size along N
|
|
bool kPadN_,
|
|
bool kSaveMeanInvStd_,
|
|
bool kFastFDiv_,
|
|
bool kWelford_,
|
|
bool kTwoPass_,
|
|
ck_tile::index_t kXbias_ = 0,
|
|
ck_tile::index_t kFusedAdd_ = 0,
|
|
ck_tile::index_t kFusedQuant_ = 0>
|
|
struct layernorm2d_fwd_traits_
|
|
{
|
|
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
|
|
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
|
|
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
|
|
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
|
|
|
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
|
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
|
static constexpr ck_tile::index_t total_warps =
|
|
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
|
|
|
// num of warps along m
|
|
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
|
if constexpr(is_warp_per_row)
|
|
{
|
|
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
|
return total_warps;
|
|
}
|
|
else
|
|
{
|
|
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
|
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
|
}
|
|
}();
|
|
|
|
// num of warps along n
|
|
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
|
if constexpr(is_warp_per_row)
|
|
{
|
|
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
|
return 1;
|
|
}
|
|
else
|
|
{
|
|
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
|
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
|
}
|
|
}();
|
|
|
|
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
|
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
|
|
|
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
|
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
|
|
|
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
|
using Vector = ck_tile::sequence<1, Vector_N_>;
|
|
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
|
|
|
|
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
|
|
|
static constexpr bool kPadN = kPadN_;
|
|
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
|
static constexpr bool kFastFDiv = kFastFDiv_;
|
|
static constexpr bool kWelford = kWelford_;
|
|
static constexpr bool kTwoPass = kTwoPass_;
|
|
static constexpr ck_tile::index_t kXbias = kXbias_;
|
|
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
|
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
|
};
|
|
|
|
template <typename XDataType_,
|
|
typename YDataType_,
|
|
typename SmoothScaleDataType_,
|
|
typename YScaleDataType_,
|
|
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
|
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
|
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
|
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
|
ck_tile::index_t Vector_N_, // vector size along N
|
|
bool kPadN_,
|
|
bool kSaveMeanInvStd_,
|
|
bool kFastFDiv_,
|
|
bool kWelford_,
|
|
bool kTwoPass_,
|
|
int kXbias_,
|
|
int kFusedAdd_,
|
|
int kFusedQuant_>
|
|
using traits_ = layernorm2d_fwd_traits_<XDataType_,
|
|
YDataType_,
|
|
SmoothScaleDataType_,
|
|
YScaleDataType_,
|
|
Repeat_M_,
|
|
Repeat_N_,
|
|
ThreadPerBlock_M_,
|
|
ThreadPerBlock_N_,
|
|
Vector_N_,
|
|
kPadN_,
|
|
kSaveMeanInvStd_,
|
|
kFastFDiv_,
|
|
kWelford_,
|
|
kTwoPass_,
|
|
kXbias_,
|
|
kFusedAdd_,
|
|
kFusedQuant_>;
|
|
"""
|
|
API_COMMON_HEADER = """
|
|
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include <ck_tile/core.hpp>
|
|
#include "layernorm2d_fwd.hpp"
|
|
#include <ck_tile/ops/epilogue.hpp>
|
|
#include <iostream>
|
|
|
|
#pragma once
|
|
|
|
using S = ck_tile::stream_config;
|
|
using A = layernorm2d_fwd_args;
|
|
|
|
{F_traits_define}
|
|
|
|
template <typename Traits_>
|
|
float layernorm2d_fwd_(const S& s, A a)
|
|
{{
|
|
using XDataType = typename Traits_::XDataType;
|
|
using YDataType = typename Traits_::YDataType;
|
|
using SmoothScaleDataType = typename Traits_::SmoothScaleDataType;
|
|
using YScaleDataType = typename Traits_::YScaleDataType;
|
|
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType;
|
|
|
|
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
|
|
Traits_::kSaveMeanInvStd,
|
|
Traits_::kFastFDiv,
|
|
Traits_::kWelford,
|
|
Traits_::kTwoPass,
|
|
static_cast<ck_tile::Layernorm2dXBiasEnum>(Traits_::kXbias),
|
|
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
|
|
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
|
|
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
|
|
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XBiasDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::GammaDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::BetaDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::MeanDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::InvStdDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::SmoothScaleDataType,
|
|
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YScaleDataType,
|
|
typename Traits_::Shape,
|
|
PipelineTraits>;
|
|
|
|
using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass<PipelineProblem>;
|
|
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
|
|
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
|
|
|
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, true>;
|
|
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
|
|
|
|
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
|
|
static constexpr bool UseRawStore = sizeof(YDataType) == 4;
|
|
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
|
|
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, UseRawStore, true/*max3*/>>;
|
|
|
|
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
|
|
|
|
using Epilogue = std::conditional_t<Traits_::kFusedQuant == 1, DynamicQuantEpilogue, Default2DEpilogue>;
|
|
|
|
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
|
|
|
|
const dim3 grids = Kernel::GridSize(a);
|
|
const dim3 blocks = Kernel::BlockSize();
|
|
constexpr ck_tile::index_t kBlockPerCu = 1;
|
|
|
|
auto kargs = Kernel::MakeKargs(a);
|
|
if(s.log_level_ > 0)
|
|
std::cout << ", " << Kernel::GetName() << std::flush;
|
|
|
|
return ck_tile::launch_kernel(
|
|
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
|
|
}}
|
|
|
|
"""
|
|
|
|
API_BASE = """
|
|
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include <ck_tile/core.hpp>
|
|
#include "layernorm2d_fwd.hpp"
|
|
|
|
{F_traits_define}
|
|
|
|
// Note: this internal API only declare, not define here, otherwise will block `make -j`
|
|
template <typename Traits_>
|
|
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a);
|
|
|
|
float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
|
layernorm2d_fwd_args a,
|
|
const ck_tile::stream_config& s)
|
|
{{
|
|
float r = -1;
|
|
{F_dispatch}
|
|
return r;
|
|
}}
|
|
|
|
"""
|
|
|
|
API_PER_DTYPE = """ {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{
|
|
{F_per_n_case}
|
|
}}
|
|
"""
|
|
API_PER_N_CASE = """ {F_if} {F_N_COND} {{
|
|
{F_inner_dispatch}
|
|
}}
|
|
"""
|
|
API_INNER_CASE = """ {F_if} {F_VEC_COND}
|
|
r={F_instance_func}(s, a);
|
|
"""
|
|
|
|
INSTANCE_BASE = """
|
|
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include "layernorm2d_fwd_api_common.hpp"
|
|
|
|
// clang-format off
|
|
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf welford 2p xbias add sweep
|
|
{F_instance_def}
|
|
// clang-format on
|
|
|
|
"""
|
|
|
|
def __init__(self, working_path, kernel_filter):
|
|
self.working_path = working_path
|
|
self.kernel_filter = kernel_filter
|
|
|
|
class k_xbias_enum(IntEnum):
|
|
F_NO_XBIAS = 0
|
|
F_ADD_XBIAS = 1
|
|
|
|
class k_fuesd_add_enum(IntEnum):
|
|
F_NO_ADD = 0
|
|
F_PRE_ADD = 1
|
|
F_PRE_ADD_STORE_RESIDUAL = 2
|
|
|
|
class k_fused_sweep_enum(IntEnum):
|
|
F_NO_SWEEP = 0
|
|
F_RENORM = 1
|
|
F_DYNAMIC_QUANT = 2
|
|
|
|
@dataclass
|
|
class k_traits:
|
|
F_kPadN: bool
|
|
F_kSaveMeanInvStd: bool
|
|
F_kTwoPass: bool
|
|
F_kXbias: Any #: layernorm_fwd_codegen.k_bias_enum
|
|
F_kFusedAdd: Any #: layernorm_fwd_codegen.k_fuesd_add_enum
|
|
F_kFusedQuant: Any #: layernorm_fwd_codegen.k_fused_sweep_enum
|
|
|
|
@dataclass
|
|
class k_shape:
|
|
F_BlockTile: List[int]
|
|
F_WarpPerBlock: List[int]
|
|
F_WarpTile: List[int]
|
|
F_Vector_: List[int]
|
|
|
|
@property
|
|
def F_BlockSize(self) -> int:
|
|
return functools.reduce(lambda a, b: a * b, self.F_WarpTile)
|
|
|
|
@dataclass
|
|
class k_problem:
|
|
F_XDataType: str
|
|
F_XBiasDataType: str
|
|
F_GammaDataType: str
|
|
F_BetaDataType: str
|
|
F_ComputeDataType: str
|
|
F_YDataType: str
|
|
F_MeanDataType: str
|
|
F_InvStdDataType: str
|
|
F_BlockShape: str
|
|
F_Traits: Any # k_traits
|
|
|
|
@dataclass
|
|
class k_pipeline_one_pass:
|
|
F_Problem: Any # k_problem
|
|
|
|
@dataclass
|
|
class k_pipeline_two_pass:
|
|
F_Problem: Any # k_problem
|
|
|
|
@dataclass
|
|
class default_2d_epilogue_problem:
|
|
F_AccDataType: str
|
|
F_ODataType: str
|
|
F_kPadM: bool
|
|
F_kPadN: bool
|
|
|
|
@dataclass
|
|
class default_2d_epilogue:
|
|
F_problem: Any
|
|
|
|
@dataclass
|
|
class k_kernel:
|
|
F_pipeline: Any
|
|
F_epilogue: Any
|
|
|
|
@dataclass
|
|
class h_traits:
|
|
F_XDataType: str
|
|
F_YDataType: str
|
|
F_SmoothScaleDataType: str
|
|
F_YScaleDataType: str
|
|
F_Repeat_M: int
|
|
F_Repeat_N: int
|
|
F_ThreadPerBlock_M: int
|
|
F_ThreadPerBlock_N: int
|
|
F_Vector_N: int
|
|
F_kPadN: bool
|
|
F_kSaveMeanInvStd_: bool
|
|
F_kFastFDiv_: bool
|
|
F_kWelford_: bool
|
|
F_kTwoPass_: bool
|
|
F_kXbias_: int
|
|
F_kFusedAdd: int
|
|
F_kFusedQuant: int
|
|
|
|
@property
|
|
def trait_name(self) -> str:
|
|
t_ = f"{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}"
|
|
t_ += f", {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}"
|
|
t_ += f", {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}"
|
|
return t_
|
|
|
|
# string when calling this kernel
|
|
@property
|
|
def call_name(self) -> str:
|
|
return f"layernorm2d_fwd_<traits_<{self.trait_name}>>"
|
|
|
|
# string when define this kernel
|
|
@property
|
|
def def_name(self) -> str:
|
|
return f"template float layernorm2d_fwd_<traits_<{self.trait_name}>>(const S&, A);"
|
|
|
|
# this class hold kernel under same source file
|
|
@dataclass
|
|
class h_instance:
|
|
F_DataTypePair: str
|
|
F_N: str
|
|
F_xbias: int
|
|
F_add: int
|
|
F_sweep: int
|
|
instance_list: List[Any] # List[h_traits]
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
prec_i, prec_o = self.F_DataTypePair.split(",")
|
|
dtype_str = f"{prec_i}" if prec_i == prec_o else f"{prec_i}_{prec_o}"
|
|
nnn = f"layernorm2d_fwd_{dtype_str}_n{self.F_N}"
|
|
if self.F_xbias != 0:
|
|
nnn = nnn + "_" + XBIAS_ENUM_STR_MAP[self.F_xbias]
|
|
if self.F_add != 0:
|
|
nnn = nnn + "_" + FUSED_ADD_ENUM_STR_MAP[self.F_add]
|
|
if self.F_sweep != 0:
|
|
nnn = nnn + "_" + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
|
|
return nnn
|
|
|
|
@property
|
|
def instance_name(self) -> str:
|
|
return self.name
|
|
|
|
@property
|
|
def content(self) -> str:
|
|
instance_defs = ""
|
|
for ins in self.instance_list:
|
|
instance_defs += ins.def_name + "\n"
|
|
return layernorm_fwd_codegen.INSTANCE_BASE.format(
|
|
F_instance_def=instance_defs
|
|
)
|
|
|
|
@property
|
|
def name_api(self) -> str:
|
|
return "layernorm2d_fwd_api"
|
|
|
|
@property
|
|
def name_common_header(self) -> str:
|
|
return "layernorm2d_fwd_api_common"
|
|
|
|
def content_api(self, args) -> str:
|
|
# 1 sort based on dtype
|
|
t_dtype_dict = dict()
|
|
blobs = self.get_blobs(args)
|
|
for blob in blobs:
|
|
if blob.F_DataTypePair not in t_dtype_dict:
|
|
t_dtype_dict[blob.F_DataTypePair] = {}
|
|
if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]:
|
|
t_dtype_dict[blob.F_DataTypePair][blob.F_N] = []
|
|
t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob)
|
|
|
|
d_str = ""
|
|
for i_d, dtype_ in enumerate(t_dtype_dict):
|
|
blob_per_t = t_dtype_dict[dtype_]
|
|
n_str = ""
|
|
for i_n, n_ in enumerate(blob_per_t):
|
|
blob_per_n = blob_per_t[n_]
|
|
inner_str = ""
|
|
for i_b, b_ in enumerate(blob_per_n):
|
|
# generate single kernel instance file
|
|
# vec_str = ""
|
|
for i_ins, ins in enumerate(b_.instance_list):
|
|
idx_in_n = i_b * len(b_.instance_list) + i_ins
|
|
len_in_n = len(blob_per_n) * len(b_.instance_list)
|
|
# _if = 'if' if i_ins == 0 else 'else if'
|
|
if ins.F_kFusedQuant == 0:
|
|
_sweep_cond = "t.fused_quant == {f_fused_sweep}".format(
|
|
f_fused_sweep=ins.F_kFusedQuant
|
|
)
|
|
elif ins.F_kFusedQuant == 1:
|
|
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == "{f_sx_type}" && t.prec_sy == "{f_sy_type}")'.format(
|
|
f_fused_sweep=ins.F_kFusedQuant,
|
|
f_sx_type=ins.F_SmoothScaleDataType,
|
|
f_sy_type=ins.F_YScaleDataType,
|
|
)
|
|
elif ins.F_kFusedQuant == 2:
|
|
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == "{f_sy_type}")'.format(
|
|
f_fused_sweep=ins.F_kFusedQuant,
|
|
f_sy_type=ins.F_YScaleDataType,
|
|
)
|
|
_cond = "((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))".format(
|
|
f_vec_n=ins.F_Vector_N,
|
|
f_xbias=ins.F_kXbias,
|
|
f_fused_add=ins.F_kFusedAdd,
|
|
f_sweep_cond=_sweep_cond,
|
|
)
|
|
inner_str += self.API_INNER_CASE.format(
|
|
F_if=get_if_str(idx_in_n, len_in_n, False),
|
|
F_VEC_COND=_cond,
|
|
F_instance_func=ins.call_name,
|
|
)
|
|
# inner_str = inner_str + vec_str
|
|
n_cnd = f"(a.n <= {n_})" if isinstance(n_, int) else ""
|
|
n_str += self.API_PER_N_CASE.format(
|
|
F_if=get_if_str(i_n, len(blob_per_t), not isinstance(n_, int)),
|
|
F_N_COND=n_cnd,
|
|
F_inner_dispatch=inner_str,
|
|
)
|
|
prec_i, prec_o = dtype_.split(",")
|
|
d_str += self.API_PER_DTYPE.format(
|
|
F_if=get_if_str(i_d, len(t_dtype_dict), False),
|
|
F_i_type=prec_i,
|
|
F_o_type=prec_o,
|
|
F_per_n_case=n_str,
|
|
)
|
|
|
|
api_base = self.API_BASE.format(
|
|
F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str
|
|
)
|
|
return api_base
|
|
|
|
@property
|
|
def content_common_header(self) -> str:
|
|
return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE)
|
|
|
|
def get_blobs(self, args):
|
|
h_traits = layernorm_fwd_codegen.h_traits
|
|
h_instance = layernorm_fwd_codegen.h_instance
|
|
|
|
dynamic_quant_out_dtype = ["int8", "fp8"]
|
|
# some predefined support range
|
|
# (prec_i,prec_o) for simplicity this string will be used as key for dict
|
|
scale_list = [("fp32,fp32")]
|
|
dtype_list = [
|
|
("fp16,fp16"),
|
|
("bf16,bf16"),
|
|
("fp16,int8"),
|
|
("bf16,int8"),
|
|
("fp16,fp8"),
|
|
("bf16,fp8"),
|
|
] # NOTE: only fused-dynamic-quant use int8 or fp8 out
|
|
types_8bit = ("int8", "fp8")
|
|
types_16bit = ("int16", "fp16", "bf16")
|
|
# fused_add_list = [0, 1, 2]
|
|
# fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant
|
|
xbias_list = [0, 1]
|
|
fused_add_list = [0, 1]
|
|
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
|
|
# rm rn tm tn vn pd mv fdiv welford 2p xbias add sweep
|
|
h_trait_dict = {
|
|
"64": [
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
1,
|
|
8,
|
|
8,
|
|
8,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
1,
|
|
4,
|
|
16,
|
|
4,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
1,
|
|
4,
|
|
64,
|
|
1,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
],
|
|
"128": [
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
1,
|
|
4,
|
|
16,
|
|
8,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
1,
|
|
4,
|
|
64,
|
|
2,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
2,
|
|
4,
|
|
64,
|
|
1,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
],
|
|
"256": [
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
1,
|
|
4,
|
|
64,
|
|
4,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
2,
|
|
4,
|
|
64,
|
|
2,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
4,
|
|
4,
|
|
64,
|
|
1,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
],
|
|
"512": [
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
1,
|
|
4,
|
|
64,
|
|
8,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
2,
|
|
4,
|
|
64,
|
|
4,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
4,
|
|
4,
|
|
64,
|
|
2,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
8,
|
|
4,
|
|
64,
|
|
1,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
],
|
|
"768": [
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
3,
|
|
4,
|
|
64,
|
|
4,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
6,
|
|
4,
|
|
64,
|
|
2,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
12,
|
|
4,
|
|
64,
|
|
1,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
],
|
|
"1024": [
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
1,
|
|
2,
|
|
128,
|
|
8,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
2,
|
|
2,
|
|
128,
|
|
4,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
4,
|
|
2,
|
|
128,
|
|
2,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
4,
|
|
1,
|
|
256,
|
|
1,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
],
|
|
"1536": [
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
3,
|
|
4,
|
|
64,
|
|
8,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
3,
|
|
2,
|
|
128,
|
|
4,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
3,
|
|
1,
|
|
256,
|
|
2,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
6,
|
|
1,
|
|
256,
|
|
1,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
],
|
|
"2048": [
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
1,
|
|
1,
|
|
256,
|
|
8,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
2,
|
|
1,
|
|
256,
|
|
4,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
4,
|
|
1,
|
|
256,
|
|
2,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
8,
|
|
1,
|
|
256,
|
|
1,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
],
|
|
"3072": [
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
3,
|
|
1,
|
|
128,
|
|
8,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
3,
|
|
1,
|
|
256,
|
|
4,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
6,
|
|
1,
|
|
256,
|
|
2,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
3,
|
|
1,
|
|
1024,
|
|
1,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
],
|
|
"4096": [
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
2,
|
|
1,
|
|
256,
|
|
8,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
4,
|
|
1,
|
|
256,
|
|
4,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
2,
|
|
1,
|
|
1024,
|
|
2,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
4,
|
|
1,
|
|
1024,
|
|
1,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
],
|
|
"6144": [
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
3,
|
|
1,
|
|
256,
|
|
8,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
3,
|
|
1,
|
|
512,
|
|
4,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
3,
|
|
1,
|
|
1024,
|
|
2,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
6,
|
|
1,
|
|
1024,
|
|
1,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
],
|
|
"8192": [
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
4,
|
|
1,
|
|
256,
|
|
8,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
4,
|
|
1,
|
|
512,
|
|
4,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
4,
|
|
1,
|
|
1024,
|
|
2,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
8,
|
|
1,
|
|
1024,
|
|
1,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
False,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
],
|
|
"big": [
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
1,
|
|
1,
|
|
1024,
|
|
8,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
True,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
4,
|
|
1,
|
|
256,
|
|
4,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
True,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
12,
|
|
1,
|
|
256,
|
|
2,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
True,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
h_traits(
|
|
"x",
|
|
"y",
|
|
"xs",
|
|
"ys",
|
|
1,
|
|
4,
|
|
1,
|
|
1024,
|
|
1,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
True,
|
|
0,
|
|
0,
|
|
0,
|
|
),
|
|
],
|
|
}
|
|
total_blob = list()
|
|
for hs_key in h_trait_dict:
|
|
hs = h_trait_dict[hs_key]
|
|
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
|
|
for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(
|
|
dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list
|
|
):
|
|
prec_i, prec_o = dtype.split(",")
|
|
scale_sm, scale_y = scale_type.split(",")
|
|
if prec_o in dynamic_quant_out_dtype and fused_quant != 1:
|
|
continue # skip non dynamic quant case
|
|
if fused_quant == 1 and hs_key == "big":
|
|
continue
|
|
current_hs = list()
|
|
for chs_ in hs:
|
|
h_ = copy.copy(chs_) # copy the base instance out
|
|
h_.F_XDataType = prec_i
|
|
h_.F_YDataType = prec_o
|
|
h_.F_SmoothScaleDataType = scale_sm
|
|
h_.F_YScaleDataType = scale_y
|
|
h_.F_kXbias = xbias
|
|
h_.F_kFusedAdd = fused_add
|
|
h_.F_kFusedQuant = fused_quant
|
|
# disable welford update for 8bit and 16 bit smallN
|
|
if not h_.F_kTwoPass_:
|
|
# disable 16 bit when set args disable_16b_welford
|
|
if args.disable_16b_welford and prec_i in types_16bit:
|
|
h_.F_kWelford_ = False
|
|
# disable 8bit by default
|
|
elif prec_i in types_8bit or prec_o in types_8bit:
|
|
h_.F_kWelford_ = False
|
|
# disable 16bit small N
|
|
elif prec_i in types_16bit and hs_key == "64":
|
|
h_.F_kWelford_ = False
|
|
current_hs.append(h_) # + "\n"
|
|
# f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
|
current_n_str = "big" if hs_key == "big" else current_n
|
|
total_blob.append(
|
|
h_instance(
|
|
dtype, current_n_str, xbias, fused_add, fused_quant, current_hs
|
|
)
|
|
)
|
|
return total_blob
|
|
|
|
def list_blobs(self, args) -> None:
|
|
w_p = Path(self.working_path)
|
|
list_p = w_p / "layernorm2d_fwd_blobs.txt"
|
|
blobs = self.get_blobs(args)
|
|
with list_p.open("w") as list_f:
|
|
# api related file
|
|
list_f.write((w_p / (self.name_api + ".cpp")).as_posix() + "\n")
|
|
list_f.write((w_p / (self.name_common_header + ".hpp")).as_posix() + "\n")
|
|
# kernel instance file
|
|
for b in blobs:
|
|
list_f.write((w_p / (b.name + ".cpp")).as_posix() + "\n")
|
|
|
|
def gen_blobs(self, args) -> None:
|
|
w_p = Path(self.working_path)
|
|
w_str = self.content_api(args)
|
|
(w_p / (self.name_api + ".cpp")).write_text(w_str)
|
|
(w_p / (self.name_common_header + ".hpp")).write_text(
|
|
self.content_common_header
|
|
)
|
|
blobs = self.get_blobs(args)
|
|
for b in blobs:
|
|
(w_p / (b.name + ".cpp")).write_text(b.content)
|
|
|
|
|
|
def list_blobs(args):
|
|
api_list = args.api.split(",")
|
|
for api in api_list:
|
|
if api == "fwd":
|
|
layernorm_fwd_codegen(args.working_path, args.filter).list_blobs(args)
|
|
|
|
|
|
def gen_blobs(args):
|
|
api_list = args.api.split(",")
|
|
for api in api_list:
|
|
if api == "fwd":
|
|
layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
prog="generate",
|
|
description="gen API for CK layernorm kernel",
|
|
)
|
|
parser.add_argument(
|
|
"-a",
|
|
"--api",
|
|
default="fwd[all]",
|
|
required=False,
|
|
help="supply API(s) to generate (default: fwd). separated by comma.",
|
|
)
|
|
|
|
# the directory for list_blobs/gen_blobs to write files into
|
|
parser.add_argument(
|
|
"-w",
|
|
"--working_path",
|
|
default="./",
|
|
required=False,
|
|
help="the path where all the blobs are going to be generated",
|
|
)
|
|
|
|
# this script have 2 modes
|
|
# 1) list_blobs mode, will generate a txt file with all the files going to be generated.
|
|
# this is useful in build system like cmake to construct source code dependency, by
|
|
# reading the content out of this file
|
|
# 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework
|
|
# like FA, only need to use this mode
|
|
parser.add_argument(
|
|
"-l",
|
|
"--list_blobs",
|
|
action="store_true",
|
|
help="list all the kernels to a file, ",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-g",
|
|
"--gen_blobs",
|
|
action="store_true",
|
|
help="generate all kernels into different tile",
|
|
)
|
|
|
|
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
|
parser.add_argument(
|
|
"-f",
|
|
"--filter",
|
|
required=False,
|
|
help="filter out kernels that need to generate, using fnmatch module",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-t",
|
|
"--traits",
|
|
default="all",
|
|
required=False,
|
|
help="enable/disable some feature. default generate all",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-r", "--receipt", default=0, required=False, help="codegen receipt."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--disable_16b_welford",
|
|
default=False,
|
|
required=False,
|
|
help="enable/disable welford for 16bit datatype n > 64",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# print(f'{args.list_blobs}-{args.gen_blobs}')
|
|
if (args.gen_blobs and args.list_blobs) or (
|
|
(not args.gen_blobs) and (not args.list_blobs)
|
|
):
|
|
print("gen_blobs/list_blobs must specify only one option")
|
|
sys.exit()
|
|
|
|
p = Path(args.working_path)
|
|
if not p.exists():
|
|
p.mkdir()
|
|
|
|
if args.list_blobs:
|
|
list_blobs(args)
|
|
else:
|
|
gen_blobs(args)
|