mxfp4: minor CUDA tweaks

This commit is contained in:
Iwan Kawrakow
2025-08-09 08:15:37 +03:00
parent 34bb912db1
commit 80bdee3f85
2 changed files with 8 additions and 11 deletions

View File

@@ -2122,7 +2122,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
union { float f; uint32_t u; } helper;
constexpr uint32_t uval[2] = { 0x00200000, 0x00400000 };
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
@@ -2133,12 +2132,12 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbxd;
helper.u = bxi->e >= 2 ? uint32_t(bxi->e - 1) << 23u : uval[bxi->e];
helper.u = bxi->e ? uint32_t(bxi->e) << 23u : 0x00400000;
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = helper.f;
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = 0.5f * helper.f;
#else
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = helper.f;
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = 0.5f * helper.f;
#endif // INT8_MMA_AVAILABLE
}
}

View File

@@ -1186,22 +1186,20 @@ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
const int * q8 = (const int *) bq8_1->qs + iqs;
constexpr uint32_t uval[2] = { 0x00200000, 0x00400000 };
int sumi = 0;
int2 sumi = {0, 0};
#pragma unroll
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
sumi.x = ggml_cuda_dp4a(v.x, q8[l + 0], sumi.x);
sumi.y = ggml_cuda_dp4a(v.y, q8[l + 4], sumi.y);
}
union { float f; uint32_t u; } helper;
helper.u = bq4->e >= 2 ? uint32_t(bq4->e - 1) << 23u : uval[bq4->e];
helper.u = bq4->e ? uint32_t(bq4->e) << 23u : 0x00400000;
return helper.f * __low2float(bq8_1->ds) * sumi;
return 0.5f * helper.f * __low2float(bq8_1->ds) * (sumi.x + sumi.y);
}
#define VDR_IQ4_XS_Q8_1_MMVQ 4