Remove unused arguments

This commit is contained in:
Kawrakow
2026-01-27 17:26:26 +00:00
parent 574cf2cd2d
commit 345545d1be

View File

@@ -39,8 +39,6 @@ typedef void (* fattn_new_mma_kernel_t)(
const int ne13,
const int ne31,
const int nb31,
const int ne33,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@@ -1408,7 +1406,7 @@ static __global__ void flash_attn_ext_f16(
const uint32_t n_head_log2,
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne10, const int ne11, const int ne12, const int ne13,
const int ne31, const int nb31, const int ne33, const int nb33,
const int ne31, const int nb31,
const int nb01, const int nb02, const int nb03,
const int nb11, const int nb12, const int nb13,
const int nb21, const int nb22, const int nb23,
@@ -1469,7 +1467,6 @@ static __global__ void flash_attn_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
//const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33));
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + zt_Q) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*z_KV);
@@ -1517,7 +1514,6 @@ static __global__ void flash_attn_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
//const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33));
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + zt_Q) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*z_KV);
@@ -1969,7 +1965,6 @@ static void launch_fattn_new_mma(
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
mask ? mask->ne[3] : 0, mask ? mask->nb[3] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
nb11, nb12, nb13,
nb21, nb22, nb23,