mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
Minor
This commit is contained in:
@@ -135,7 +135,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
||||
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
|
||||
|
||||
if (use_cp_async) {
|
||||
if constexpr (use_cp_async) {
|
||||
const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
|
||||
|
||||
constexpr int preload = 64;
|
||||
@@ -205,7 +205,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
|
||||
static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
|
||||
|
||||
if (use_cp_async) {
|
||||
if constexpr (use_cp_async) {
|
||||
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
|
||||
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
|
||||
constexpr int stride_j = nwarps * cols_per_warp;
|
||||
|
||||
Reference in New Issue
Block a user