mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
Implement fp8 quant for layernorm and rmsnorm (#1814)
This commit is contained in:
@@ -33,7 +33,7 @@ target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
|
||||
set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress)
|
||||
|
||||
target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS})
|
||||
|
||||
|
||||
@@ -37,7 +37,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
|
||||
DATA_TYPE_MAP = {'fp32' : 'float',
|
||||
'fp16' : 'ck_tile::fp16_t',
|
||||
'bf16' : 'ck_tile::bf16_t',
|
||||
'int8' : 'ck_tile::int8_t'}
|
||||
'int8' : 'ck_tile::int8_t',
|
||||
'fp8' : 'ck_tile::fp8_t'}
|
||||
|
||||
def BOOL_MAP(b_) -> str:
|
||||
if b_:
|
||||
@@ -477,12 +478,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
h_traits = rmsnorm_fwd_codegen.h_traits
|
||||
h_instance = rmsnorm_fwd_codegen.h_instance
|
||||
|
||||
dynamic_quant_out_dtype = ['int8']
|
||||
dynamic_quant_out_dtype = ['int8', 'fp8']
|
||||
# some predefined support range
|
||||
# (prec_i,prec_o) for simplicity this string will be used as key for dict
|
||||
scale_list = [('fp32,fp32')]
|
||||
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
|
||||
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
|
||||
('fp16,int8'), ('bf16,int8'),
|
||||
('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 out
|
||||
#fused_add_list = [0, 1, 2]
|
||||
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
|
||||
fused_add_list = [0, 1]
|
||||
|
||||
@@ -105,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
prec_sy = "fp32";
|
||||
}
|
||||
|
||||
if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8")
|
||||
if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8" && prec_o != "fp8")
|
||||
{
|
||||
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl;
|
||||
std::cout
|
||||
<< "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -248,7 +250,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
absmax = a > absmax ? a : absmax;
|
||||
}
|
||||
// printf("cpu:absmax:%f\n", absmax);
|
||||
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
|
||||
constexpr ComputeDataType kMaxY =
|
||||
std::is_same<YDataType, ck_tile::fp8_t>::value ? 240.0
|
||||
: std::is_same<YDataType, ck_tile::int8_t>::value ? 127.0
|
||||
: 0.0;
|
||||
ComputeDataType y_scale = absmax / kMaxY;
|
||||
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
@@ -400,6 +406,16 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
|
||||
|
||||
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8"; do
|
||||
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"; do
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
|
||||
@@ -27,7 +27,7 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
|
||||
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
|
||||
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
|
||||
done
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user