diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 0a664dbd..a46f03e5 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -52,7 +52,7 @@ typedef half (*vec_dot_KQ_f16_t)( typedef float (*vec_dot_KQ_f32_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -62,7 +62,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -92,7 +92,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -102,7 +102,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -142,7 +142,7 @@ static __device__ __forceinline__ int get_one_int_from_table_16(const int & q4) return *((const int *) &val0_8); } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -152,7 +152,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -179,7 +179,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -189,7 +189,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -226,7 +226,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -277,7 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -287,7 +287,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -320,7 +320,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -330,7 +330,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_0; @@ -353,7 +353,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { @@ -368,7 +368,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( half2 sum2 = make_half2(0.0f, 0.0f); #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; @@ -384,7 +384,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( float sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; @@ -603,29 +603,29 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v return x[i]; } -template +template constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : - type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : - type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : + type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : nullptr; } -template +template constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : - type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : - type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : + type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : nullptr; } @@ -653,20 +653,20 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } -template // D == head size +template // Dv == V head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) +__launch_bounds__(Dv, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const float2 * __restrict__ VKQ_meta, float * __restrict__ dst) { - VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; - VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; - dst += D * gridDim.y*blockIdx.x; + VKQ_parts += parallel_blocks*Dv * gridDim.y*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; + dst += Dv * gridDim.y*blockIdx.x; const int tid = threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < Dv); __shared__ float2 meta[parallel_blocks]; if (tid < 2*parallel_blocks) { @@ -690,20 +690,20 @@ static __global__ void flash_attn_combine_results( const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint32_t *) &KQ_max_scale) &= ftz_mask; - VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*Dv + blockIdx.y*Dv + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; } - dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; + dst[blockIdx.y*Dv + tid] = VKQ_numerator / VKQ_denominator; } -static void on_no_fattn_vec_case(const int D) { - if (D == 64) { +static void on_no_fattn_vec_case(const int Dk, const int Dv) { + if (Dk == 64 && Dv == 64) { fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); fprintf(stderr, "By default only f16 KV cache is supported.\n"); fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); GGML_ABORT("fatal error"); - } else if (D == 128) { + } else if (Dk == 128 && Dv == 128) { fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); fprintf(stderr, "Supported combinations:\n"); fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n"); @@ -715,14 +715,22 @@ static void on_no_fattn_vec_case(const int D) { fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n"); fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n"); GGML_ABORT("fatal error"); + } + else if (Dk == 192 && Dv == 128) { + fprintf(stderr, "Unsupported KV type combination for head_sizes 192 / 128\n"); + // TODO: add what is supported + } + else if (Dk == 576 && Dv == 512) { + fprintf(stderr, "Unsupported KV type combination for head_sizes 576 / 512\n"); + // TODO: add what is supported } else { - fprintf(stderr, "Unsupported KV type combination for head_size 256.\n"); + fprintf(stderr, "Unsupported KV type combination for head_sizes %d, %d.\n", Dk, Dv); fprintf(stderr, "Only f16 is supported.\n"); GGML_ABORT("fatal error"); } } -template +template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V @@ -838,11 +846,11 @@ void launch_fattn( return; } - const dim3 block_dim_combine(D, 1, 1); + const dim3 block_dim_combine(Dv, 1, 1); const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); const int shmem_combine = 0; - flash_attn_combine_results + flash_attn_combine_results <<>> (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); CUDA_CHECK(cudaGetLastError()); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index d1bbf01f..bf2a4521 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -291,13 +291,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int D = 64; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 25908d7a..28846561 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F32 32 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -44,8 +44,9 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { + static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)); // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } @@ -61,15 +62,22 @@ static __global__ void flash_attn_tile_ext_f32( const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; - const int stride_KV2 = nb11 / sizeof(half2); + const int stride_K2 = nb11 / sizeof(half2); + const int stride_V2 = nb12 / sizeof(half2); const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + // TODO: is it Dk or Dv or both that need to be multiple of 2*WARP_SIZE ? + // let's assume it is is both. + static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64."); + static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64."); + + constexpr int Dkv = Dk < Dv ? Dv : Dk; // let's use this when we don't understand if it is Dk or Dv __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32]; - __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts. + // This is being used to store either K or V data. Hence we need max(Dk, Dv) as the dimension + __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][Dkv + 1]; // Pad D to avoid memory bank conflicts. float2 * KV_tmp2 = (float2 *) KV_tmp; float kqmax[ncols/nwarps]; @@ -79,16 +87,16 @@ static __global__ void flash_attn_tile_ext_f32( } float kqsum[ncols/nwarps] = {0.0f}; - float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; + float2 VKQ[ncols/nwarps][(Dv/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; // Convert Q to half2 and store in registers: - __shared__ float Q_f[ncols][D]; + __shared__ float Q_f[ncols][Dk]; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { + for (int i0 = 0; i0 < Dk; i0 += 2*WARP_SIZE) { float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f); Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale; Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale; @@ -112,8 +120,8 @@ static __global__ void flash_attn_tile_ext_f32( const int i_KQ = i_KQ_0 + threadIdx.y; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { - const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; + for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 2*WARP_SIZE) { + const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_K2 + k_KQ_0/2 + threadIdx.x]; KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp); KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); } @@ -124,7 +132,7 @@ static __global__ void flash_attn_tile_ext_f32( float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}}; #pragma unroll - for (int k_KQ = 0; k_KQ < D; ++k_KQ) { + for (int k_KQ = 0; k_KQ < Dk; ++k_KQ) { float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE]; float Q_k[ncols/nwarps]; @@ -193,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f32( kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; } @@ -206,11 +214,11 @@ static __global__ void flash_attn_tile_ext_f32( const int k = k0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); - KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); + KV_tmp2[k*(Dv/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]); + KV_tmp2[k*(Dv/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]); } } @@ -218,14 +226,14 @@ static __global__ void flash_attn_tile_ext_f32( #pragma unroll for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) { - float2 V_k[(D/2)/WARP_SIZE]; + float2 V_k[(Dv/2)/WARP_SIZE]; float KQ_k[ncols/nwarps]; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i]; + V_k[i0/WARP_SIZE] = KV_tmp2[k*(Dv/2) + i]; } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { @@ -235,7 +243,7 @@ static __global__ void flash_attn_tile_ext_f32( } #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps]; @@ -259,7 +267,7 @@ static __global__ void flash_attn_tile_ext_f32( kqsum_j = warp_reduce_sum(kqsum_j); #pragma unroll - for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { + for (int i00 = 0; i00 < Dv; i00 += 2*WARP_SIZE) { const int i0 = i00 + 2*threadIdx.x; float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; @@ -268,8 +276,8 @@ static __global__ void flash_attn_tile_ext_f32( dst_val.y /= kqsum_j; } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 0] = dst_val.x; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 1] = dst_val.y; } if (parallel_blocks != 1 && threadIdx.x == 0) { @@ -285,14 +293,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 7f14e78b..83d7bd49 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -302,7 +302,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); } template diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 1aa88272..36289dcf 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -283,7 +283,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); } template diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index efe78a2f..4804a9ef 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -5,7 +5,7 @@ #include #endif // FP16_MMA_AVAILABLE -// D == head size, VKQ_stride == num VKQ rows calculated in parallel: +// Dk == K head size, Dv = V head size, VKQ_stride == num VKQ rows calculated in parallel: template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) @@ -464,7 +464,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm fattn_kernel_t fattn_kernel = softcap == 0.0f ? flash_attn_ext_f16 : flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { @@ -472,14 +472,14 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm fattn_kernel_t fattn_kernel = softcap == 0.0f ? flash_attn_ext_f16 : flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } constexpr int parallel_blocks = 1; fattn_kernel_t fattn_kernel = softcap == 0.0f ? flash_attn_ext_f16 : flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } #define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \ diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index c15d6c81..c93da895 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -233,7 +233,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0) #endif // GGML_CUDA_FA_ALL_QUANTS - on_no_fattn_vec_case(Q->ne[0]); + on_no_fattn_vec_case(Q->ne[0], V->ne[0]); } #define FATTN_VEC_F32_CASE(D, type_K, type_V) \ @@ -308,7 +308,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) #endif // GGML_CUDA_FA_ALL_QUANTS - on_no_fattn_vec_case(Q->ne[0]); + on_no_fattn_vec_case(Q->ne[0], V->ne[0]); } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {