mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-29 18:51:15 +00:00
[fix]: add amx optimization
This commit is contained in:
@@ -177,7 +177,7 @@ class MOEBindings {
|
||||
const uintptr_t physical_to_logical_map = 0) {
|
||||
Args* args = new Args{nullptr, moe.get()};
|
||||
if (physical_to_logical_map) {
|
||||
printf("debug physical_to_logical_map in arg:%lu\n", physical_to_logical_map);
|
||||
// printf("debug physical_to_logical_map in arg:%lu\n", physical_to_logical_map);
|
||||
moe->config.physical_to_logical_map = reinterpret_cast<void*>(physical_to_logical_map);
|
||||
printf("moe ptr:%p,confirm: moe->config.physical_to_logical_map:%lu\n", reinterpret_cast<void*>(moe.get()),
|
||||
reinterpret_cast<uintptr_t>(moe->config.physical_to_logical_map));
|
||||
|
||||
@@ -48,13 +48,37 @@ struct BufferAImpl {
|
||||
assert(ith == 0 && nth == 1);
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
|
||||
float amax = 0.0f;
|
||||
for (int j = 0; j < k; j += 32) {
|
||||
__m512 f0, f1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j), &f0, &f1);
|
||||
amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));
|
||||
amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));
|
||||
__m512 amax_v0 = _mm512_setzero_ps();
|
||||
__m512 amax_v1 = _mm512_setzero_ps();
|
||||
__m512 amax_v2 = _mm512_setzero_ps();
|
||||
__m512 amax_v3 = _mm512_setzero_ps();
|
||||
__m512 amax_v4 = _mm512_setzero_ps();
|
||||
__m512 amax_v5 = _mm512_setzero_ps();
|
||||
__m512 amax_v6 = _mm512_setzero_ps();
|
||||
__m512 amax_v7 = _mm512_setzero_ps();
|
||||
for (int j = 0; j < k; j += 128) {
|
||||
__m512 f0, f1, f2, f3, f4, f5, f6, f7;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 0), &f0, &f1);
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 32), &f2, &f3);
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 64), &f4, &f5);
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 96), &f6, &f7);
|
||||
amax_v0 = vector_abs_max(amax_v0, f0);
|
||||
amax_v1 = vector_abs_max(amax_v1, f1);
|
||||
amax_v2 = vector_abs_max(amax_v2, f2);
|
||||
amax_v3 = vector_abs_max(amax_v3, f3);
|
||||
amax_v4 = vector_abs_max(amax_v4, f4);
|
||||
amax_v5 = vector_abs_max(amax_v5, f5);
|
||||
amax_v6 = vector_abs_max(amax_v6, f6);
|
||||
amax_v7 = vector_abs_max(amax_v7, f7);
|
||||
}
|
||||
amax_v0 = vector_abs_max(amax_v0, amax_v1);
|
||||
amax_v2 = vector_abs_max(amax_v2, amax_v3);
|
||||
amax_v4 = vector_abs_max(amax_v4, amax_v5);
|
||||
amax_v6 = vector_abs_max(amax_v6, amax_v7);
|
||||
amax_v0 = vector_abs_max(amax_v0, amax_v2);
|
||||
amax_v4 = vector_abs_max(amax_v4, amax_v6);
|
||||
amax_v0 = vector_abs_max(amax_v0, amax_v4);
|
||||
float amax = _mm512_reduce_max_ps(amax_v0);
|
||||
d[m_begin + i] = amax / ((1 << 7) - 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1181,8 +1181,8 @@ struct GemmKernel224Int4 {
|
||||
|
||||
static void load_a(dt* a, size_t lda) {
|
||||
#ifdef HAVE_AMX
|
||||
_tile_loadd(0, a, lda);
|
||||
_tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);
|
||||
_tile_stream_loadd(0, a, lda);
|
||||
_tile_stream_loadd(1, offset_pointer(a, lda * TILE_M), lda);
|
||||
#else
|
||||
(void)a;
|
||||
(void)lda;
|
||||
|
||||
Reference in New Issue
Block a user