mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
test fp16 ok
This commit is contained in:
@@ -638,7 +638,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
return None
|
||||
|
||||
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'bf16':
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
### '32' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
### '64' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
@@ -736,7 +736,7 @@ def get_batch_decode_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
continue
|
||||
# Aiter(batch_decode) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in [ 'bf16']
|
||||
cond = dtype in [ 'fp16','bf16']
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias == 'no'
|
||||
|
||||
Reference in New Issue
Block a user