mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
mxfp4: minor CUDA tweaks
This commit is contained in:
@@ -2122,7 +2122,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||||||
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
||||||
|
|
||||||
union { float f; uint32_t u; } helper;
|
union { float f; uint32_t u; } helper;
|
||||||
constexpr uint32_t uval[2] = { 0x00200000, 0x00400000 };
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
|
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
|
||||||
@@ -2133,12 +2132,12 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||||||
}
|
}
|
||||||
|
|
||||||
const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbxd;
|
const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbxd;
|
||||||
helper.u = bxi->e >= 2 ? uint32_t(bxi->e - 1) << 23u : uval[bxi->e];
|
helper.u = bxi->e ? uint32_t(bxi->e) << 23u : 0x00400000;
|
||||||
|
|
||||||
#ifdef INT8_MMA_AVAILABLE
|
#ifdef INT8_MMA_AVAILABLE
|
||||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = helper.f;
|
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = 0.5f * helper.f;
|
||||||
#else
|
#else
|
||||||
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = helper.f;
|
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = 0.5f * helper.f;
|
||||||
#endif // INT8_MMA_AVAILABLE
|
#endif // INT8_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1186,22 +1186,20 @@ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
|
|||||||
|
|
||||||
const int * q8 = (const int *) bq8_1->qs + iqs;
|
const int * q8 = (const int *) bq8_1->qs + iqs;
|
||||||
|
|
||||||
constexpr uint32_t uval[2] = { 0x00200000, 0x00400000 };
|
int2 sumi = {0, 0};
|
||||||
|
|
||||||
int sumi = 0;
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
||||||
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
|
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
|
||||||
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
||||||
|
|
||||||
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
|
sumi.x = ggml_cuda_dp4a(v.x, q8[l + 0], sumi.x);
|
||||||
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
|
sumi.y = ggml_cuda_dp4a(v.y, q8[l + 4], sumi.y);
|
||||||
}
|
}
|
||||||
|
|
||||||
union { float f; uint32_t u; } helper;
|
union { float f; uint32_t u; } helper;
|
||||||
helper.u = bq4->e >= 2 ? uint32_t(bq4->e - 1) << 23u : uval[bq4->e];
|
helper.u = bq4->e ? uint32_t(bq4->e) << 23u : 0x00400000;
|
||||||
|
|
||||||
return helper.f * __low2float(bq8_1->ds) * sumi;
|
return 0.5f * helper.f * __low2float(bq8_1->ds) * (sumi.x + sumi.y);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define VDR_IQ4_XS_Q8_1_MMVQ 4
|
#define VDR_IQ4_XS_Q8_1_MMVQ 4
|
||||||
|
|||||||
Reference in New Issue
Block a user