mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
* Reapply "[CK_TILE] support hdim=192/128 pair for deepseekv3 (#1961)" (#1969)
This reverts commit 8cbcd3e0d0.
* fix codegen problem
* Update config.hpp
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -8,11 +8,8 @@
|
||||
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
|
||||
#if __clang_major__ == 20
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
@@ -2553,3 +2555,5 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif // !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
@@ -2553,3 +2555,5 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif // CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
@@ -252,3 +252,11 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
|
||||
#else // for GPU code
|
||||
#define CK_TILE_USE_OCP_FP8 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
#if __clang_major__ == 20
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 1
|
||||
#else
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -33,12 +33,12 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
|
||||
@@ -112,6 +112,13 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 192)
|
||||
{
|
||||
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 256)
|
||||
{
|
||||
return 1;
|
||||
|
||||
@@ -13,6 +13,8 @@ static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index
|
||||
return 128;
|
||||
if(len == 160)
|
||||
return 256;
|
||||
if(len == 192)
|
||||
return 192;
|
||||
|
||||
// only length of 96, 160 and power-of-two is supported
|
||||
if(!(len & (len - 1)))
|
||||
|
||||
Reference in New Issue
Block a user