mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-22 14:14:32 +00:00
Remove unused arguments
This commit is contained in:
@@ -39,8 +39,6 @@ typedef void (* fattn_new_mma_kernel_t)(
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int ne33,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
@@ -1408,7 +1406,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00, const int ne01, const int ne02, const int ne03,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
const int ne31, const int nb31, const int ne33, const int nb33,
|
||||
const int ne31, const int nb31,
|
||||
const int nb01, const int nb02, const int nb03,
|
||||
const int nb11, const int nb12, const int nb13,
|
||||
const int nb21, const int nb22, const int nb23,
|
||||
@@ -1469,7 +1467,6 @@ static __global__ void flash_attn_ext_f16(
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
//const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33));
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + zt_Q) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*z_KV);
|
||||
@@ -1517,7 +1514,6 @@ static __global__ void flash_attn_ext_f16(
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
//const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33));
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + zt_Q) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*z_KV);
|
||||
@@ -1969,7 +1965,6 @@ static void launch_fattn_new_mma(
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
mask ? mask->ne[3] : 0, mask ? mask->nb[3] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
|
||||
Reference in New Issue
Block a user