mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Implement fp8 quant for layernorm and rmsnorm (#1814)
This commit is contained in:
@@ -37,7 +37,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
|
||||
DATA_TYPE_MAP = {'fp32' : 'float',
|
||||
'fp16' : 'ck_tile::fp16_t',
|
||||
'bf16' : 'ck_tile::bf16_t',
|
||||
'int8' : 'ck_tile::int8_t'}
|
||||
'int8' : 'ck_tile::int8_t',
|
||||
'fp8' : 'ck_tile::fp8_t'}
|
||||
|
||||
def BOOL_MAP(b_) -> str:
|
||||
if b_:
|
||||
@@ -477,12 +478,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
h_traits = rmsnorm_fwd_codegen.h_traits
|
||||
h_instance = rmsnorm_fwd_codegen.h_instance
|
||||
|
||||
dynamic_quant_out_dtype = ['int8']
|
||||
dynamic_quant_out_dtype = ['int8', 'fp8']
|
||||
# some predefined support range
|
||||
# (prec_i,prec_o) for simplicity this string will be used as key for dict
|
||||
scale_list = [('fp32,fp32')]
|
||||
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
|
||||
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
|
||||
('fp16,int8'), ('bf16,int8'),
|
||||
('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 out
|
||||
#fused_add_list = [0, 1, 2]
|
||||
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
|
||||
fused_add_list = [0, 1]
|
||||
|
||||
Reference in New Issue
Block a user