Add kQKHeaddimForGemmN and kVHeaddimForGemmN in order to support headdim 96

This commit is contained in:
Qianfeng Zhang
2024-10-09 14:27:28 +00:00
parent 0c094daa7e
commit fb496fa2d6
8 changed files with 206 additions and 146 deletions

View File

@@ -455,6 +455,8 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'96' : [FmhaBwdDQDKDVTileSize( 16, 128, 96, 16, 96, 16, 32, 96, 96, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
@@ -801,4 +803,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")