mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
This reverts commit 7a93b16ff6.
This commit is contained in:
@@ -118,7 +118,7 @@ FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
@@ -288,7 +288,7 @@ class FmhaFwdApiPool:
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
@@ -417,7 +417,6 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -490,10 +489,6 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't' or (pipeline.F_mask not in ['no', 's_no']):
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
|
||||
@@ -8,8 +8,11 @@
|
||||
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#if __clang_major__ >= 20
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
@@ -2555,5 +2553,3 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif // !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
@@ -2555,5 +2553,3 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif // CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
@@ -252,11 +252,3 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
|
||||
#else // for GPU code
|
||||
#define CK_TILE_USE_OCP_FP8 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
#if __clang_major__ >= 20
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 1
|
||||
#else
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -33,12 +33,12 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
|
||||
@@ -112,13 +112,6 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 192)
|
||||
{
|
||||
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 256)
|
||||
{
|
||||
return 1;
|
||||
|
||||
@@ -13,8 +13,6 @@ static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index
|
||||
return 128;
|
||||
if(len == 160)
|
||||
return 256;
|
||||
if(len == 192)
|
||||
return 192;
|
||||
|
||||
// only length of 96, 160 and power-of-two is supported
|
||||
if(!(len & (len - 1)))
|
||||
|
||||
Reference in New Issue
Block a user