mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
mxfp4: minor CUDA tweaks
This commit is contained in:
@@ -2122,7 +2122,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
||||
|
||||
union { float f; uint32_t u; } helper;
|
||||
constexpr uint32_t uval[2] = { 0x00200000, 0x00400000 };
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
|
||||
@@ -2133,12 +2132,12 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbxd;
|
||||
helper.u = bxi->e >= 2 ? uint32_t(bxi->e - 1) << 23u : uval[bxi->e];
|
||||
helper.u = bxi->e ? uint32_t(bxi->e) << 23u : 0x00400000;
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = helper.f;
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = 0.5f * helper.f;
|
||||
#else
|
||||
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = helper.f;
|
||||
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = 0.5f * helper.f;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1186,22 +1186,20 @@ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
|
||||
|
||||
const int * q8 = (const int *) bq8_1->qs + iqs;
|
||||
|
||||
constexpr uint32_t uval[2] = { 0x00200000, 0x00400000 };
|
||||
|
||||
int sumi = 0;
|
||||
int2 sumi = {0, 0};
|
||||
#pragma unroll
|
||||
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
||||
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
|
||||
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
||||
|
||||
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
||||
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
||||
sumi.x = ggml_cuda_dp4a(v.x, q8[l + 0], sumi.x);
|
||||
sumi.y = ggml_cuda_dp4a(v.y, q8[l + 4], sumi.y);
|
||||
}
|
||||
|
||||
union { float f; uint32_t u; } helper;
|
||||
helper.u = bq4->e >= 2 ? uint32_t(bq4->e - 1) << 23u : uval[bq4->e];
|
||||
helper.u = bq4->e ? uint32_t(bq4->e) << 23u : 0x00400000;
|
||||
|
||||
return helper.f * __low2float(bq8_1->ds) * sumi;
|
||||
return 0.5f * helper.f * __low2float(bq8_1->ds) * (sumi.x + sumi.y);
|
||||
}
|
||||
|
||||
#define VDR_IQ4_XS_Q8_1_MMVQ 4
|
||||
|
||||
Reference in New Issue
Block a user