mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Implement fp8 quant for layernorm and rmsnorm (#1814)
This commit is contained in:
@@ -33,7 +33,7 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
|
||||
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress)
|
||||
|
||||
target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS})
|
||||
|
||||
|
||||
@@ -39,7 +39,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
|
||||
DATA_TYPE_MAP = {'fp32' : 'float',
|
||||
'fp16' : 'ck_tile::fp16_t',
|
||||
'bf16' : 'ck_tile::bf16_t',
|
||||
'int8' : 'ck_tile::int8_t'}
|
||||
'int8' : 'ck_tile::int8_t',
|
||||
'fp8' : 'ck_tile::fp8_t'}
|
||||
|
||||
def BOOL_MAP(b_) -> str:
|
||||
if b_:
|
||||
@@ -504,12 +505,13 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
h_traits = layernorm_fwd_codegen.h_traits
|
||||
h_instance = layernorm_fwd_codegen.h_instance
|
||||
|
||||
dynamic_quant_out_dtype = ['int8']
|
||||
dynamic_quant_out_dtype = ['int8', 'fp8']
|
||||
# some predefined support range
|
||||
# (prec_i,prec_o) for simplicity this string will be used as key for dict
|
||||
scale_list = [('fp32,fp32')]
|
||||
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
|
||||
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
|
||||
('fp16,int8'), ('bf16,int8'),
|
||||
('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 or fp8 out
|
||||
types_8bit = ('int8', 'fp8')
|
||||
types_16bit = ('int16', 'fp16', 'bf16')
|
||||
#fused_add_list = [0, 1, 2]
|
||||
|
||||
@@ -20,6 +20,14 @@ auto get_elimit<ck_tile::bf16_t>()
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::int8_t>()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1.0;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -97,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int xbias = arg_parser.get_int("xbias");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
if(fused_quant == 1 && prec_o != "int8")
|
||||
if(fused_quant == 1 && prec_o != "int8" && prec_o != "fp8")
|
||||
{
|
||||
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl;
|
||||
std::cout
|
||||
<< "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -291,7 +301,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
absmax = a > absmax ? a : absmax;
|
||||
}
|
||||
// printf("cpu:absmax:%f\n", absmax);
|
||||
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
|
||||
constexpr ComputeDataType kMaxY =
|
||||
std::is_same<YDataType, ck_tile::fp8_t>::value ? 240.0
|
||||
: std::is_same<YDataType, ck_tile::int8_t>::value ? 127.0
|
||||
: 0.0;
|
||||
ComputeDataType y_scale = absmax / kMaxY;
|
||||
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
@@ -334,7 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
y_residual_buf.FromDevice(y_residual_host_dev.data());
|
||||
}
|
||||
|
||||
auto [rtol, atol] = get_elimit<InDataType>();
|
||||
auto [rtol, atol] = get_elimit<OutDataType>();
|
||||
|
||||
if(x_stride == n)
|
||||
{
|
||||
@@ -452,6 +466,16 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)"
|
||||
|
||||
for fquant in "" "-fquant=1 -prec_o=int8"; do
|
||||
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=1 -prec_o=fp8"; do
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
|
||||
|
||||
Reference in New Issue
Block a user