[feat]: patch kml problem (#1704)

This commit is contained in:
ZiWei Yuan
2025-12-10 22:40:29 -08:00
committed by GitHub
parent c65febe05c
commit 53f6a6d6e1
33 changed files with 0 additions and 5401 deletions

View File

@@ -1,56 +0,0 @@
#include <stdexcept>
#include "../batch_gemm_api.hpp"
#include "utils.hpp"
#ifdef __cplusplus
extern "C" {
#endif
void decode_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,
int32_t* c, const size_t ldc, const int32_t* oc) {
BLASINT8* ptrA = (BLASINT8*)a;
BLASINT8* ptrB = (BLASINT8*)b;
int32_t* ptrC = c;
size_t split_n = n / N_SIZE;
for (size_t n_block = 0; n_block < split_n; n_block++) {
BLASINT8* cur_ptrA = ptrA;
BLASINT8* cur_ptrB = ptrB + n_block * (N_SIZE * k);
int32_t* cur_ptrC = ptrC + n_block * N_SIZE;
gemm_kernel_1x8(cur_ptrA, cur_ptrB, cur_ptrC, ldc, k, COMP_SV_LEN);
}
}
void decode_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
const int32_t* oc) {
BLASINT8* ptrA = (BLASINT8*)a;
BLASINT8* ptrB = (BLASINT8*)b;
int32_t* ptrC = c;
size_t split_n = n / N_SIZE;
for (size_t n_block = 0; n_block < split_n; n_block++) {
BLASINT8* cur_ptrA = ptrA;
BLASINT8* cur_ptrB = ptrB + n_block * (N_SIZE * (k / 2));
int32_t* cur_ptrC = ptrC + n_block * N_SIZE;
gemm_kernel_1x8_int4(cur_ptrA, cur_ptrB, cur_ptrC, (ldc / 2), (k / 2), COMP_SV_LEN);
}
}
void reorder_B_gemm(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
const size_t n, const size_t ldb, const void* b, void* b_reordered) {
throw std::runtime_error("haven't supported reorder");
}
size_t get_reorder_B_size(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
const size_t n) {
throw std::runtime_error("haven't supported reorder");
}
#ifdef __cplusplus
}
#endif

View File

@@ -1,414 +0,0 @@
#include "prefillgemm_int4/integer_gemm_kernels.h"
#include "utils.hpp"
#ifdef __cplusplus
extern "C" {
#endif
void pack_b_1x8(void* bufferB, const void* cur_b_ptr, size_t n, size_t k, size_t ldb, const BLASINT8 ob) {
BLASINT8* bufferB_typed = (BLASINT8*)bufferB;
BLASINT8* cur_b_typed = (BLASINT8*)cur_b_ptr;
size_t split_n = n / N_SIZE;
size_t split_k = k / K_SIZE;
// TODO::vectorization
for (size_t np = 0; np < split_n; np++) {
for (size_t n_idx = 0; n_idx < N_SIZE; n_idx++) {
for (size_t kp = 0; kp < split_k; kp++) {
for (size_t k_idx = 0; k_idx < K_SIZE; k_idx++) {
bufferB_typed[np * (N_SIZE * k) + kp * (K_SIZE * N_SIZE) + n_idx * K_SIZE + k_idx] =
cur_b_typed[INDEXING_B((kp * K_SIZE + k_idx), (np * N_SIZE + n_idx), ldb)] + ob;
}
}
}
}
}
void pack_b_1x8_int4(void* bufferB, const void* cur_b_ptr, size_t n, size_t k, size_t ldb, const BLASINT8 ob) {
uint8_t* bufferB_typed = (uint8_t*)bufferB;
uint8_t* cur_b_typed = (uint8_t*)cur_b_ptr;
#define RHS_MASK 0x0F
#define LHS_MASK 0xF0
size_t split_n = n / N_SIZE;
size_t split_k = k / K_SIZE;
// TODO::vectorization
for (size_t np = 0; np < split_n; np++) {
for (size_t n_idx = 0; n_idx < N_SIZE; n_idx++) {
for (size_t kp = 0; kp < split_k; kp++) {
for (size_t k_idx = 0; k_idx < K_SIZE; k_idx += 2) {
uint8_t b01 = cur_b_typed[INDEXING_B((kp * K_SIZE + k_idx / 2), (np * N_SIZE + n_idx), ldb)];
uint8_t b23 = cur_b_typed[INDEXING_B((kp * K_SIZE + k_idx / 2 + K_SIZE / 2), (np * N_SIZE + n_idx), ldb)];
uint8_t b02 = (b01 & LHS_MASK) | ((b23 & LHS_MASK) >> 4);
uint8_t b13 = (b23 & RHS_MASK) | ((b01 & RHS_MASK) << 4);
bufferB_typed[np * (N_SIZE * k) + kp * (K_SIZE * N_SIZE) + n_idx * K_SIZE + k_idx] = b02;
bufferB_typed[np * (N_SIZE * k) + kp * (K_SIZE * N_SIZE) + n_idx * K_SIZE + k_idx + 1] = b13;
}
}
}
}
}
void gemm_kernel_1x8(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len) {
int64_t run_k_depth = k_depth;
int64_t run_sv_len = sv_len;
int64_t run_2sv_len = 2 * sv_len;
int64_t move_lhs = sv_len;
int64_t move_rhs = N_SIZE * sv_len;
int32_t* dst_ptr = accum_ptr;
ldc -= N_SIZE;
ldc *= 4;
asm volatile(
"ptrue p0.b, all\n"
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
"dup z16.s, #0\n"
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
"dup z17.s, #0\n"
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
"dup z18.s, #0\n"
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
"dup z19.s, #0\n"
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
"dup z20.s, #0\n"
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
"dup z21.s, #0\n"
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
"dup z22.s, #0\n"
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
"dup z23.s, #0\n"
"ld1b {z8.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
"ble 1f\n"
"cmp %[run_k_depth], %[run_2sv_len]\n"
"blt 2f\n"
"3:\n"
"ld1b {z9.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
"sdot z16.s, z8.b, z0.b\n"
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
"sdot z17.s, z8.b, z1.b\n"
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
"sdot z18.s, z8.b, z2.b\n"
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
"sdot z19.s, z8.b, z3.b\n"
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
"sdot z20.s, z8.b, z4.b\n"
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
"sdot z21.s, z8.b, z5.b\n"
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
"sdot z22.s, z8.b, z6.b\n"
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
"sdot z23.s, z8.b, z7.b\n"
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
"sub %[run_k_depth], %[run_k_depth], %[run_2sv_len]\n"
"ld1b {z8.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
"sdot z16.s, z9.b, z0.b\n"
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
"sdot z17.s, z9.b, z1.b\n"
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
"sdot z18.s, z9.b, z2.b\n"
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
"sdot z19.s, z9.b, z3.b\n"
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
"sdot z20.s, z9.b, z4.b\n"
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
"sdot z21.s, z9.b, z5.b\n"
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
"sdot z22.s, z9.b, z6.b\n"
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
"sdot z23.s, z9.b, z7.b\n"
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
"cmp %[run_k_depth], %[run_2sv_len]\n"
"bge 3b\n"
"cmp %[run_k_depth], #0\n"
"ble 1f\n"
"2:\n"
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
"sdot z16.s, z8.b, z0.b\n"
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
"sdot z17.s, z8.b, z1.b\n"
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
"sdot z18.s, z8.b, z2.b\n"
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
"sdot z19.s, z8.b, z3.b\n"
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
"sdot z20.s, z8.b, z4.b\n"
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
"sdot z21.s, z8.b, z5.b\n"
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
"sdot z22.s, z8.b, z6.b\n"
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
"sdot z23.s, z8.b, z7.b\n"
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
"ld1b {z8.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
"bgt 2b\n"
"1:\n"
"sdot z16.s, z8.b, z0.b\n"
"sdot z17.s, z8.b, z1.b\n"
"sdot z18.s, z8.b, z2.b\n"
"sdot z19.s, z8.b, z3.b\n"
"sdot z20.s, z8.b, z4.b\n"
"sdot z21.s, z8.b, z5.b\n"
"sdot z22.s, z8.b, z6.b\n"
"sdot z23.s, z8.b, z7.b\n"
PROCESS_ACCUM(0, 16, x16, dst_ptr, p0) PROCESS_ACCUM(1, 17, x17, dst_ptr, p0)
PROCESS_ACCUM(2, 18, x18, dst_ptr, p0) PROCESS_ACCUM(3, 19, x19, dst_ptr, p0)
PROCESS_ACCUM(4, 20, x16, dst_ptr, p0) PROCESS_ACCUM(5, 21, x17, dst_ptr, p0)
PROCESS_ACCUM(6, 22, x18, dst_ptr, p0) PROCESS_ACCUM(7, 23, x19, dst_ptr, p0)
: [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [run_k_depth] "+r"(run_k_depth), [dst_ptr] "+wr"(dst_ptr)
: [run_sv_len] "r"(run_sv_len), [run_2sv_len] "r"(run_2sv_len), [move_lhs] "r"(move_lhs),
[move_rhs] "r"(move_rhs), [ldc] "r"(ldc), [accum_ptr] "r"(accum_ptr)
: "cc", "memory", "w0", "w1", "w2", "w3", "w4", "w5", "w6", "w7", "w8", "w9", "w10", "w11", "w12", "w13", "w14",
"w15", "x16", "x17", "x18", "x19", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11",
"z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27",
"z28", "z29", "z30", "z31");
}
// A: z8 ~ z11
// B: z0 ~ z7
// C: z16 ~ z23
// M: z12 z13
// MB: z14
void gemm_kernel_1x8_int4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len) {
int64_t run_k_depth = k_depth;
int64_t run_sv_len = sv_len;
int64_t run_2sv_len = 2 * sv_len;
int64_t move_lhs = 2 * sv_len;
int64_t move_rhs = N_SIZE * sv_len;
int32_t* dst_ptr = accum_ptr;
ldc -= N_SIZE;
ldc *= 4;
asm volatile(
"ptrue p0.b, all\n"
"mov z12.b, #0xF0\n" //mask high
"mov z13.b, #0x0F\n" //mask low
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
"dup z16.s, #0\n"
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
"dup z17.s, #0\n"
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
"dup z18.s, #0\n"
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
"dup z19.s, #0\n"
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
"dup z20.s, #0\n"
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
"dup z21.s, #0\n"
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
"dup z22.s, #0\n"
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
"dup z23.s, #0\n"
"ld1b {z8.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
"ld1b {z9.b}, p0/z, [%[lhs_ptr], #1, MUL VL]\n"
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
"ble 1f\n"
"cmp %[run_k_depth], %[run_2sv_len]\n"
"blt 2f\n"
"3:\n"
"ld1b {z10.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
"ld1b {z11.b}, p0/z, [%[lhs_ptr], #1, MUL VL]\n"
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
INT4_CP_MASK_SHIFT_1x8(0, 14, 12, 13, 4)
"sdot z16.s, z8.b, z0.b\n"
"sdot z16.s, z9.b, z14.b\n"
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(1, 14, 12, 13, 4)
"sdot z17.s, z8.b, z1.b\n"
"sdot z17.s, z9.b, z14.b\n"
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(2, 14, 12, 13, 4)
"sdot z18.s, z8.b, z2.b\n"
"sdot z18.s, z9.b, z14.b\n"
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(3, 14, 12, 13, 4)
"sdot z19.s, z8.b, z3.b\n"
"sdot z19.s, z9.b, z14.b\n"
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(4, 14, 12, 13, 4)
"sdot z20.s, z8.b, z4.b\n"
"sdot z20.s, z9.b, z14.b\n"
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(5, 14, 12, 13, 4)
"sdot z21.s, z8.b, z5.b\n"
"sdot z21.s, z9.b, z14.b\n"
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(6, 14, 12, 13, 4)
"sdot z22.s, z8.b, z6.b\n"
"sdot z22.s, z9.b, z14.b\n"
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(7, 14, 12, 13, 4)
"sdot z23.s, z8.b, z7.b\n"
"sdot z23.s, z9.b, z14.b\n"
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
"sub %[run_k_depth], %[run_k_depth], %[run_2sv_len]\n"
"ld1b {z8.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
"ld1b {z9.b}, p0/z, [%[lhs_ptr], #1, MUL VL]\n"
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
INT4_CP_MASK_SHIFT_1x8(0, 14, 12, 13, 4)
"sdot z16.s, z10.b, z0.b\n"
"sdot z16.s, z11.b, z14.b\n"
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(1, 14, 12, 13, 4)
"sdot z17.s, z10.b, z1.b\n"
"sdot z17.s, z11.b, z14.b\n"
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(2, 14, 12, 13, 4)
"sdot z18.s, z10.b, z2.b\n"
"sdot z18.s, z11.b, z14.b\n"
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(3, 14, 12, 13, 4)
"sdot z19.s, z10.b, z3.b\n"
"sdot z19.s, z11.b, z14.b\n"
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(4, 14, 12, 13, 4)
"sdot z20.s, z10.b, z4.b\n"
"sdot z20.s, z11.b, z14.b\n"
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(5, 14, 12, 13, 4)
"sdot z21.s, z10.b, z5.b\n"
"sdot z21.s, z11.b, z14.b\n"
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(6, 14, 12, 13, 4)
"sdot z22.s, z10.b, z6.b\n"
"sdot z22.s, z11.b, z14.b\n"
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(7, 14, 12, 13, 4)
"sdot z23.s, z10.b, z7.b\n"
"sdot z23.s, z11.b, z14.b\n"
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
"cmp %[run_k_depth], %[run_2sv_len]\n"
"bge 3b\n"
"cmp %[run_k_depth], #0\n"
"ble 1f\n"
"2:\n"
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
INT4_CP_MASK_SHIFT_1x8(0, 14, 12, 13, 4)
"sdot z16.s, z8.b, z0.b\n"
"sdot z16.s, z9.b, z14.b\n"
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(1, 14, 12, 13, 4)
"sdot z17.s, z8.b, z1.b\n"
"sdot z17.s, z9.b, z14.b\n"
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(2, 14, 12, 13, 4)
"sdot z18.s, z8.b, z2.b\n"
"sdot z18.s, z9.b, z14.b\n"
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(3, 14, 12, 13, 4)
"sdot z19.s, z8.b, z3.b\n"
"sdot z19.s, z9.b, z14.b\n"
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(4, 14, 12, 13, 4)
"sdot z20.s, z8.b, z4.b\n"
"sdot z20.s, z9.b, z14.b\n"
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(5, 14, 12, 13, 4)
"sdot z21.s, z8.b, z5.b\n"
"sdot z21.s, z9.b, z14.b\n"
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(6, 14, 12, 13, 4)
"sdot z22.s, z8.b, z6.b\n"
"sdot z22.s, z9.b, z14.b\n"
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
INT4_CP_MASK_SHIFT_1x8(7, 14, 12, 13, 4)
"sdot z23.s, z8.b, z7.b\n"
"sdot z23.s, z9.b, z14.b\n"
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
"ld1b {z8.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
"ld1b {z9.b}, p0/z, [%[lhs_ptr], #1, MUL VL]\n"
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
"bgt 2b\n"
"1:\n"
INT4_CP_MASK_SHIFT_1x8(0, 14, 12, 13, 4)
"sdot z16.s, z8.b, z0.b\n"
"sdot z16.s, z9.b, z14.b\n"
INT4_CP_MASK_SHIFT_1x8(1, 14, 12, 13, 4)
"sdot z17.s, z8.b, z1.b\n"
"sdot z17.s, z9.b, z14.b\n"
INT4_CP_MASK_SHIFT_1x8(2, 14, 12, 13, 4)
"sdot z18.s, z8.b, z2.b\n"
"sdot z18.s, z9.b, z14.b\n"
INT4_CP_MASK_SHIFT_1x8(3, 14, 12, 13, 4)
"sdot z19.s, z8.b, z3.b\n"
"sdot z19.s, z9.b, z14.b\n"
INT4_CP_MASK_SHIFT_1x8(4, 14, 12, 13, 4)
"sdot z20.s, z8.b, z4.b\n"
"sdot z20.s, z9.b, z14.b\n"
INT4_CP_MASK_SHIFT_1x8(5, 14, 12, 13, 4)
"sdot z21.s, z8.b, z5.b\n"
"sdot z21.s, z9.b, z14.b\n"
INT4_CP_MASK_SHIFT_1x8(6, 14, 12, 13, 4)
"sdot z22.s, z8.b, z6.b\n"
"sdot z22.s, z9.b, z14.b\n"
INT4_CP_MASK_SHIFT_1x8(7, 14, 12, 13, 4)
"sdot z23.s, z8.b, z7.b\n"
"sdot z23.s, z9.b, z14.b\n"
PROCESS_ACCUM(0, 16, x16, dst_ptr, p0)
PROCESS_ACCUM(1, 17, x17, dst_ptr, p0)
PROCESS_ACCUM(2, 18, x18, dst_ptr, p0)
PROCESS_ACCUM(3, 19, x19, dst_ptr, p0)
PROCESS_ACCUM(4, 20, x16, dst_ptr, p0)
PROCESS_ACCUM(5, 21, x17, dst_ptr, p0)
PROCESS_ACCUM(6, 22, x18, dst_ptr, p0)
PROCESS_ACCUM(7, 23, x19, dst_ptr, p0)
:
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[run_k_depth] "+r"(run_k_depth),
[dst_ptr] "+wr"(dst_ptr)
:
[run_sv_len] "r"(run_sv_len), [run_2sv_len] "r"(run_2sv_len),
[move_lhs] "r"(move_lhs), [move_rhs] "r"(move_rhs), [ldc] "r"(ldc),
[accum_ptr] "r"(accum_ptr)
:
"cc", "memory",
"w0","w1","w2","w3","w4","w5","w6","w7",
"w8","w9","w10","w11","w12","w13","w14","w15",
"x16","x17","x18","x19",
"z0","z1","z2","z3","z4","z5","z6","z7",
"z8","z9","z10","z11","z12","z13","z14","z15",
"z16","z17","z18","z19","z20","z21","z22","z23",
"z24","z25","z26","z27","z28","z29","z30","z31"
);
}
#ifdef __cplusplus
}
#endif

View File

@@ -1,160 +0,0 @@
cmake_minimum_required(VERSION 3.14)
set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED ON)
# can be compiled for SVE256 (32) or SVE512 (64)
set(SV_LENGTH 32) # in bytes
include_directories("${PROJECT_SOURCE_DIR}")
include_directories("${PROJECT_SOURCE_DIR}/include")
include_directories("${PROJECT_SOURCE_DIR}/standalone")
add_compile_options(-fPIC -fvisibility=hidden -fstack-protector-strong -march=armv8.3-a+sve+i8mm -O3)
# add_compile_options(-Wall -Wextra -Werror)
# sources are split into several groups: matmul kernels (sources are compiled multiple times),
# packing kernels (sources are compiled multiple times),
# sequential/parallel pipelines,
# interface (sources are compiled multiple times)
set(INT_GEMM_KERNELS "")
set(INT_PACK_KERNELS "")
set(INT_GEMM_INTERFACE "")
set(INT_GEMM_PARALLEL "")
set(INT_GEMM_SEQ "")
set(BETA_KERNELS "")
set(POST_OPS_KERNELS "")
set(INT_SMALL_KERNELS "")
set(GEMM_DRIVERS "")
# Supported precisions are i/u for A and B matrices (4 combinations)
set(LHS_TYPES LHS_INT LHS_UINT)
set(RHS_TYPES RHS_INT RHS_UINT)
# compile matrix-multiplication kernels multiple times
set(INTEGER_MM_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_kernels.c")
set(M_SIZES 1 2 3 4)
set(N_SIZES 1 2 3 4)
foreach(M_SIZE ${M_SIZES})
foreach(N_SIZE ${N_SIZES})
foreach(LHS_TYPE ${LHS_TYPES})
foreach(RHS_TYPE ${RHS_TYPES})
add_library(integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH} OBJECT ${INTEGER_MM_KERNELS_SRC})
target_compile_options(integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH} PUBLIC -DM_SIZE=${M_SIZE}
-DN_SIZE=${N_SIZE} -D${LHS_TYPE} -D${RHS_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH} PUBLIC
"${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND INT_GEMM_KERNELS $<TARGET_OBJECTS:integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH}>)
endforeach()
endforeach()
endforeach()
endforeach()
# compile interface multiple times
set(INTEGER_GEMM_IFACE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_interface.c")
foreach(LHS_TYPE ${LHS_TYPES})
foreach(RHS_TYPE ${RHS_TYPES})
add_library(integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH} OBJECT ${INTEGER_GEMM_IFACE_SRC})
target_compile_options(integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH} PUBLIC -D${LHS_TYPE} -D${RHS_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH} PUBLIC
"${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND INT_GEMM_INTERFACE $<TARGET_OBJECTS:integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH}>)
endforeach()
endforeach()
# compile threading layer
# set(INTEGER_GEMM_PAR_PIPE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/parallel_int_gemm_pipeline.c")
# add_library(integer_gemm_par_pipe_${SV_LENGTH} OBJECT ${INTEGER_GEMM_PAR_PIPE_SRC})
# target_compile_options(integer_gemm_par_pipe_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH})
# target_include_directories(integer_gemm_par_pipe_${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
# list(APPEND INT_GEMM_PARALLEL $<TARGET_OBJECTS:integer_gemm_par_pipe_${SV_LENGTH}>)
# compile sequential layer
set(INTEGER_GEMMSEQ_PIPE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/sequential_int_gemm_pipeline.c")
add_library(integer_gemm_seq_pipe_${SV_LENGTH} OBJECT ${INTEGER_GEMMSEQ_PIPE_SRC})
target_compile_options(integer_gemm_seq_pipe_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(integer_gemm_seq_pipe_${SV_LENGTH} PUBLIC
"${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND INT_GEMM_SEQ $<TARGET_OBJECTS:integer_gemm_seq_pipe_${SV_LENGTH}>)
# compile packingA kernels
set(INT_PACK_A_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_pack_a_kernels.c")
set(TRANSA_VALS TRANSA NOTRANSA)
foreach(LHS_TYPE ${LHS_TYPES})
foreach(TRANSA_VAL ${TRANSA_VALS})
add_library(integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH} OBJECT ${INT_PACK_A_KERNELS_SRC})
target_compile_options(integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH} -D${LHS_TYPE} -D${TRANSA_VAL})
target_include_directories(integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND INT_PACK_KERNELS $<TARGET_OBJECTS:integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH}>)
endforeach()
endforeach()
# compile packingB kernels
set(INT_PACK_B_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_pack_b_kernels.c")
set(TRANSB_VALS TRANSB NOTRANSB)
foreach(RHS_TYPE ${RHS_TYPES})
foreach(TRANSB_VAL ${TRANSB_VALS})
add_library(integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH} OBJECT ${INT_PACK_B_KERNELS_SRC})
target_compile_options(integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH} -D${RHS_TYPE} -D${TRANSB_VAL})
target_include_directories(integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND INT_PACK_KERNELS $<TARGET_OBJECTS:integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH}>)
endforeach()
endforeach()
# compile beta kernels
set(BETA_OPTS BETA_OPT BETA_NO_OPT)
foreach(B_TYPE ${BETA_OPTS})
add_library(beta_kernels_${B_TYPE} OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/integer_beta_kernels.c")
target_compile_options(beta_kernels_${B_TYPE} PUBLIC -D${B_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(beta_kernels_${B_TYPE} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND BETA_KERNELS $<TARGET_OBJECTS:beta_kernels_${B_TYPE}>)
endforeach()
# compile int gemm drivers
foreach(B_TYPE ${BETA_OPTS})
add_library(gemm_driver_${B_TYPE} OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_driver.c")
target_compile_options(gemm_driver_${B_TYPE} PUBLIC -D${B_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(gemm_driver_${B_TYPE} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND GEMM_DRIVERS $<TARGET_OBJECTS:gemm_driver_${B_TYPE}>)
endforeach()
# compile post-ops kernels
foreach(B_TYPE ${BETA_OPTS})
add_library(post_ops_kernels_${B_TYPE} OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/integer_post_ops_kernels.c")
target_compile_options(post_ops_kernels_${B_TYPE} PUBLIC -D${B_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(post_ops_kernels_${B_TYPE} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND POST_OPS_KERNELS $<TARGET_OBJECTS:post_ops_kernels_${B_TYPE}>)
endforeach()
# compile matrix-multiplication small kernels multiple times
set(SMALL_KERNELS_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_small_kernels.c")
set(TRANSA_VALS TRANSA NOTRANSA)
set(TRANSB_VALS TRANSB NOTRANSB)
set(OC_TYPES OC_FIX OC_COL OC_ROW)
foreach(LHS_TYPE ${LHS_TYPES})
foreach(RHS_TYPE ${RHS_TYPES})
foreach(TRANSA_VAL ${TRANSA_VALS})
foreach(TRANSB_VAL ${TRANSB_VALS})
foreach(OC_TYPE ${OC_TYPES})
add_library(small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH} OBJECT ${SMALL_KERNELS_KERNELS_SRC})
target_compile_options(small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH} PUBLIC -D${LHS_TYPE} -D${RHS_TYPE} -D${TRANSA_VAL} -D${TRANSB_VAL} -D${OC_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND INT_SMALL_KERNELS $<TARGET_OBJECTS:small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH}>)
endforeach()
endforeach()
endforeach()
endforeach()
endforeach()
list(APPEND OBJ_FILES_STANDALONE_DIR ${INT_GEMM_KERNELS} ${INT_GEMM_INTERFACE} ${INT_GEMM_SEQ} ${INT_PACK_KERNELS} ${BETA_KERNELS} ${POST_OPS_KERNELS} ${INT_SMALL_KERNELS} ${GEMM_DRIVERS})
# set(OBJ_FILES_STANDALONE_DIR ${OBJ_FILES_STANDALONE_DIR} PARENT_SCOPE)
# all compiled object files are united into one object library
add_library(prefillint8gemm SHARED ${OBJ_FILES_STANDALONE_DIR})
set_target_properties(prefillint8gemm PROPERTIES
ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/prefillint8gemm
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/prefillint8gemm
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/prefillint8gemm
)

View File

@@ -1,25 +0,0 @@
#ifndef BETA_MACROS_H
#define BETA_MACROS_H
#if defined(OC_FIX)
#define OC_TYPE f
#define OC_IDX(mi, nn) 0
#elif defined(OC_COL)
#define OC_TYPE c
#define OC_IDX(mi, ni) mi
#else
#define OC_TYPE r
#define OC_IDX(mi, ni) ni
#endif
#if defined(BETA_OPT)
#define BETA_SUFF(name) name##_opt
#define LDC(m, ldc) ldc
#define DTYPE int32_t
#else
#define BETA_SUFF(name) name
#define LDC(m, ldc) m
#define DTYPE float
#endif
#endif

View File

@@ -1,63 +0,0 @@
#ifndef __HELPING_MACROS_H__
#define __HELPING_MACROS_H__
#include <stdio.h>
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
#if defined(LHS_INT) && defined(LHS_UINT)
#error "Both LHS_INT and LHS_UINT are defined"
#endif
#if defined(RHS_INT) && defined(RHS_UINT)
#error "Both RHS_INT and RHS_UINT are defined"
#endif
#ifdef LHS_INT
#define LHS_TYPE s
#define LHS_INT_TYPE int8_t
#endif
#ifdef LHS_UINT
#define LHS_TYPE u
#define LHS_INT_TYPE uint8_t
#endif
#ifdef RHS_INT
#define RHS_TYPE s
#define RHS_INT_TYPE int8_t
#endif
#ifdef RHS_UINT
#define RHS_TYPE u
#define RHS_INT_TYPE uint8_t
#endif
// mangling macros
#define ADD_M_N_SIZES(name, m_size, n_size) name##_##m_size##x##n_size
#define ADD_M_N_SIZES_MACRO(name, m_size, n_size) ADD_M_N_SIZES(name, m_size, n_size)
#define ADD_TYPES(name, lhs_type, rhs_type) name##_##lhs_type##8##rhs_type##8s32
#define ADD_TYPES_MACRO(name, lhs_type, rhs_type) ADD_TYPES(name, lhs_type, rhs_type)
#define ADD_TYPES_SUFF(name) ADD_TYPES_MACRO(name, LHS_TYPE, RHS_TYPE)
#define ADD_ONE_TYPE_TRANSP(name, type, nt) name##_##type##8_##nt
#define ADD_ONE_TYPE_TRANSP_MACRO(name, type, nt) ADD_ONE_TYPE_TRANSP(name, type, nt)
#define ADD_PACK_A_N_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, LHS_TYPE, n)
#define ADD_PACK_B_N_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, RHS_TYPE, n)
#define ADD_PACK_A_T_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, LHS_TYPE, t)
#define ADD_PACK_B_T_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, RHS_TYPE, t)
#define ADD_TWO_TYPES_TRANSP(name, lhs_type, rhs_type, a_t, b_t, oc_t) \
name##_##lhs_type##8##rhs_type##8s32##_##a_t##b_t##_##oc_t
#define ADD_TWO_TYPES_TRANSP_MACRO(name, lhs_type, rhs_type, a_t, b_t, oc_t) \
ADD_TWO_TYPES_TRANSP(name, lhs_type, rhs_type, a_t, b_t, oc_t)
#define ADD_TRANSP_MACRO(name, a_t, b_t, oc_t) ADD_TWO_TYPES_TRANSP_MACRO(name, LHS_TYPE, RHS_TYPE, a_t, b_t, oc_t)
#ifdef ENABLE_THREADING
#define ADD_THREAD_SUFF(name) name##_thread
#else
#define ADD_THREAD_SUFF(name) name
#endif
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -1,82 +0,0 @@
#include <stdint.h>
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
#include "beta_macros.h"
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
// Column major C, Fixed C_offset
void BETA_SUFF(beta_cf_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
DTYPE oc_val = (DTYPE)*oc;
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
DTYPE beta_val = (DTYPE)beta;
for (size_t ni = 0; ni < n_block_size; ++ni) {
for (size_t mi = 0; mi < m_block_size; ++mi) {
c_typed_ptr[ldc * ni + mi] = beta_val * ((DTYPE)c_ptr[ldc * ni + mi]) + oc_val;
}
}
}
// Column major C, Column major C_offset
void BETA_SUFF(beta_cc_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
DTYPE beta_val = (DTYPE)beta;
for (size_t ni = 0; ni < n_block_size; ++ni) {
for (size_t mi = 0; mi < m_block_size; ++mi) {
c_typed_ptr[ldc * ni + mi] = beta_val * ((DTYPE)c_ptr[ldc * ni + mi]) + ((DTYPE)oc[mi]);
}
}
}
// Column major C, Row major C_offset
void BETA_SUFF(beta_cr_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
DTYPE beta_val = (DTYPE)beta;
for (size_t ni = 0; ni < n_block_size; ++ni) {
for (size_t mi = 0; mi < m_block_size; ++mi) {
c_typed_ptr[ldc * ni + mi] = beta_val * ((DTYPE)c_ptr[ldc * ni + mi]) + ((DTYPE)oc[ni]);
}
}
}
// Row major C, Fixed C_offset
// for row-major we actually swap m and n values so we reswap it here again
void BETA_SUFF(beta_rf_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
DTYPE oc_val = (DTYPE)*oc;
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
DTYPE beta_val = (DTYPE)beta;
for (size_t mi = 0; mi < n_block_size; ++mi) {
for (size_t ni = 0; ni < m_block_size; ++ni) {
c_typed_ptr[ldc * mi + ni] = beta_val * ((DTYPE)c_ptr[ldc * mi + ni]) + oc_val;
}
}
}
// Row major C, Column major C_offset
// for row-major we actually swap m and n values so we reswap it here again
void BETA_SUFF(beta_rc_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
DTYPE beta_val = (DTYPE)beta;
for (size_t mi = 0; mi < n_block_size; ++mi) {
for (size_t ni = 0; ni < m_block_size; ++ni) {
c_typed_ptr[ldc * mi + ni] = beta_val * ((DTYPE)c_ptr[ldc * mi + ni]) + ((DTYPE)oc[mi]);
}
}
}
// Row major C, Row major C_offset
// for row-major we actually swap m and n values so we reswap it here again
void BETA_SUFF(beta_rr_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
DTYPE beta_val = (DTYPE)beta;
for (size_t mi = 0; mi < n_block_size; ++mi) {
for (size_t ni = 0; ni < m_block_size; ++ni) {
c_typed_ptr[ldc * mi + ni] = beta_val * ((DTYPE)c_ptr[ldc * mi + ni]) + ((DTYPE)oc[ni]);
}
}
}
#ifdef __cplusplus
} // extern "C"
#endif /* __cplusplus */

View File

@@ -1,117 +0,0 @@
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include "integer_gemm_kernels.h"
#include "beta_macros.h"
#define OLD_N_SIZE 8
#define PACKED_LD_STEP(n_step, k_step, ldb) (n_step * ldb + k_step * OLD_N_SIZE)
void BETA_SUFF(gemm_driver)(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha,
const void* a, size_t lda, const BLASINT8 oa,
const void* b, size_t ldb, const BLASINT8 ob,
float beta, int32_t* c, size_t ldc, const int32_t* oc) {
void (**gemm_kernels)(const void*, const void*, int32_t*, size_t, int64_t, int64_t) = arg->gemm_kernels;
void (*pack_a_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8) = arg->pack_a_fun;
void (*pack_b_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8) = arg->pack_b_fun;
void (*beta_func)(int32_t*, const int32_t*, float, size_t, size_t, size_t) = arg->beta_func;
size_t (*a_indexing)(size_t m, size_t n, size_t ld) = arg->a_indexing;
size_t (*b_indexing)(size_t m, size_t n, size_t ld) = arg->b_indexing;
#ifndef BETA_OPT
void (*post_ops_func)(float, const int32_t*, int32_t*, size_t, size_t, size_t) = arg->post_ops_func;
#endif // BETA_OPT
const BLASINT8* a_typed = (const BLASINT8*) a;
const BLASINT8* b_typed = (const BLASINT8*) b;
BLASINT8* bufferA = (BLASINT8*) aligned_alloc(ALIGNMENT, sizeof(BLASINT8) * K_BLOCK * M_BLOCK);
BLASINT8* bufferB = (BLASINT8*) aligned_alloc(ALIGNMENT, sizeof(BLASINT8) * K_BLOCK * N_BLOCK);
// Tmp buffer is not needed when (alpha = 1 and beta = 0/1)
#ifdef BETA_OPT
int32_t* bufferC = c;
#else
int32_t* bufferC = (int32_t*) aligned_alloc(ALIGNMENT, sizeof(int32_t) * m * N_BLOCK);
#endif
if (!bufferA || !bufferB || !bufferC) {
free(bufferA);
free(bufferB);
free(bufferC);
printf("Integer GEMM unsuccessful allocation");
return;
}
// printf("pack b beta: %f\n",beta);
beta_func(c, oc, beta, m, n, ldc);
for (size_t n_block = 0; n_block < n; n_block += N_BLOCK) {
size_t n_block_size = n - n_block;
if (n_block_size > N_BLOCK) {
n_block_size = N_BLOCK;
}
#ifndef BETA_OPT
// fill bufferC w/ zeros
for (size_t tmp_idx = 0; tmp_idx < (m * N_BLOCK); ++tmp_idx) {
bufferC[tmp_idx] = 0;
}
#endif // BETA_OPT
if (alpha != 0.0f){
for (size_t k_block = 0; k_block < k; k_block += K_BLOCK){
size_t k_block_size = k - k_block;
if (k_block_size > K_BLOCK) {
k_block_size = K_BLOCK;
}
size_t k_block_size_up = (k_block_size + KERNEL_K_STEP - 1) / KERNEL_K_STEP * KERNEL_K_STEP;
const BLASINT8* curr_b_ptr = b_typed + b_indexing(k_block, n_block, ldb);
pack_b_fun(bufferB, curr_b_ptr, n_block_size, k_block_size, ldb, ob);
for (size_t m_block = 0; m_block < m; m_block += M_BLOCK) {
size_t m_block_size = m - m_block;
if (m_block_size > M_BLOCK) {
m_block_size = M_BLOCK;
}
const BLASINT8* curr_a_ptr = a_typed + PACKED_LD_STEP(m_block, k_block, lda);
pack_a_fun(bufferA, curr_a_ptr, m_block_size, k_block_size, lda, oa);
// loop over bufferB, taking parts which fit into L1
for (size_t n_sub_block = 0; n_sub_block < n_block_size; n_sub_block += KERNEL_N_STEP) {
size_t n_sub_block_size = n_block_size - n_sub_block;
if (n_sub_block_size > KERNEL_N_STEP) {
n_sub_block_size = KERNEL_N_STEP;
}
BLASINT8* current_bufferB_ptr = bufferB + n_sub_block * k_block_size_up;
// loop over bufferA, taking parts which fit into L1
for (size_t m_sub_block = 0; m_sub_block < m_block_size; m_sub_block += KERNEL_M_STEP) {
size_t m_sub_block_size = m_block_size - m_sub_block;
if (m_sub_block_size > KERNEL_M_STEP) {
m_sub_block_size = KERNEL_M_STEP;
}
BLASINT8* current_bufferA_ptr = bufferA + m_sub_block * k_block_size_up;
int32_t* current_bufferC_ptr = bufferC + n_block * ldc + n_sub_block * LDC(m, ldc) + m_sub_block + m_block;
// call kernel which performs loop over k_block_size
gemm_kernels[(n_sub_block_size - 1) + (m_sub_block_size - 1) * KERNEL_N_STEP](current_bufferA_ptr, current_bufferB_ptr, current_bufferC_ptr,
LDC(m, ldc), k_block_size_up, COMP_SV_LEN);
}
}
}
}
}
#ifndef BETA_OPT
// copy C data from bufferC multiplying by alpha and adding initial C data (scaled by beta)
int32_t* current_c_ptr = c + n_block * ldc; // col major
post_ops_func(alpha, bufferC, current_c_ptr, LDC(m, ldc), n_block_size, ldc);
#endif
}
free(bufferA);
free(bufferB);
#ifndef BETA_OPT
free(bufferC);
#endif
}

View File

@@ -1,154 +0,0 @@
#include <stdint.h>
#include <stdio.h>
//#include "cblas.h"
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
/* matrix saved in rows or cols */
typedef enum CBLAS_ORDER {
CblasRowMajor = 101,
CblasColMajor = 102
} CBLAS_ORDER;
/* matrix transpose or conjugate transpose */
typedef enum CBLAS_TRANSPOSE {
CblasNoTrans = 111,
CblasTrans = 112,
CblasConjTrans = 113, // conjugate transpose
CblasConjNoTrans = 114
} CBLAS_TRANSPOSE;
typedef CBLAS_ORDER CBLAS_LAYOUT;
typedef enum CBLAS_OFFSET {
CblasRowOffset = 171,
CblasColOffset = 172,
CblasFixOffset = 173
} CBLAS_OFFSET;
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
#define ADD_KERNEL_SUFF(name, m_size, n_size) ADD_M_N_SIZES_MACRO(ADD_TYPES_MACRO(name, LHS_TYPE, RHS_TYPE), m_size, n_size)
static void (*gemm_kernels[])(const void*, const void*, int32_t*, size_t, int64_t, int64_t) = {
ADD_KERNEL_SUFF(gemm_kernel, 1, 1), ADD_KERNEL_SUFF(gemm_kernel, 1, 2),ADD_KERNEL_SUFF(gemm_kernel, 1, 3), ADD_KERNEL_SUFF(gemm_kernel, 1, 4),
ADD_KERNEL_SUFF(gemm_kernel, 2, 1), ADD_KERNEL_SUFF(gemm_kernel, 2, 2),ADD_KERNEL_SUFF(gemm_kernel, 2, 3), ADD_KERNEL_SUFF(gemm_kernel, 2, 4),
ADD_KERNEL_SUFF(gemm_kernel, 3, 1), ADD_KERNEL_SUFF(gemm_kernel, 3, 2),ADD_KERNEL_SUFF(gemm_kernel, 3, 3), ADD_KERNEL_SUFF(gemm_kernel, 3, 4),
ADD_KERNEL_SUFF(gemm_kernel, 4, 1), ADD_KERNEL_SUFF(gemm_kernel, 4, 2),ADD_KERNEL_SUFF(gemm_kernel, 4, 3), ADD_KERNEL_SUFF(gemm_kernel, 4, 4)
};
static void (*pack_b_funs[])(void*, const void*, size_t, size_t, size_t, const BLASINT8) = {
ADD_PACK_B_N_SUFF(pack_b),
ADD_PACK_B_T_SUFF(pack_b)
};
static void (*pack_a_funs[])(void*, const void*, size_t, size_t, size_t, const BLASINT8) = {
ADD_PACK_A_N_SUFF(pack_a),
ADD_PACK_A_T_SUFF(pack_a)
};
static void (*small_kernels[])(const size_t, const size_t, const size_t, const float,
const void *, const size_t, const BLASINT8,
const void *, const size_t, const BLASINT8,
const float, int32_t *, const size_t, const int32_t *) = {
ADD_TRANSP_MACRO(small_kernel, n, n, f), ADD_TRANSP_MACRO(small_kernel, n, t, f),
ADD_TRANSP_MACRO(small_kernel, t, n, f), ADD_TRANSP_MACRO(small_kernel, t, t, f),
ADD_TRANSP_MACRO(small_kernel, n, n, c), ADD_TRANSP_MACRO(small_kernel, n, t, c),
ADD_TRANSP_MACRO(small_kernel, t, n, c), ADD_TRANSP_MACRO(small_kernel, t, t, c),
ADD_TRANSP_MACRO(small_kernel, n, n, r), ADD_TRANSP_MACRO(small_kernel, n, t, r),
ADD_TRANSP_MACRO(small_kernel, t, n, r), ADD_TRANSP_MACRO(small_kernel, t, t, r),
};
static void (*beta_funcs[])(int32_t*, const int32_t*, float, size_t, size_t, size_t) = {
beta_cf_s8, beta_cc_s8, beta_cr_s8, beta_rf_s8, beta_rc_s8, beta_rr_s8,
beta_cf_s8_opt, beta_cc_s8_opt, beta_cr_s8_opt, beta_rf_s8_opt, beta_rc_s8_opt, beta_rr_s8_opt
};
static void (*post_op_kernels[])(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc) = {
post_ops, post_ops_opt
};
static size_t row_major_idx(size_t m, size_t n, size_t ld) {
return ld * m + n;
}
static size_t col_major_idx(size_t m, size_t n, size_t ld) {
return m + ld * n;
}
static size_t (*compute_idx[])(size_t m, size_t n, size_t ld) = {
col_major_idx,
row_major_idx
};
static size_t mov_oc_fix(size_t mi, size_t ni) {
UNUSED(mi);
UNUSED(ni);
return 0;
}
static size_t mov_oc_col(size_t mi, size_t ni){
UNUSED(ni);
return mi;
}
static size_t mov_oc_row(size_t mi, size_t ni) {
UNUSED(mi);
return ni;
}
static size_t (*move_oc[])(size_t, size_t) = {
mov_oc_fix, mov_oc_col, mov_oc_row
};
EXTERNAL_API void ADD_TYPES_SUFF(prefill_cblas_gemm)(
const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,
const size_t m, const size_t n, const size_t k, const float alpha,
const void *a, const size_t lda, const BLASINT8 oa,
const void *b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t *c, const size_t ldc, const int32_t *oc) {
int opt_offset = ((alpha == 1.0f) && (beta == 0.0f || beta == 1.0f)) ? 1 : 0;
if(Layout == CblasColMajor) {
int beta_offset = (offsetc == CblasFixOffset) ? 0 : (offsetc == CblasColOffset ? 1:2);
int_gemm_funcs arg = {
small_kernels[(transb == CblasTrans) + 2 * (transa == CblasTrans) + beta_offset * 4],
gemm_kernels,
pack_a_funs[transa == CblasTrans],
pack_b_funs[transb == CblasTrans],
beta_funcs[beta_offset + opt_offset * 6],
post_op_kernels[alpha == 1],
compute_idx[transa == CblasTrans],
compute_idx[transb == CblasTrans],
move_oc[beta_offset],
};
(gemm_impl_8bit(&arg, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc, 262144));
} else if (Layout == CblasRowMajor) {
int beta_offset = (offsetc == CblasFixOffset) ? 3 : (offsetc == CblasColOffset ? 4 : 5);
int beta_offset_small = (offsetc == CblasFixOffset) ? 0 : (offsetc == CblasColOffset ? 2 : 1);
int_gemm_funcs arg = {
small_kernels[(transa == CblasTrans) + 2 * (transb == CblasTrans) + beta_offset_small * 4],
gemm_kernels,
pack_a_funs[transb == CblasTrans],
pack_b_funs[transa == CblasTrans],
beta_funcs[beta_offset + opt_offset * 6],
post_op_kernels[alpha == 1],
compute_idx[transb == CblasTrans],
compute_idx[transa == CblasTrans],
move_oc[beta_offset_small]
};
(gemm_impl_8bit(&arg, n, m, k, alpha, b, ldb, ob, a, lda, oa, beta, c, ldc, oc, 262144));
}
else {
printf("Incorrect layout");
return;
}
}
#ifdef __cplusplus
}
#endif

View File

@@ -1,453 +0,0 @@
#include <stdint.h>
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
#ifdef __cplusplus
extern "C" {
#endif
#define ADD_SUFFIX(name) ADD_M_N_SIZES_MACRO(ADD_TYPES_MACRO(name, LHS_TYPE, RHS_TYPE), M_SIZE, N_SIZE)
#define LD1B_PTR(reg_name, p, ptr, idx) "ld1b {" #reg_name ".b}, " #p "/z, [%[" #ptr "], #" #idx ", MUL VL]\n"
#define COMPUTE_ADDP(out, in1, in2) "addp " #out ".s, " #in1 ".s, " #in2 ".s\n"
#if (defined(LHS_INT) && defined(RHS_INT)) || (defined(LHS_UINT) && defined(RHS_UINT))
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) #lhs_type "dot " #out ".s, " #in1 ".b, " #in2 ".b\n"
#else
#ifdef LHS_INT
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in2 ".b, " #in1 ".b\n"
#else
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in1 ".b, " #in2 ".b\n"
#endif
#endif
#define COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE) COMPUTE_DOT_TYPED(out, in1, in2, LHS_TYPE, RHS_TYPE)
#define COMPUTE_DOT(out, in1, in2) COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE)
#if (N_SIZE > 4)
#error "N_SIZE can't be greater than 4"
#endif
#if (M_SIZE > 4)
#error "M_SIZE can't be greater than 4"
#endif
#define LOAD_Z0(p, ptr) LD1B_PTR(z0, p, ptr, 0)
#define LOAD_Z8(p, ptr) LD1B_PTR(z8, p, ptr, 0)
#if (N_SIZE > 1)
#define LOAD_Z1(p, ptr) LD1B_PTR(z1, p, ptr, 1)
#define LOAD_Z9(p, ptr) LD1B_PTR(z9, p, ptr, 1)
#else
#define LOAD_Z1(p, ptr)
#define LOAD_Z9(p, ptr)
#endif
#if (N_SIZE > 2)
#define LOAD_Z2(p, ptr) LD1B_PTR(z2, p, ptr, 2)
#define LOAD_Z10(p, ptr) LD1B_PTR(z10, p, ptr, 2)
#else
#define LOAD_Z2(p, ptr)
#define LOAD_Z10(p, ptr)
#endif
#if (N_SIZE > 3)
#define LOAD_Z3(p, ptr) LD1B_PTR(z3, p, ptr, 3)
#define LOAD_Z11(p, ptr) LD1B_PTR(z11, p, ptr, 3)
#else
#define LOAD_Z3(p, ptr)
#define LOAD_Z11(p, ptr)
#endif
#define LOAD_Z4(p, ptr) LD1B_PTR(z4, p, ptr, 0)
#define LOAD_Z12(p, ptr) LD1B_PTR(z12, p, ptr, 0)
#if (M_SIZE > 1)
#define LOAD_Z5(p, ptr) LD1B_PTR(z5, p, ptr, 1)
#define LOAD_Z13(p, ptr) LD1B_PTR(z13, p, ptr, 1)
#else
#define LOAD_Z5(p, ptr)
#define LOAD_Z13(p, ptr)
#endif
#if (M_SIZE > 2)
#define LOAD_Z6(p, ptr) LD1B_PTR(z6, p, ptr, 2)
#define LOAD_Z14(p, ptr) LD1B_PTR(z14, p, ptr, 2)
#else
#define LOAD_Z6(p, ptr)
#define LOAD_Z14(p, ptr)
#endif
#if (M_SIZE > 3)
#define LOAD_Z7(p, ptr) LD1B_PTR(z7, p, ptr, 3)
#define LOAD_Z15(p, ptr) LD1B_PTR(z15, p, ptr, 3)
#else
#define LOAD_Z7(p, ptr)
#define LOAD_Z15(p, ptr)
#endif
// macros for dot multiplication
#define ACCUMULATE_Z16(lhs, rhs) COMPUTE_DOT(z16, lhs, rhs)
#if (N_SIZE > 1)
#define ACCUMULATE_Z17(lhs, rhs) COMPUTE_DOT(z17, lhs, rhs)
#else
#define ACCUMULATE_Z17(lhs, rhs)
#endif
#if (N_SIZE > 2)
#define ACCUMULATE_Z18(lhs, rhs) COMPUTE_DOT(z18, lhs, rhs)
#else
#define ACCUMULATE_Z18(lhs, rhs)
#endif
#if (N_SIZE > 3)
#define ACCUMULATE_Z19(lhs, rhs) COMPUTE_DOT(z19, lhs, rhs)
#else
#define ACCUMULATE_Z19(lhs, rhs)
#endif
#if (M_SIZE > 1)
#define ACCUMULATE_Z20(lhs, rhs) COMPUTE_DOT(z20, lhs, rhs)
#if (N_SIZE > 1)
#define ACCUMULATE_Z21(lhs, rhs) COMPUTE_DOT(z21, lhs, rhs)
#else
#define ACCUMULATE_Z21(lhs, rhs)
#endif
#if (N_SIZE > 2)
#define ACCUMULATE_Z22(lhs, rhs) COMPUTE_DOT(z22, lhs, rhs)
#else
#define ACCUMULATE_Z22(lhs, rhs)
#endif
#if (N_SIZE > 3)
#define ACCUMULATE_Z23(lhs, rhs) COMPUTE_DOT(z23, lhs, rhs)
#else
#define ACCUMULATE_Z23(lhs, rhs)
#endif
#else
#define ACCUMULATE_Z20(lhs, rhs)
#define ACCUMULATE_Z21(lhs, rhs)
#define ACCUMULATE_Z22(lhs, rhs)
#define ACCUMULATE_Z23(lhs, rhs)
#endif
#if (M_SIZE > 2)
#define ACCUMULATE_Z24(lhs, rhs) COMPUTE_DOT(z24, lhs, rhs)
#if (N_SIZE > 1)
#define ACCUMULATE_Z25(lhs, rhs) COMPUTE_DOT(z25, lhs, rhs)
#else
#define ACCUMULATE_Z25(lhs, rhs)
#endif
#if (N_SIZE > 2)
#define ACCUMULATE_Z26(lhs, rhs) COMPUTE_DOT(z26, lhs, rhs)
#else
#define ACCUMULATE_Z26(lhs, rhs)
#endif
#if (N_SIZE > 3)
#define ACCUMULATE_Z27(lhs, rhs) COMPUTE_DOT(z27, lhs, rhs)
#else
#define ACCUMULATE_Z27(lhs, rhs)
#endif
#else
#define ACCUMULATE_Z24(lhs, rhs)
#define ACCUMULATE_Z25(lhs, rhs)
#define ACCUMULATE_Z26(lhs, rhs)
#define ACCUMULATE_Z27(lhs, rhs)
#endif
#if (M_SIZE > 3)
#define ACCUMULATE_Z28(lhs, rhs) COMPUTE_DOT(z28, lhs, rhs)
#if (N_SIZE > 1)
#define ACCUMULATE_Z29(lhs, rhs) COMPUTE_DOT(z29, lhs, rhs)
#else
#define ACCUMULATE_Z29(lhs, rhs)
#endif
#if (N_SIZE > 2)
#define ACCUMULATE_Z30(lhs, rhs) COMPUTE_DOT(z30, lhs, rhs)
#else
#define ACCUMULATE_Z30(lhs, rhs)
#endif
#if (N_SIZE > 3)
#define ACCUMULATE_Z31(lhs, rhs) COMPUTE_DOT(z31, lhs, rhs)
#else
#define ACCUMULATE_Z31(lhs, rhs)
#endif
#else
#define ACCUMULATE_Z28(lhs, rhs)
#define ACCUMULATE_Z29(lhs, rhs)
#define ACCUMULATE_Z30(lhs, rhs)
#define ACCUMULATE_Z31(lhs, rhs)
#endif
#define MOVE_LHS_PTR(ptr) "add %[" #ptr "], %[" #ptr "], %[move_lhs]\n"
#define MOVE_RHS_PTR(ptr) "add %[" #ptr "], %[" #ptr "], %[move_rhs]\n"
#define PROCESS_ACCUM(reg_idx, z_reg_idx, tmp_reg, dst, p) \
"ldr w" #reg_idx ", [%[" #dst "]]\n" \
"saddv d" #reg_idx ", " #p ", z" #z_reg_idx ".s\n" \
"fmov " #tmp_reg ", d" #reg_idx "\n" \
"add x" #reg_idx ", x" #reg_idx ", " #tmp_reg "\n" \
"str w" #reg_idx ", [%[" #dst "]], #4\n"
// function logic
void ADD_SUFFIX(gemm_kernel)(const void *lhs_ptr, const void *rhs_ptr,
int32_t *accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len) {
int64_t run_k_depth = k_depth;
int64_t run_sv_len = sv_len;
int64_t run_2sv_len = 2 * sv_len;
int64_t move_lhs = M_SIZE * sv_len;
int64_t move_rhs = N_SIZE * sv_len;
int32_t* dst_ptr = accum_ptr;
ldc -= M_SIZE;
ldc *= 4;
asm volatile(
// predicate for operating on lhs and rhs is always true
"ptrue p0.b, all\n"
// Clear accumulators
LOAD_Z0(p0, rhs_ptr)
"dup z16.s, #0\n"
LOAD_Z1(p0, rhs_ptr)
"dup z17.s, #0\n"
LOAD_Z4(p0, lhs_ptr)
"dup z18.s, #0\n"
LOAD_Z5(p0, lhs_ptr)
"dup z19.s, #0\n"
LOAD_Z6(p0, lhs_ptr)
"dup z20.s, #0\n"
LOAD_Z7(p0, lhs_ptr)
"dup z21.s, #0\n"
LOAD_Z2(p0, rhs_ptr)
"dup z22.s, #0\n"
LOAD_Z3(p0, rhs_ptr)
"dup z23.s, #0\n"
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
"dup z24.s, #0\n"
"mov x16, %[dst_ptr]\n"
"dup z25.s, #0\n"
"dup z26.s, #0\n"
"dup z27.s, #0\n"
MOVE_LHS_PTR(lhs_ptr)
"dup z28.s, #0\n"
MOVE_RHS_PTR(rhs_ptr)
"dup z29.s, #0\n"
"dup z30.s, #0\n"
"dup z31.s, #0\n"
"ble 1f\n"
"cmp %[run_k_depth], %[run_2sv_len]\n"
"blt 2f\n"
"3:\n"
LOAD_Z12(p0, lhs_ptr)
ACCUMULATE_Z16(z4,z0)
ACCUMULATE_Z17(z4,z1)
LOAD_Z13(p0, lhs_ptr)
ACCUMULATE_Z18(z4,z2)
ACCUMULATE_Z19(z4,z3)
LOAD_Z8(p0, rhs_ptr)
ACCUMULATE_Z20(z5,z0)
ACCUMULATE_Z21(z5,z1)
LOAD_Z9(p0, rhs_ptr)
ACCUMULATE_Z22(z5,z2)
ACCUMULATE_Z23(z5,z3)
LOAD_Z10(p0,rhs_ptr)
ACCUMULATE_Z24(z6,z0)
ACCUMULATE_Z25(z6,z1)
LOAD_Z11(p0,rhs_ptr)
ACCUMULATE_Z26(z6,z2)
MOVE_RHS_PTR(rhs_ptr)
"prfw pldl1keep, p0, [%[rhs_ptr], #4, MUL VL]\n"
ACCUMULATE_Z27(z6,z3)
LOAD_Z14(p0, lhs_ptr)
ACCUMULATE_Z28(z7,z0)
ACCUMULATE_Z29(z7,z1)
LOAD_Z15(p0,lhs_ptr)
ACCUMULATE_Z30(z7,z2)
MOVE_LHS_PTR(lhs_ptr)
"prfw pldl1keep, p0, [%[lhs_ptr], #4, MUL VL]\n"
ACCUMULATE_Z31(z7,z3)
LOAD_Z4(p0, lhs_ptr)
ACCUMULATE_Z16(z12,z8)
ACCUMULATE_Z17(z12,z9)
LOAD_Z5(p0, lhs_ptr)
ACCUMULATE_Z18(z12,z10)
ACCUMULATE_Z19(z12,z11)
LOAD_Z6(p0, lhs_ptr)
ACCUMULATE_Z20(z13,z8)
ACCUMULATE_Z21(z13,z9)
LOAD_Z0(p0, rhs_ptr)
"sub %[run_k_depth], %[run_k_depth], %[run_2sv_len]\n"
ACCUMULATE_Z22(z13,z10)
ACCUMULATE_Z23(z13,z11)
LOAD_Z1(p0, rhs_ptr)
ACCUMULATE_Z24(z14,z8)
ACCUMULATE_Z25(z14,z9)
LOAD_Z2(p0,rhs_ptr)
ACCUMULATE_Z26(z14,z10)
ACCUMULATE_Z27(z14,z11)
LOAD_Z3(p0, rhs_ptr)
ACCUMULATE_Z28(z15,z8)
MOVE_RHS_PTR(rhs_ptr)
"prfw pldl1keep, p0, [%[rhs_ptr], #4, MUL VL]\n"
ACCUMULATE_Z29(z15,z9)
LOAD_Z7(p0, lhs_ptr)
"cmp %[run_k_depth], %[run_2sv_len]\n"
ACCUMULATE_Z30(z15, z10)
MOVE_LHS_PTR(lhs_ptr)
"prfw pldl1keep, p0, [%[lhs_ptr], #4, MUL VL]\n"
ACCUMULATE_Z31(z15,z11)
"bge 3b\n"
"cmp %[run_k_depth], #0\n"
"ble 1f\n"
"2:\n"
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
ACCUMULATE_Z16(z4,z0)
ACCUMULATE_Z17(z4,z1)
ACCUMULATE_Z18(z4,z2)
ACCUMULATE_Z19(z4,z3)
LOAD_Z4(p0,lhs_ptr)
ACCUMULATE_Z20(z5,z0)
ACCUMULATE_Z21(z5,z1)
ACCUMULATE_Z22(z5,z2)
ACCUMULATE_Z23(z5,z3)
LOAD_Z5(p0,lhs_ptr)
ACCUMULATE_Z24(z6,z0)
ACCUMULATE_Z25(z6,z1)
ACCUMULATE_Z26(z6,z2)
ACCUMULATE_Z27(z6,z3)
LOAD_Z6(p0,lhs_ptr)
ACCUMULATE_Z28(z7,z0)
LOAD_Z0(p0,rhs_ptr)
ACCUMULATE_Z29(z7,z1)
LOAD_Z1(p0,rhs_ptr)
ACCUMULATE_Z30(z7,z2)
LOAD_Z2(p0,rhs_ptr)
ACCUMULATE_Z31(z7,z3)
LOAD_Z3(p0,rhs_ptr)
MOVE_RHS_PTR(rhs_ptr)
LOAD_Z7(p0,lhs_ptr)
MOVE_LHS_PTR(lhs_ptr)
"bgt 2b\n"
"1:\n"
ACCUMULATE_Z16(z4,z0)
ACCUMULATE_Z17(z4,z1)
ACCUMULATE_Z18(z4,z2)
ACCUMULATE_Z19(z4,z3)
ACCUMULATE_Z20(z5,z0)
ACCUMULATE_Z21(z5,z1)
ACCUMULATE_Z22(z5,z2)
ACCUMULATE_Z23(z5,z3)
ACCUMULATE_Z24(z6,z0)
ACCUMULATE_Z25(z6,z1)
ACCUMULATE_Z26(z6,z2)
ACCUMULATE_Z27(z6,z3)
ACCUMULATE_Z28(z7,z0)
ACCUMULATE_Z29(z7,z1)
ACCUMULATE_Z30(z7,z2)
ACCUMULATE_Z31(z7,z3)
#if (N_SIZE > 0)
#if (M_SIZE > 0)
PROCESS_ACCUM(0, 16, x16, dst_ptr, p0)
#endif
#if (M_SIZE > 1)
PROCESS_ACCUM(4, 20, x17, dst_ptr, p0)
#endif
#if (M_SIZE > 2)
PROCESS_ACCUM(8, 24, x18, dst_ptr, p0)
#endif
#if (M_SIZE > 3)
PROCESS_ACCUM(12, 28, x17, dst_ptr, p0)
#endif
#endif
"add %[dst_ptr], %[dst_ptr], %[ldc]\n"
#if (N_SIZE > 1)
#if (M_SIZE > 0)
PROCESS_ACCUM(1, 17, x16, dst_ptr, p0)
#endif
#if (M_SIZE > 1)
PROCESS_ACCUM(5, 21, x17,dst_ptr,p0)
#endif
#if (M_SIZE > 2)
PROCESS_ACCUM(9, 25, x18, dst_ptr, p0)
#endif
#if (M_SIZE > 3)
PROCESS_ACCUM(13, 29, x17, dst_ptr, p0)
#endif
#endif
"add %[dst_ptr], %[dst_ptr], %[ldc]\n"
#if (N_SIZE > 2)
#if (M_SIZE > 0)
PROCESS_ACCUM(2, 18, x16, dst_ptr, p0)
#endif
#if (M_SIZE > 1)
PROCESS_ACCUM(6, 22, x17,dst_ptr,p0)
#endif
#if (M_SIZE > 2)
PROCESS_ACCUM(10,26,x18,dst_ptr,p0)
#endif
#if (M_SIZE > 3)
PROCESS_ACCUM(14,30,x17,dst_ptr,p0)
#endif
#endif
"add %[dst_ptr], %[dst_ptr], %[ldc]\n"
#if (N_SIZE > 3)
#if (M_SIZE > 0)
PROCESS_ACCUM(3, 19, x16, dst_ptr, p0)
#endif
#if (M_SIZE > 1)
PROCESS_ACCUM(7, 23, x17,dst_ptr,p0)
#endif
#if (M_SIZE > 2)
PROCESS_ACCUM(11,27,x18,dst_ptr,p0)
#endif
#if (M_SIZE > 3)
PROCESS_ACCUM(15,31,x17,dst_ptr,p0)
#endif
#endif
:
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[run_k_depth] "+r"(run_k_depth),
[dst_ptr] "+wr"(dst_ptr)
:
[run_sv_len] "r"(run_sv_len), [run_2sv_len] "r"(run_2sv_len),
[move_lhs] "r"(move_lhs), [move_rhs] "r"(move_rhs), [ldc] "r"(ldc),
[accum_ptr] "r"(accum_ptr)
:
"cc", "memory",
"w0","w1","w2","w3","w4","w5","w6","w7",
"w8","w9","w10","w11","w12","w13","w14","w15",
"x16","x17","x18","x19",
"z0","z1","z2","z3","z4","z5","z6","z7",
"z8","z9","z10","z11","z12","z13","z14","z15",
"z16","z17","z18","z19","z20","z21","z22","z23",
"z24","z25","z26","z27","z28","z29","z30","z31"
);
}
#ifdef __cplusplus
}
#endif

View File

@@ -1,468 +0,0 @@
#ifndef __GEMM_INTEGER_KERNELS_H__
#define __GEMM_INTEGER_KERNELS_H__
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
#include <stddef.h>
typedef int8_t BLASINT8;
typedef uint8_t BLASUINT8;
typedef struct {
void (*small_kernel)(const size_t, const size_t, const size_t, const float, const void*, const size_t, const BLASINT8,
const void*, const size_t, const BLASINT8, const float, int32_t*, const size_t, const int32_t*);
void (**gemm_kernels)(const void*, const void*, int32_t*, size_t, int64_t, int64_t);
void (*pack_a_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8);
void (*pack_b_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8);
void (*beta_func)(int32_t*, const int32_t*, float, size_t, size_t, size_t);
void (*post_ops_func)(float, const int32_t*, int32_t*, size_t, size_t, size_t);
size_t (*a_indexing)(size_t, size_t, size_t);
size_t (*b_indexing)(size_t, size_t, size_t);
size_t (*move_oc)(size_t, size_t);
} int_gemm_funcs;
#ifndef COMP_SV_LEN
#error "COMP_SV_LEN is not defined"
#endif
#define KERNEL_M_STEP 4
#define KERNEL_N_STEP 4
#define KERNEL_K_STEP COMP_SV_LEN
#define M_BLOCK 256
#if ((M_BLOCK % KERNEL_M_STEP) != 0)
#error "M_BLOCK % KERNEL_M_STEP != 0"
#endif
#define N_BLOCK 256
#if ((N_BLOCK % KERNEL_N_STEP) != 0)
#error "N_BLOCK % KERNEL_N_STEP != 0"
#endif
#define K_BLOCK 512
#if ((K_BLOCK % KERNEL_K_STEP) != 0)
#error "K_BLOCK % KERNEL_K_STEP != 0"
#endif
#define ALIGNMENT 4096
#define EXTERNAL_API __attribute__((visibility("default")))
#define UNUSED(arg) ((void)(arg))
// general pipeline
void gemm_impl_8bit(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha, const void* a, size_t lda,
const BLASINT8 oa, const void* b, size_t ldb, const BLASINT8 ob, float beta, int32_t* c, size_t ldc,
const int32_t* oc, size_t small_switch);
// s8 kernel
void pack_b_s8_n(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
const BLASINT8 ob);
void pack_a_s8_n(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
const BLASINT8 oa);
void pack_b_s8_t(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
const BLASINT8 ob);
void pack_a_s8_t(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
const BLASINT8 oa);
// u8 kernels
void pack_b_u8_n(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
const BLASINT8 ob);
void pack_a_u8_n(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
const BLASINT8 oa);
void pack_b_u8_t(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
const BLASINT8 ob);
void pack_a_u8_t(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
const BLASINT8 oa);
// beta kernels
void beta_cf_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
void beta_cc_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
void beta_cr_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
void beta_rf_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
void beta_rc_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
void beta_rr_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
void beta_cf_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
size_t ldc);
void beta_cc_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
size_t ldc);
void beta_cr_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
size_t ldc);
void beta_rf_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
size_t ldc);
void beta_rc_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
size_t ldc);
void beta_rr_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
size_t ldc);
// post-ops kernels
void post_ops(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc);
void post_ops_opt(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc);
// drivers
void gemm_driver(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha, const void* a, size_t lda,
const BLASINT8 oa, const void* b, size_t ldb, const BLASINT8 ob, float beta, int32_t* c, size_t ldc,
const int32_t* oc);
void gemm_driver_opt(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha, const void* a, size_t lda,
const BLASINT8 oa, const void* b, size_t ldb, const BLASINT8 ob, float beta, int32_t* c,
size_t ldc, const int32_t* oc);
// matrix multiplication kernels
// s8s8s32 kernels
void gemm_kernel_s8s8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
// u8u8s32 kernels
void gemm_kernel_u8u8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
// s8u8s32 kernels
void gemm_kernel_s8u8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
// u8s8s32 kernels
void gemm_kernel_u8s8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
// small kernels
// s8s8s32 kernels
void small_kernel_s8s8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
// s8u8s32 kernels
void small_kernel_s8u8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
// u8s8s32 kernels
void small_kernel_u8s8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
// u8u8s32 kernels
void small_kernel_u8u8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -1,158 +0,0 @@
#include <stdint.h>
#include <math.h>
#include <stdlib.h>
#include <stdio.h>
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
#if defined(OC_FIX)
#define OC_TYPE f
#define OC_IDX(mi, nn) 0
#elif defined(OC_COL)
#define OC_TYPE c
#define OC_IDX(mi, ni) mi
#else // OC_ROW
#define OC_TYPE r
#define OC_IDX(mi, ni) ni
#endif // OC_T
#if defined(TRANSA)
#if defined TRANSB
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, t, t, OC_TYPE)
#elif defined(NOTRANSB)
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, t, n, OC_TYPE)
#else
#error "Neither TRANSB or NOTRANSB is defined"
#endif
#elif defined(NOTRANSA)
#if defined TRANSB
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, n, t, OC_TYPE)
#elif defined(NOTRANSB)
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, n, n, OC_TYPE)
#else
#error "Neither TRANSB or NOTRANSB is defined"
#endif
#else
#error "Neither TRANSA or NOTRANSA is defined"
#endif
#if (defined (LHS_INT) && defined(RHS_INT)) || (defined (LHS_UINT) && defined(RHS_UINT))
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) #lhs_type "dot " #out ".s, " #in1 ".b, " #in2 ".b\n"
#else
#if defined(LHS_INT)
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in2 ".b," #in1 ".b\n"
#else // LHS_UINT
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in1 ".b," #in2 ".b\n"
#endif // LHS_INT
#endif // LHS_INT
#define COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE) COMPUTE_DOT_TYPED(out, in1, in2, LHS_TYPE, RHS_TYPE)
#define COMPUTE_DOT(out, in1, in2) COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE)
static inline double compute_dot(size_t k, const void *a, const BLASINT8* oa,
const void *b, const BLASINT8* ob, int64_t sv_len) {
int32_t accum = 0;
int64_t run_k_depth = k;
int64_t run_sv_len = sv_len;
const void* lhs_ptr = a;
const void* rhs_ptr = b;
asm volatile(
"dup z4.s, #0\n"
"ptrue p0.b, all\n"
"ld1b {z0.b}, p0/z, [%[oa]]\n"
"ld1b {z1.b}, p0/z, [%[ob]]\n"
"1:\n"
"whilelt p1.b, xzr, %[run_k_depth]\n"
"ld1b {z2.b}, p1/z, [%[lhs_ptr]]\n"
"ld1b {z3.b}, p1/z, [%[rhs_ptr]]\n"
"add z2.b, p1/m, z2.b, z0.b\n"
"add z3.b, p1/m, z3.b, z1.b\n"
"add %[lhs_ptr], %[lhs_ptr], %[run_sv_len]\n"
"add %[rhs_ptr], %[rhs_ptr], %[run_sv_len]\n"
COMPUTE_DOT(z4, z2, z3)
"subs %[run_k_depth], %[run_k_depth], #1\n"
"bgt 1b\n"
"ptrue p2.s, all\n"
"saddv d1, p2, z4.s\n"
"fmov x2, d1\n"
"add %[accum], %[accum], x2\n"
: // outputs
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[run_k_depth] "+wr"(run_k_depth),
[accum] "+wr"(accum)
: // inputs
[run_sv_len] "r"(run_sv_len), [oa] "r"(oa), [ob] "r"(ob)
: // clobbers
"cc", "memory",
"d1", "x2",
"z0", "z1", "z2", "z3", "z4", "p0", "p1", "p2"
);
return (double) accum;
}
#if !defined(TRANSA) || defined(TRANSB)
// performs transposition (n0 * n1) -> (n1 * n0) assuming col major
static inline void simplest_transpose(const void *in, void *out, size_t n0, size_t ld0, size_t n1) {
// since we care only about size, we can use signed type always
BLASINT8* typed_in = (BLASINT8*) in;
BLASINT8* typed_out = (BLASINT8*) out;
for (size_t i = 0; i < n1; ++i) {
for (size_t j = 0; j < n0; ++j) {
typed_out[i + j * n1] = typed_in[j + i * ld0];
}
}
}
#endif // !defined(TRANSA) || defined(TRANSB)
// A in row-major, B in col-major
void ADD_SUFFIX(small_kernel)(const size_t m, const size_t n, const size_t k, const float alpha,
const void *a, const size_t lda, const BLASINT8 oa,
const void *b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t *c, const size_t ldc, const int32_t *oc) {
double double_alpha = (double) alpha;
// we use typed pointers only for indexing, so we don't care about signess
#ifdef TRANSA
BLASINT8* a_typed = (BLASINT8*) a;
const size_t used_lda = lda;
#else
BLASINT8* a_typed = (BLASINT8*) aligned_alloc(128, sizeof(BLASINT8) * m * k);
simplest_transpose(a, a_typed, m, lda, k);
const size_t used_lda = k;
#endif // TRANSA
#ifndef TRANSB
BLASINT8* b_typed = (BLASINT8*) b;
const size_t used_ldb = ldb;
#else // TRANSB
BLASINT8* b_typed = (BLASINT8*) aligned_alloc(128, sizeof(BLASINT8) * k * n);
simplest_transpose(b, b_typed, n, ldb, k);
const size_t used_ldb = k;
#endif // TRANSB
BLASINT8 oa_buf[KERNEL_K_STEP];
BLASINT8 ob_buf[KERNEL_K_STEP];
for (size_t i = 0; i < KERNEL_K_STEP; ++i) {
oa_buf[i] = oa;
ob_buf[i] = ob;
}
// printf("\n========\n");
for (size_t mi = 0; mi < m; ++mi) {
for (size_t ni = 0; ni < n; ++ni) {
// printf("mi = %lu, ni = %lu, oc_idx = %lu\n", mi, ni, OC_IDX(mi, ni));
double tmp = compute_dot(k, a_typed + mi * used_lda, oa_buf, b_typed + ni * used_ldb, ob_buf, KERNEL_K_STEP);
c[mi + ni * ldc] = round(tmp * double_alpha + ((double)(beta * ((float)c[mi + ni * ldc]) + oc[OC_IDX(mi, ni)])));
}
}
#ifdef TRANSA
free(a_typed);
#endif // TRANSA
#ifdef TRANSB
free(b_typed);
#endif // TRANSB
}
#ifdef __cplusplus
} // extern "C"
#endif /* __cplusplus */

View File

@@ -1,134 +0,0 @@
#include <stdint.h>
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
#if defined(TRANSA)
// row major
#define INDEXING_A(row_idx, col_idx, lda) ((col_idx) * (lda) + row_idx)
#define ADD_A_SUFF(name) ADD_PACK_A_T_SUFF(name)
#elif defined(NOTRANSA)
// col major
#define INDEXING_A(row_idx, col_idx, lda) ((row_idx) * (lda) + col_idx)
#define ADD_A_SUFF(name) ADD_PACK_A_N_SUFF(name)
#else
#error "Neither TRANSA or NOTRANSA is defined"
#endif
#define OLD_N_SIZE 8
#define NEW_N_SIZE 4
void ADD_A_SUFF(pack_a)(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda, const BLASINT8 oa) {
LHS_INT_TYPE* bufferA_typed = (LHS_INT_TYPE*) bufferA;
LHS_INT_TYPE* curr_a_ptr_typed = (LHS_INT_TYPE*) curr_a_ptr;
// printf("m_block_size:%lu ,k_block_size: %lu\n", m_block_size, k_block_size);
for(size_t old_split_n = 0; old_split_n < (m_block_size / OLD_N_SIZE); old_split_n++) {
for(size_t split_k = 0; split_k < (k_block_size / KERNEL_K_STEP); split_k++) {
for(size_t old_idx_n = 0; old_idx_n < OLD_N_SIZE; old_idx_n++) {
for(size_t idx_k = 0; idx_k < KERNEL_K_STEP; idx_k++) {
size_t n_idx = old_split_n * OLD_N_SIZE + old_idx_n;
size_t new_split_n = n_idx / NEW_N_SIZE;
size_t new_idx_n = n_idx % NEW_N_SIZE;
size_t old_buff_idx =
old_split_n * OLD_N_SIZE * lda +
split_k * OLD_N_SIZE * KERNEL_K_STEP +
old_idx_n * KERNEL_K_STEP +
idx_k;
size_t new_buff_idx =
new_split_n * NEW_N_SIZE * k_block_size +
split_k * NEW_N_SIZE * KERNEL_K_STEP +
new_idx_n * KERNEL_K_STEP +
idx_k;
bufferA_typed[new_buff_idx] = curr_a_ptr_typed[old_buff_idx];
}
}
}
}
// for(size_t n_idx = 0; n_idx < m_block_size; n_idx++) {
// for(size_t k_idx = 0; k_idx < k_block_size; k_idx++) {
// size_t old_split_n = n_idx / OLD_N_SIZE;
// size_t old_idx_n = n_idx % OLD_N_SIZE;
// size_t new_split_n = n_idx / NEW_N_SIZE;
// size_t new_idx_n = n_idx % NEW_N_SIZE;
// size_t split_k = k_idx / KERNEL_K_STEP;
// size_t idx_k = k_idx % KERNEL_K_STEP;
// size_t old_buff_idx =
// old_split_n * OLD_N_SIZE * lda +
// split_k * OLD_N_SIZE * KERNEL_K_STEP +
// old_idx_n * KERNEL_K_STEP +
// idx_k;
// size_t new_buff_idx =
// new_split_n * NEW_N_SIZE * k_block_size +
// split_k * NEW_N_SIZE * KERNEL_K_STEP +
// new_idx_n * KERNEL_K_STEP +
// idx_k;
// bufferA_typed[new_buff_idx] = curr_a_ptr_typed[old_buff_idx] + oa;
// }
// }
// size_t k_block_size_up = (k_block_size + KERNEL_K_STEP - 1) / KERNEL_K_STEP * KERNEL_K_STEP;
// size_t k_portions = k_block_size / KERNEL_K_STEP;
// size_t k_resid = k_block_size - KERNEL_K_STEP * k_portions;
// size_t m_portions = m_block_size / KERNEL_M_STEP;
// size_t m_resid = m_block_size - KERNEL_M_STEP * m_portions;
// for (size_t im4 = 0; im4 < m_portions; ++im4) {
// for (size_t ik16 = 0; ik16 < k_portions; ++ik16) {
// for (size_t im = 0; im < KERNEL_M_STEP; ++im) {
// for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * KERNEL_M_STEP * ik16 + k_block_size_up * KERNEL_M_STEP * im4] =
// curr_a_ptr_typed[INDEXING_A((ik16 * KERNEL_K_STEP + ik), (im4 * KERNEL_M_STEP + im), lda)] + oa;
// }
// }
// }
// if (k_resid) {
// for (size_t im = 0; im < KERNEL_M_STEP; ++im) {
// for (size_t ik = 0; ik < k_resid; ++ik) {
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * KERNEL_M_STEP * k_portions + k_block_size_up * KERNEL_M_STEP * im4] =
// curr_a_ptr_typed[INDEXING_A((k_portions * KERNEL_K_STEP + ik), (im4 * KERNEL_M_STEP + im), lda)] + oa;
// }
// }
// for (size_t im = 0; im < KERNEL_M_STEP; ++im) {
// for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * KERNEL_M_STEP * k_portions + k_block_size_up * KERNEL_M_STEP * im4] = 0;
// }
// }
// }
// }
// if (m_resid) {
// for (size_t ik16 = 0; ik16 < k_portions; ++ik16) {
// for (size_t im = 0; im < m_resid; ++im) {
// for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * m_resid * ik16 + k_block_size_up * KERNEL_M_STEP * m_portions] =
// curr_a_ptr_typed[INDEXING_A((ik16 * KERNEL_K_STEP + ik), (m_portions * KERNEL_M_STEP + im), lda)] + oa;
// }
// }
// }
// if (k_resid) {
// for (size_t im = 0; im < m_resid; ++im) {
// for (size_t ik = 0; ik < k_resid; ++ik) {
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * m_resid * k_portions + k_block_size_up * KERNEL_M_STEP * m_portions] =
// curr_a_ptr_typed[INDEXING_A((k_portions * KERNEL_K_STEP + ik), (m_portions * KERNEL_M_STEP + im), lda)] + oa;
// }
// }
// for (size_t im = 0; im < m_resid; ++im) {
// for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * m_resid * k_portions + k_block_size_up * KERNEL_M_STEP * m_portions] = 0;
// }
// }
// }
// }
}
#ifdef __cplusplus
} // extern "C"
#endif /* __cplusplus */

View File

@@ -1,96 +0,0 @@
#include <stdint.h>
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
#if defined(TRANSB)
// row major
#define INDEXING_B(row_idx, col_idx, ldb) ((col_idx) * (ldb) + row_idx)
#define ADD_B_SUFF(name) ADD_PACK_B_T_SUFF(name)
#elif defined(NOTRANSB)
// col major
#define INDEXING_B(row_idx, col_idx, ldb) ((row_idx) * (ldb) + col_idx)
#define ADD_B_SUFF(name) ADD_PACK_B_N_SUFF(name)
#else
#error "Neither TRANSB or NOTRANSB is defined."
#endif
void ADD_B_SUFF(pack_b)(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb, const BLASINT8 ob) {
RHS_INT_TYPE* bufferB_typed = (RHS_INT_TYPE*) bufferB;
RHS_INT_TYPE* curr_b_ptr_typed = (RHS_INT_TYPE*) curr_b_ptr;
size_t k_block_size_up = (k_block_size + KERNEL_K_STEP - 1) / KERNEL_K_STEP * KERNEL_K_STEP;
size_t k_portions = k_block_size / KERNEL_K_STEP;
size_t k_resid = k_block_size - KERNEL_K_STEP * k_portions;
size_t n_portions = n_block_size / KERNEL_N_STEP;
size_t n_resid = n_block_size - KERNEL_N_STEP * n_portions;
for (size_t in4 = 0; in4 < n_portions; ++in4) {
for (size_t in = 0; in < KERNEL_N_STEP; ++in) {
for (size_t ik16 = 0; ik16 < k_block_size / KERNEL_K_STEP; ++ik16) {
for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * KERNEL_N_STEP * ik16 + k_block_size_up * KERNEL_N_STEP * in4] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * in4 + in), (KERNEL_K_STEP * ik16 + ik), ldb)] + ob;
}
}
if (k_resid) {
for (size_t ik = 0; ik < k_resid; ++ik) {
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * KERNEL_N_STEP * k_portions + k_block_size_up * KERNEL_N_STEP * in4] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * in4 + in), (KERNEL_K_STEP * k_portions + ik), ldb)] + ob;
}
for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * KERNEL_N_STEP * k_portions + k_block_size_up * KERNEL_N_STEP * in4] = 0;
}
}
}
}
if (n_resid) {
for (size_t in = 0; in < n_resid; ++in) {
for (size_t ik16 = 0; ik16 < k_block_size / KERNEL_K_STEP; ++ik16) {
for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * n_resid * ik16 + k_block_size_up * KERNEL_N_STEP * n_portions] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * n_portions + in), (KERNEL_K_STEP * ik16 + ik), ldb)] + ob;
}
}
if (k_resid) {
for (size_t ik = 0; ik < k_resid; ++ik) {
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * n_resid * k_portions + k_block_size_up * KERNEL_N_STEP * n_portions] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * n_portions + in), (KERNEL_K_STEP * k_portions + ik), ldb)] + ob;
}
for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * n_resid * k_portions + k_block_size_up * KERNEL_N_STEP * n_portions] = 0;
}
}
}
}
// printf("n_block_size:%lu ,k_block_size: %lu\n", n_block_size, k_block_size);
// for(size_t n_idx = 0; n_idx < n_block_size; n_idx++) {
// for(size_t k_idx = 0; k_idx < k_block_size; k_idx++) {
// size_t old_split_n = n_idx / OLD_N_SIZE;
// size_t old_idx_n = n_idx % OLD_N_SIZE;
// size_t new_split_n = n_idx / NEW_N_SIZE;
// size_t new_idx_n = n_idx % NEW_N_SIZE;
// size_t split_k = k_idx / KERNEL_K_STEP;
// size_t idx_k = k_idx % KERNEL_K_STEP;
// size_t old_buff_idx =
// old_split_n * OLD_N_SIZE * ldb +
// split_k * OLD_N_SIZE * KERNEL_K_STEP +
// old_idx_n * KERNEL_K_STEP +
// idx_k;
// size_t new_buff_idx =
// new_split_n * NEW_N_SIZE * k_block_size +
// split_k * NEW_N_SIZE * KERNEL_K_STEP +
// new_idx_n * KERNEL_K_STEP +
// idx_k;
// bufferB_typed[new_buff_idx] = curr_b_ptr_typed[old_buff_idx] + ob;
// }
// }
}
#ifdef __cplusplus
} // extern "C"
#endif /* __cplusplus */

View File

@@ -1,25 +0,0 @@
#include <stdint.h>
#include <math.h>
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
#include "beta_macros.h"
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
// Fixed OC
void BETA_SUFF(post_ops)(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc) {
float* current_c_float_ptr = (float*) current_c_ptr;
double double_alpha = (double) alpha;
for (size_t n_idx = 0; n_idx < n_block; ++n_idx) {
for (size_t m_idx = 0; m_idx < m; ++m_idx) {
current_c_ptr[m_idx + n_idx * ldc] = round(((double)(current_c_float_ptr[m_idx + n_idx * ldc]))
+ double_alpha * ((double) bufferC[m_idx + n_idx * LDC(m, ldc)]));
}
}
}
#ifdef __cplusplus
} // extern "C"
#endif /* __cplusplus */

View File

@@ -1,35 +0,0 @@
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include "integer_gemm_kernels.h"
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
// here we will use BLASINT8, because we care only about type size, not signness (no math operations required)
void gemm_impl_8bit(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha,
const void* a, size_t lda, const BLASINT8 oa,
const void* b, size_t ldb, const BLASINT8 ob,
float beta, int32_t* c, size_t ldc, const int32_t* oc, size_t small_switch) {
void (*small_kernel)(const size_t, const size_t, const size_t, const float,
const void *, const size_t, const BLASINT8,
const void *, const size_t, const BLASINT8,
const float, int32_t *, const size_t, const int32_t *) = arg->small_kernel;
if (m * n * k < small_switch) { // experimentally measured constant
small_kernel(m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
} else {
// Corner cases optimizations
if ((alpha == 1.0f) && (beta == 0.0f || beta == 1.0f)) {
gemm_driver_opt(arg, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
} else {
gemm_driver(arg, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
}
}
}
#ifdef __cplusplus
}
#endif /* __cplusplus */

View File

@@ -1,30 +0,0 @@
cmake_minimum_required(VERSION 3.14.1)
project(THREAD_TEST)
set(CMAKE_CXX_FLAGS ${CFLAGS_OPT})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
message(STATUS "Current source dir: ${CMAKE_CURRENT_SOURCE_DIR}/..")
set(BLAS_PATH "")
message(STATUS "BLAS_PATH=${BLAS_PATH}")
if ("${BLAS_PATH}" STREQUAL "")
set(BLAS_LIB "${CMAKE_CURRENT_SOURCE_DIR}/../libint8gemm.so")
message(STATUS "BLAS_LIB=${BLAS_LIB}")
# Add threading library to linker
#find_package(Threads)
add_executable(integer_tester integer_gemm.cpp)
#target_include_directories(integer_tester PUBLIC "${BLAS_PATH}")
#target_include_directories(integer_tester PUBLIC "${BLAS_BUILD_PATH}")
target_link_libraries(integer_tester PRIVATE ${BLAS_LIB})
set_property(TARGET integer_tester PROPERTY CXX_STANDARD 17)
add_test(integer_tester ${CMAKE_CURRENT_BINARY_DIR}/integer_tester)
else ()
message(FATAL_ERROR "Can not find int8_gemm path, pls set right path!")
endif ()

View File

@@ -1,438 +0,0 @@
#include <malloc.h>
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <exception>
#include <iostream>
#include <limits>
#include <random>
#include <type_traits>
#include <vector>
/* matrix saved in rows or cols */
typedef enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 } CBLAS_ORDER;
/* matrix transpose or conjugate transpose */
typedef enum CBLAS_TRANSPOSE {
CblasNoTrans = 111,
CblasTrans = 112,
CblasConjTrans = 113, // conjugate transpose
CblasConjNoTrans = 114
} CBLAS_TRANSPOSE;
typedef CBLAS_ORDER CBLAS_LAYOUT;
typedef enum CBLAS_OFFSET { CblasRowOffset = 171, CblasColOffset = 172, CblasFixOffset = 173 } CBLAS_OFFSET;
typedef int8_t BLASINT8;
typedef uint8_t BLASUINT8;
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
void cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void cblas_gemm_u8u8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void cblas_gemm_s8u8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void cblas_gemm_u8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
#ifdef __cplusplus
}
#endif /* __cplusplus */
namespace test {
namespace tools {
template <typename T, std::size_t alignment = 128>
struct aligned_allocator {
using value_type = T;
using pointer = T*;
using const_pointer = const T*;
using reference = T&;
using const_reference = const T&;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
template <typename U>
struct rebind {
typedef aligned_allocator<U, alignment> other;
};
[[nodiscard]] T* allocate(std::size_t n) {
if (n > std::numeric_limits<std::size_t>::max() / sizeof(T)) throw std::bad_array_new_length();
if (auto p = static_cast<T*>(memalign(alignment, n * sizeof(T)))) {
return p;
}
throw std::bad_alloc();
}
void deallocate(T* p, std::size_t n) noexcept {
(void)(n);
free(p);
}
~aligned_allocator() {}
};
template <typename T, std::size_t alignment_1, typename U, std::size_t alignment_2>
bool operator==(const aligned_allocator<T, alignment_1>&, const aligned_allocator<U, alignment_2>&) {
return (alignment_1 == alignment_2) && std::is_same_v<T, U>;
}
template <typename T, std::size_t alignment_1, typename U, std::size_t alignment_2>
bool operator!=(const aligned_allocator<T, alignment_1>& lhs, const aligned_allocator<U, alignment_2>& rhs) {
return !(lhs == rhs);
}
template <typename Func, typename... Args>
double timing(Func&& func, Args&&... args) {
double time = 0.0;
double time_begin = 0.0;
std::size_t n_run = 0;
auto start_begin = std::chrono::steady_clock::now();
std::forward<Func>(func)(std::forward<Args>(args)...);
auto end_begin = std::chrono::steady_clock::now();
time_begin = std::chrono::duration_cast<std::chrono::nanoseconds>(end_begin - start_begin).count() / 1e9;
n_run = std::max<std::size_t>(std::size_t(1.0 / time_begin), 3);
auto start = std::chrono::steady_clock::now();
for (std::size_t i = 0; i < n_run; ++i) {
std::forward<Func>(func)(std::forward<Args>(args)...);
}
auto end = std::chrono::steady_clock::now();
time += std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count() / 1e9;
return time / n_run;
}
} // namespace tools
namespace helpers {
std::size_t get_oc_size(CBLAS_OFFSET offset, std::size_t m, std::size_t n) {
std::size_t ret_val = 0;
switch (offset) {
case CblasFixOffset:
ret_val = 1;
break;
case CblasColOffset:
ret_val = m;
break;
case CblasRowOffset:
ret_val = n;
break;
default:
std::cout << "Incorrect value of offset to the function " << __PRETTY_FUNCTION__ << std::endl;
}
return ret_val;
}
template <typename T>
auto get_ab_matrix(CBLAS_LAYOUT lt, CBLAS_TRANSPOSE trans_, T&& non_trans_mtx, T&& trans_mtx) {
if (lt == CblasColMajor) {
if (trans_ == CblasNoTrans) {
return non_trans_mtx.data();
} else {
return trans_mtx.data();
}
} else {
if (trans_ == CblasNoTrans) {
return trans_mtx.data();
} else {
return non_trans_mtx.data();
}
}
}
auto get_ldab(CBLAS_LAYOUT lt, CBLAS_TRANSPOSE trans_mtx, std::size_t ld_n, std::size_t ld_t) {
if (lt == CblasColMajor) {
if (trans_mtx == CblasNoTrans) {
return ld_n;
} else {
return ld_t;
}
} else {
if (trans_mtx == CblasNoTrans) {
return ld_t;
} else {
return ld_n;
}
}
}
// returns copy of the matrix
template <typename T>
auto get_c_matrix(CBLAS_LAYOUT lt, T&& non_trans_mtx, T&& trans_mtx) {
if (lt == CblasColMajor) {
return non_trans_mtx;
} else {
return trans_mtx;
}
}
auto get_ldc(CBLAS_LAYOUT lt, std::size_t ldc_n, std::size_t ldc_t) {
if (lt == CblasColMajor) {
return ldc_n;
} else {
return ldc_t;
}
}
template <typename A_Type, typename B_Type>
void cblas_gemm_wrapper(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
const A_Type* a, const size_t lda, const int8_t oa, const B_Type* b, const size_t ldb,
const int8_t ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc) {
if constexpr (std::is_same_v<A_Type, std::int8_t>) {
if constexpr (std::is_same_v<B_Type, std::int8_t>) {
cblas_gemm_s8s8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
} else {
cblas_gemm_s8u8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
}
} else {
if constexpr (std::is_same_v<B_Type, std::int8_t>) {
cblas_gemm_u8s8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
} else {
cblas_gemm_u8u8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
}
}
}
std::size_t return_oc_idx(const CBLAS_OFFSET offsetc, std::size_t mi, std::size_t ni) {
return (offsetc == CblasFixOffset) ? 0 : ((offsetc == CblasColOffset) ? mi : ni);
}
} // namespace helpers
enum class status_t { passed, failed };
std::ostream& operator<<(std::ostream& os, const status_t& st) {
if (status_t::passed == st) {
os << "PASSED";
} else if (status_t::failed == st) {
os << "FAILED";
}
return os;
}
// column major
template <typename A_Type, typename B_Type>
void ref_gemm(const CBLAS_OFFSET offsetc, const std::size_t m, const std::size_t n, const std::size_t k,
const float alpha, const A_Type* a, const std::size_t lda, const std::int8_t oa, const B_Type* b,
const std::size_t ldb, const std::int8_t ob, const float beta, std::int32_t* c, const std::size_t ldc,
const std::int32_t* oc) {
for (std::size_t mi = 0; mi < m; ++mi) {
for (std::size_t ni = 0; ni < n; ++ni) {
std::int32_t tmp = 0;
for (std::size_t ki = 0; ki < k; ++ki) {
tmp += (a[mi + ki * lda] + oa) * (b[ki + ni * ldb] + ob);
}
c[mi + ni * ldc] = std::round(alpha * static_cast<double>(tmp) +
static_cast<double>(beta * static_cast<float>(c[mi + ni * ldc])) +
static_cast<float>(oc[helpers::return_oc_idx(offsetc, mi, ni)]));
}
}
}
template <typename DataType>
void fill_random(DataType* buffer, std::size_t len) {
static std::mt19937 generator(0);
std::uniform_int_distribution<DataType> dist(0, 64);
for (std::size_t i = 0; i < len; i++) {
buffer[i] = static_cast<DataType>(dist(generator));
}
}
template <typename DataType>
void fill_const(DataType* buffer, std::size_t len) {
for (std::size_t i = 0; i < len; i++) {
buffer[i] = DataType{-8};
}
}
// performs transposition (n0 * n1) -> (n1 * n0), assuming col major
template <typename T>
void simplest_transpose(T* in, T* out, std::size_t n0, std::size_t n1, std::size_t ld0, std::size_t ld1) {
for (std::size_t i = 0; i < n0; ++i) {
for (std::size_t j = 0; j < n1; ++j) {
out[i + j * ld1] = in[j + i * ld0];
}
}
}
template <typename DataType>
status_t compare(DataType* ref, DataType* test, std::size_t m, std::size_t n, std::size_t ld) {
for (std::size_t mi = 0; mi < m; ++mi) {
for (std::size_t ni = 0; ni < n; ++ni) {
if (ref[mi + ni * ld] != test[mi + ni * ld]) {
return status_t::failed;
}
}
}
return status_t::passed;
}
template <typename DataType>
void print_matrix(DataType* buffer, std::size_t m, std::size_t n) {
for (std::size_t mi = 0; mi < m; ++mi) {
for (std::size_t ni = 0; ni < n; ++ni) {
std::cout << static_cast<std::int32_t>(buffer[mi + ni * m]) << ", ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}
template <typename A_Type, typename B_Type>
status_t gemm(std::size_t m, std::size_t n, std::size_t k, float alpha, float beta) {
std::int8_t oa = 4;
std::int8_t ob = 9;
std::size_t lda_n = m;
std::size_t ldb_n = k;
std::size_t ldc_n = m;
std::size_t lda_t = k;
std::size_t ldb_t = n;
std::size_t ldc_t = n;
if (std::getenv("LD_STRIDE")) {
lda_n += 2;
ldb_n += 7;
ldc_n += 3;
lda_t += 8;
ldb_t += 3;
ldc_t += 23;
}
bool only_performance = false;
if (std::getenv("ONLY_PERF")) {
only_performance = true;
}
std::vector<A_Type, tools::aligned_allocator<A_Type, 128>> a_n(lda_n * k);
std::vector<A_Type, tools::aligned_allocator<A_Type, 128>> a_t(m * lda_t);
std::vector<B_Type, tools::aligned_allocator<B_Type, 128>> b_n(ldb_n * n);
std::vector<B_Type, tools::aligned_allocator<B_Type, 128>> b_t(k * ldb_t);
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_ref(ldc_n * n);
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_n(ldc_n * n);
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_t(m * ldc_t);
// fill the whole array even if ld* > corresponding dim
fill_random(a_n.data(), a_n.size());
fill_random(b_n.data(), b_n.size());
simplest_transpose(a_n.data(), a_t.data(), m, k, lda_n, lda_t);
simplest_transpose(b_n.data(), b_t.data(), k, n, ldb_n, ldb_t);
fill_const(c_ref.data(), ldc_n * n);
c_n = c_ref;
simplest_transpose(c_n.data(), c_t.data(), m, ldc_n, n, ldc_t);
auto return_st = status_t::passed;
double total = 0;
size_t cnt = 0;
for (auto c_offset : {CblasFixOffset, CblasColOffset, CblasRowOffset}) {
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> oc(helpers::get_oc_size(c_offset, m, n));
for (std::size_t i = 0; i < oc.size(); ++i) {
oc[i] = i + i;
}
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_ref_copy = c_ref;
if (!only_performance) {
ref_gemm(c_offset, m, n, k, alpha, a_n.data(), lda_n, oa, b_n.data(), ldb_n, ob, beta, c_ref_copy.data(), ldc_n,
oc.data());
}
for (auto layout : {CblasColMajor, CblasRowMajor}) {
for (auto transa : {CblasNoTrans, CblasTrans}) {
for (auto transb : {CblasNoTrans, CblasTrans}) {
auto&& c_tested = helpers::get_c_matrix(layout, c_n, c_t);
if (!only_performance) {
helpers::cblas_gemm_wrapper(
layout, transa, transb, c_offset, m, n, k, alpha, helpers::get_ab_matrix(layout, transa, a_n, a_t),
helpers::get_ldab(layout, transa, lda_n, lda_t), oa, helpers::get_ab_matrix(layout, transb, b_n, b_t),
helpers::get_ldab(layout, transb, ldb_n, ldb_t), ob, beta, c_tested.data(),
helpers::get_ldc(layout, ldc_n, ldc_t), oc.data());
// transpose c_tested to col-major if required
auto loc_st = status_t::passed;
if (layout == CblasRowMajor) {
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_tested_n(ldc_n * n);
simplest_transpose(c_tested.data(), c_tested_n.data(), n, ldc_t, m, ldc_n);
loc_st = compare(c_ref_copy.data(), c_tested_n.data(), m, n, ldc_n);
} else {
loc_st = compare(c_ref_copy.data(), c_tested.data(), m, n, ldc_n);
}
if (loc_st != status_t::passed) {
std::cout << "-";
return_st = status_t::failed;
} else {
std::cout << "+";
}
} else {
double cur = (2.0 * m * n * k) /
tools::timing(helpers::cblas_gemm_wrapper<A_Type, B_Type>, layout, transa, transb, c_offset, m,
n, k, alpha, helpers::get_ab_matrix(layout, transa, a_n, a_t),
helpers::get_ldab(layout, transa, lda_n, lda_t), oa,
helpers::get_ab_matrix(layout, transb, b_n, b_t),
helpers::get_ldab(layout, transb, ldb_n, ldb_t), ob, beta, c_tested.data(),
helpers::get_ldc(layout, ldc_n, ldc_t), oc.data()) /
1e12;
total += cur;
++cnt;
std::cout << cur << ", ";
}
}
}
}
}
if (only_performance) {
std::cout << "Average " << total / cnt << " TFlops";
}
std::cout << " ";
return return_st;
}
} // namespace test
int main(int argc, char** argv) {
std::size_t m = 128;
std::size_t n = 128;
std::size_t k = 128;
float alpha = 1.0f;
float beta = 1.0f;
if (argc > 1) {
m = std::stoi(argv[1]);
if (argc > 2) {
n = std::stoi(argv[2]);
if (argc > 3) {
k = std::stoi(argv[3]);
if (argc > 4) {
alpha = std::stof(argv[4]);
if (argc > 5) {
beta = std::stof(argv[5]);
}
}
}
}
}
std::cout << "Testing matrix m = " << m << ", n = " << n << ", k = " << k << ", alpha = " << alpha
<< ", beta = " << beta << std::endl;
std::cout << "\tTesting i8i8i32: " << test::gemm<std::int8_t, std::int8_t>(m, n, k, alpha, beta) << std::endl;
std::cout << "\tTesting i8u8i32: " << test::gemm<std::int8_t, std::uint8_t>(m, n, k, alpha, beta) << std::endl;
std::cout << "\tTesting u8i8i32: " << test::gemm<std::uint8_t, std::int8_t>(m, n, k, alpha, beta) << std::endl;
std::cout << "\tTesting u8u8i32: " << test::gemm<std::uint8_t, std::uint8_t>(m, n, k, alpha, beta) << std::endl;
return 0;
}

View File

@@ -1,160 +0,0 @@
cmake_minimum_required(VERSION 3.14)
set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED ON)
# can be compiled for SVE256 (32) or SVE512 (64)
set(SV_LENGTH 32) # in bytes
include_directories("${PROJECT_SOURCE_DIR}")
include_directories("${PROJECT_SOURCE_DIR}/include")
include_directories("${PROJECT_SOURCE_DIR}/standalone")
add_compile_options(-fPIC -fvisibility=hidden -fstack-protector-strong -march=armv8.3-a+sve+i8mm -O3)
# add_compile_options(-Wall -Wextra -Werror)
# sources are split into several groups: matmul kernels (sources are compiled multiple times),
# packing kernels (sources are compiled multiple times),
# sequential/parallel pipelines,
# interface (sources are compiled multiple times)
set(INT_GEMM_KERNELS "")
set(INT_PACK_KERNELS "")
set(INT_GEMM_INTERFACE "")
set(INT_GEMM_PARALLEL "")
set(INT_GEMM_SEQ "")
set(BETA_KERNELS "")
set(POST_OPS_KERNELS "")
set(INT_SMALL_KERNELS "")
set(GEMM_DRIVERS "")
# Supported precisions are i/u for A and B matrices (4 combinations)
set(LHS_TYPES LHS_INT LHS_UINT)
set(RHS_TYPES RHS_INT RHS_UINT)
# compile matrix-multiplication kernels multiple times
set(INTEGER_MM_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_kernels.c")
set(M_SIZES 1 2 3 4)
set(N_SIZES 1 2 3 4)
foreach(M_SIZE ${M_SIZES})
foreach(N_SIZE ${N_SIZES})
foreach(LHS_TYPE ${LHS_TYPES})
foreach(RHS_TYPE ${RHS_TYPES})
add_library(int4_integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH} OBJECT ${INTEGER_MM_KERNELS_SRC})
target_compile_options(int4_integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH} PUBLIC -DM_SIZE=${M_SIZE}
-DN_SIZE=${N_SIZE} -D${LHS_TYPE} -D${RHS_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(int4_integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH} PUBLIC
"${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND INT_GEMM_KERNELS $<TARGET_OBJECTS:int4_integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH}>)
endforeach()
endforeach()
endforeach()
endforeach()
# compile interface multiple times
set(INTEGER_GEMM_IFACE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_interface.c")
foreach(LHS_TYPE ${LHS_TYPES})
foreach(RHS_TYPE ${RHS_TYPES})
add_library(int4_integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH} OBJECT ${INTEGER_GEMM_IFACE_SRC})
target_compile_options(int4_integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH} PUBLIC -D${LHS_TYPE} -D${RHS_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(int4_integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH} PUBLIC
"${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND INT_GEMM_INTERFACE $<TARGET_OBJECTS:int4_integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH}>)
endforeach()
endforeach()
# compile threading layer
# set(INTEGER_GEMM_PAR_PIPE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/parallel_int_gemm_pipeline.c")
# add_library(integer_gemm_par_pipe_${SV_LENGTH} OBJECT ${INTEGER_GEMM_PAR_PIPE_SRC})
# target_compile_options(int4_integer_gemm_par_pipe_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH})
# target_include_directories(int4_integer_gemm_par_pipe_${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
# list(APPEND INT_GEMM_PARALLEL $<TARGET_OBJECTS:int4_integer_gemm_par_pipe_${SV_LENGTH}>)
# compile sequential layer
set(INTEGER_GEMMSEQ_PIPE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/sequential_int_gemm_pipeline.c")
add_library(int4_integer_gemm_seq_pipe_${SV_LENGTH} OBJECT ${INTEGER_GEMMSEQ_PIPE_SRC})
target_compile_options(int4_integer_gemm_seq_pipe_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(int4_integer_gemm_seq_pipe_${SV_LENGTH} PUBLIC
"${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND INT_GEMM_SEQ $<TARGET_OBJECTS:int4_integer_gemm_seq_pipe_${SV_LENGTH}>)
# compile packingA kernels
set(INT_PACK_A_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_pack_a_kernels.c")
set(TRANSA_VALS TRANSA NOTRANSA)
foreach(LHS_TYPE ${LHS_TYPES})
foreach(TRANSA_VAL ${TRANSA_VALS})
add_library(int4_integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH} OBJECT ${INT_PACK_A_KERNELS_SRC})
target_compile_options(int4_integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH} -D${LHS_TYPE} -D${TRANSA_VAL})
target_include_directories(int4_integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND INT_PACK_KERNELS $<TARGET_OBJECTS:int4_integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH}>)
endforeach()
endforeach()
# compile packingB kernels
set(INT_PACK_B_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_pack_b_kernels.c")
set(TRANSB_VALS TRANSB NOTRANSB)
foreach(RHS_TYPE ${RHS_TYPES})
foreach(TRANSB_VAL ${TRANSB_VALS})
add_library(int4_integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH} OBJECT ${INT_PACK_B_KERNELS_SRC})
target_compile_options(int4_integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH} -D${RHS_TYPE} -D${TRANSB_VAL})
target_include_directories(int4_integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND INT_PACK_KERNELS $<TARGET_OBJECTS:int4_integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH}>)
endforeach()
endforeach()
# compile beta kernels
set(BETA_OPTS BETA_OPT BETA_NO_OPT)
foreach(B_TYPE ${BETA_OPTS})
add_library(int4_beta_kernels_${B_TYPE} OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/integer_beta_kernels.c")
target_compile_options(int4_beta_kernels_${B_TYPE} PUBLIC -D${B_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(int4_beta_kernels_${B_TYPE} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND BETA_KERNELS $<TARGET_OBJECTS:int4_beta_kernels_${B_TYPE}>)
endforeach()
# compile int gemm drivers
foreach(B_TYPE ${BETA_OPTS})
add_library(int4_gemm_driver_${B_TYPE} OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_driver.c")
target_compile_options(int4_gemm_driver_${B_TYPE} PUBLIC -D${B_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(int4_gemm_driver_${B_TYPE} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND GEMM_DRIVERS $<TARGET_OBJECTS:int4_gemm_driver_${B_TYPE}>)
endforeach()
# compile post-ops kernels
foreach(B_TYPE ${BETA_OPTS})
add_library(int4_post_ops_kernels_${B_TYPE} OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/integer_post_ops_kernels.c")
target_compile_options(int4_post_ops_kernels_${B_TYPE} PUBLIC -D${B_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(int4_post_ops_kernels_${B_TYPE} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND POST_OPS_KERNELS $<TARGET_OBJECTS:int4_post_ops_kernels_${B_TYPE}>)
endforeach()
# compile matrix-multiplication small kernels multiple times
set(SMALL_KERNELS_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_small_kernels.c")
set(TRANSA_VALS TRANSA NOTRANSA)
set(TRANSB_VALS TRANSB NOTRANSB)
set(OC_TYPES OC_FIX OC_COL OC_ROW)
foreach(LHS_TYPE ${LHS_TYPES})
foreach(RHS_TYPE ${RHS_TYPES})
foreach(TRANSA_VAL ${TRANSA_VALS})
foreach(TRANSB_VAL ${TRANSB_VALS})
foreach(OC_TYPE ${OC_TYPES})
add_library(int4_small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH} OBJECT ${SMALL_KERNELS_KERNELS_SRC})
target_compile_options(int4_small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH} PUBLIC -D${LHS_TYPE} -D${RHS_TYPE} -D${TRANSA_VAL} -D${TRANSB_VAL} -D${OC_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
target_include_directories(int4_small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
list(APPEND INT_SMALL_KERNELS $<TARGET_OBJECTS:int4_small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH}>)
endforeach()
endforeach()
endforeach()
endforeach()
endforeach()
list(APPEND OBJ_FILES_STANDALONE_DIR ${INT_GEMM_KERNELS} ${INT_GEMM_INTERFACE} ${INT_GEMM_SEQ} ${INT_PACK_KERNELS} ${BETA_KERNELS} ${POST_OPS_KERNELS} ${INT_SMALL_KERNELS} ${GEMM_DRIVERS})
# set(OBJ_FILES_STANDALONE_DIR ${OBJ_FILES_STANDALONE_DIR} PARENT_SCOPE)
# all compiled object files are united into one object library
add_library(prefillint4gemm SHARED ${OBJ_FILES_STANDALONE_DIR})
set_target_properties(prefillint4gemm PROPERTIES
ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/prefillint4gemm
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/prefillint4gemm
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/prefillint4gemm
)

View File

@@ -1,25 +0,0 @@
#ifndef BETA_MACROS_H
#define BETA_MACROS_H
#if defined(OC_FIX)
#define OC_TYPE f
#define OC_IDX(mi, nn) 0
#elif defined(OC_COL)
#define OC_TYPE c
#define OC_IDX(mi, ni) mi
#else
#define OC_TYPE r
#define OC_IDX(mi, ni) ni
#endif
#if defined(BETA_OPT)
#define BETA_SUFF(name) name##_opt
#define LDC(m, ldc) ldc
#define DTYPE int32_t
#else
#define BETA_SUFF(name) name
#define LDC(m, ldc) m
#define DTYPE float
#endif
#endif

View File

@@ -1,63 +0,0 @@
#ifndef __HELPING_MACROS_INT4_H__
#define __HELPING_MACROS_INT4_H__
#include <stdio.h>
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
#if defined(LHS_INT) && defined(LHS_UINT)
#error "Both LHS_INT and LHS_UINT are defined"
#endif
#if defined(RHS_INT) && defined(RHS_UINT)
#error "Both RHS_INT and RHS_UINT are defined"
#endif
#ifdef LHS_INT
#define LHS_TYPE s
#define LHS_INT_TYPE int8_t
#endif
#ifdef LHS_UINT
#define LHS_TYPE u
#define LHS_INT_TYPE uint8_t
#endif
#ifdef RHS_INT
#define RHS_TYPE s
#define RHS_INT_TYPE int8_t
#endif
#ifdef RHS_UINT
#define RHS_TYPE u
#define RHS_INT_TYPE uint8_t
#endif
// mangling macros
#define ADD_M_N_SIZES(name, m_size, n_size) name##_##m_size##x##n_size
#define ADD_M_N_SIZES_MACRO(name, m_size, n_size) ADD_M_N_SIZES(name, m_size, n_size)
#define ADD_TYPES(name, lhs_type, rhs_type) name##_##lhs_type##8##rhs_type##8s32
#define ADD_TYPES_MACRO(name, lhs_type, rhs_type) ADD_TYPES(name, lhs_type, rhs_type)
#define ADD_TYPES_SUFF(name) ADD_TYPES_MACRO(name, LHS_TYPE, RHS_TYPE)
#define ADD_ONE_TYPE_TRANSP(name, type, nt) name##_##type##8_##nt
#define ADD_ONE_TYPE_TRANSP_MACRO(name, type, nt) ADD_ONE_TYPE_TRANSP(name, type, nt)
#define ADD_PACK_A_N_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, LHS_TYPE, n)
#define ADD_PACK_B_N_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, RHS_TYPE, n)
#define ADD_PACK_A_T_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, LHS_TYPE, t)
#define ADD_PACK_B_T_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, RHS_TYPE, t)
#define ADD_TWO_TYPES_TRANSP(name, lhs_type, rhs_type, a_t, b_t, oc_t) \
name##_##lhs_type##8##rhs_type##8s32##_##a_t##b_t##_##oc_t
#define ADD_TWO_TYPES_TRANSP_MACRO(name, lhs_type, rhs_type, a_t, b_t, oc_t) \
ADD_TWO_TYPES_TRANSP(name, lhs_type, rhs_type, a_t, b_t, oc_t)
#define ADD_TRANSP_MACRO(name, a_t, b_t, oc_t) ADD_TWO_TYPES_TRANSP_MACRO(name, LHS_TYPE, RHS_TYPE, a_t, b_t, oc_t)
#ifdef ENABLE_THREADING
#define ADD_THREAD_SUFF(name) name##_thread
#else
#define ADD_THREAD_SUFF(name) name
#endif
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -1,82 +0,0 @@
#include <stdint.h>
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
#include "beta_macros.h"
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
// Column major C, Fixed C_offset
void BETA_SUFF(beta_cf_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
DTYPE oc_val = (DTYPE)*oc;
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
DTYPE beta_val = (DTYPE)beta;
for (size_t ni = 0; ni < n_block_size; ++ni) {
for (size_t mi = 0; mi < m_block_size; ++mi) {
c_typed_ptr[ldc * ni + mi] = beta_val * ((DTYPE)c_ptr[ldc * ni + mi]) + oc_val;
}
}
}
// Column major C, Column major C_offset
void BETA_SUFF(beta_cc_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
DTYPE beta_val = (DTYPE)beta;
for (size_t ni = 0; ni < n_block_size; ++ni) {
for (size_t mi = 0; mi < m_block_size; ++mi) {
c_typed_ptr[ldc * ni + mi] = beta_val * ((DTYPE)c_ptr[ldc * ni + mi]) + ((DTYPE)oc[mi]);
}
}
}
// Column major C, Row major C_offset
void BETA_SUFF(beta_cr_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
DTYPE beta_val = (DTYPE)beta;
for (size_t ni = 0; ni < n_block_size; ++ni) {
for (size_t mi = 0; mi < m_block_size; ++mi) {
c_typed_ptr[ldc * ni + mi] = beta_val * ((DTYPE)c_ptr[ldc * ni + mi]) + ((DTYPE)oc[ni]);
}
}
}
// Row major C, Fixed C_offset
// for row-major we actually swap m and n values so we reswap it here again
void BETA_SUFF(beta_rf_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
DTYPE oc_val = (DTYPE)*oc;
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
DTYPE beta_val = (DTYPE)beta;
for (size_t mi = 0; mi < n_block_size; ++mi) {
for (size_t ni = 0; ni < m_block_size; ++ni) {
c_typed_ptr[ldc * mi + ni] = beta_val * ((DTYPE)c_ptr[ldc * mi + ni]) + oc_val;
}
}
}
// Row major C, Column major C_offset
// for row-major we actually swap m and n values so we reswap it here again
void BETA_SUFF(beta_rc_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
DTYPE beta_val = (DTYPE)beta;
for (size_t mi = 0; mi < n_block_size; ++mi) {
for (size_t ni = 0; ni < m_block_size; ++ni) {
c_typed_ptr[ldc * mi + ni] = beta_val * ((DTYPE)c_ptr[ldc * mi + ni]) + ((DTYPE)oc[mi]);
}
}
}
// Row major C, Row major C_offset
// for row-major we actually swap m and n values so we reswap it here again
void BETA_SUFF(beta_rr_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
DTYPE beta_val = (DTYPE)beta;
for (size_t mi = 0; mi < n_block_size; ++mi) {
for (size_t ni = 0; ni < m_block_size; ++ni) {
c_typed_ptr[ldc * mi + ni] = beta_val * ((DTYPE)c_ptr[ldc * mi + ni]) + ((DTYPE)oc[ni]);
}
}
}
#ifdef __cplusplus
} // extern "C"
#endif /* __cplusplus */

View File

@@ -1,121 +0,0 @@
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include "integer_gemm_kernels.h"
#include "beta_macros.h"
#define OLD_N_SIZE 8
#define PACKED_LD_STEP(n_step, k_step, ldb) (n_step * (ldb / 2) + (k_step / 2) * OLD_N_SIZE)
void BETA_SUFF(gemm_driver)(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha,
const void* a, size_t lda, const BLASINT8 oa,
const void* b, size_t ldb, const BLASINT8 ob,
float beta, int32_t* c, size_t ldc, const int32_t* oc) {
void (**gemm_kernels)(const void*, const void*, int32_t*, size_t, int64_t, int64_t) = arg->gemm_kernels;
void (*pack_a_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8) = arg->pack_a_fun;
void (*pack_b_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8) = arg->pack_b_fun;
void (*beta_func)(int32_t*, const int32_t*, float, size_t, size_t, size_t) = arg->beta_func;
size_t (*a_indexing)(size_t m, size_t n, size_t ld) = arg->a_indexing;
size_t (*b_indexing)(size_t m, size_t n, size_t ld) = arg->b_indexing;
#ifndef BETA_OPT
void (*post_ops_func)(float, const int32_t*, int32_t*, size_t, size_t, size_t) = arg->post_ops_func;
#endif // BETA_OPT
const BLASINT8* a_typed = (const BLASINT8*) a;
const BLASINT8* b_typed = (const BLASINT8*) b;
BLASINT8* bufferA = (BLASINT8*) aligned_alloc(ALIGNMENT, sizeof(BLASINT8) * K_BLOCK * M_BLOCK);
BLASINT8* bufferB = (BLASINT8*) aligned_alloc(ALIGNMENT, sizeof(BLASINT8) * K_BLOCK * N_BLOCK);
// Tmp buffer is not needed when (alpha = 1 and beta = 0/1)
#ifdef BETA_OPT
int32_t* bufferC = c;
#else
int32_t* bufferC = (int32_t*) aligned_alloc(ALIGNMENT, sizeof(int32_t) * m * N_BLOCK);
#endif
if (!bufferA || !bufferB || !bufferC) {
free(bufferA);
free(bufferB);
free(bufferC);
printf("Integer GEMM unsuccessful allocation");
return;
}
// printf("pack b beta: %f\n",beta);
beta_func(c, oc, beta, m, n, ldc);
for (size_t n_block = 0; n_block < n; n_block += N_BLOCK) {
size_t n_block_size = n - n_block;
if (n_block_size > N_BLOCK) {
n_block_size = N_BLOCK;
}
#ifndef BETA_OPT
// fill bufferC w/ zeros
for (size_t tmp_idx = 0; tmp_idx < (m * N_BLOCK); ++tmp_idx) {
bufferC[tmp_idx] = 0;
}
#endif // BETA_OPT
if (alpha != 0.0f){
for (size_t k_block = 0; k_block < k; k_block += K_BLOCK){
size_t k_block_size = k - k_block;
if (k_block_size > K_BLOCK) {
k_block_size = K_BLOCK;
}
size_t k_block_size_up = (k_block_size + KERNEL_K_STEP - 1) / KERNEL_K_STEP * KERNEL_K_STEP;
const BLASINT8* curr_b_ptr = b_typed + b_indexing(k_block, n_block, ldb);
pack_b_fun(bufferB, curr_b_ptr, n_block_size, k_block_size, ldb, ob);
for (size_t m_block = 0; m_block < m; m_block += M_BLOCK) {
size_t m_block_size = m - m_block;
if (m_block_size > M_BLOCK) {
m_block_size = M_BLOCK;
}
const BLASINT8* curr_a_ptr = a_typed + PACKED_LD_STEP(m_block, k_block, lda);
pack_a_fun(bufferA, curr_a_ptr, m_block_size, k_block_size, lda, oa);
// loop over bufferB, taking parts which fit into L1
for (size_t n_sub_block = 0; n_sub_block < n_block_size; n_sub_block += KERNEL_N_STEP) {
size_t n_sub_block_size = n_block_size - n_sub_block;
if (n_sub_block_size > KERNEL_N_STEP) {
n_sub_block_size = KERNEL_N_STEP;
}
BLASINT8* current_bufferB_ptr = bufferB + n_sub_block * k_block_size_up;
// loop over bufferA, taking parts which fit into L1
for (size_t m_sub_block = 0; m_sub_block < m_block_size; m_sub_block += KERNEL_M_STEP) {
size_t m_sub_block_size = m_block_size - m_sub_block;
if (m_sub_block_size > KERNEL_M_STEP) {
m_sub_block_size = KERNEL_M_STEP;
}
BLASINT8* current_bufferA_ptr = bufferA + m_sub_block * k_block_size_up;
#ifdef BETA_OPT
int32_t* current_bufferC_ptr = bufferC + n_block * ldc + n_sub_block * LDC(m, ldc) + m_sub_block + m_block;
#else
int32_t* current_bufferC_ptr = bufferC + n_sub_block * LDC(m, ldc) + m_sub_block + m_block;
#endif
// call kernel which performs loop over k_block_size
gemm_kernels[(n_sub_block_size - 1) + (m_sub_block_size - 1) * KERNEL_N_STEP](current_bufferA_ptr, current_bufferB_ptr, current_bufferC_ptr,
LDC(m, ldc), k_block_size_up, COMP_SV_LEN);
}
}
}
}
}
#ifndef BETA_OPT
// copy C data from bufferC multiplying by alpha and adding initial C data (scaled by beta)
int32_t* current_c_ptr = c + n_block * ldc; // col major
post_ops_func(alpha, bufferC, current_c_ptr, LDC(m, ldc), n_block_size, ldc);
#endif
}
free(bufferA);
free(bufferB);
#ifndef BETA_OPT
free(bufferC);
#endif
}

View File

@@ -1,154 +0,0 @@
#include <stdint.h>
#include <stdio.h>
//#include "cblas.h"
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
/* matrix saved in rows or cols */
typedef enum CBLAS_ORDER {
CblasRowMajor = 101,
CblasColMajor = 102
} CBLAS_ORDER;
/* matrix transpose or conjugate transpose */
typedef enum CBLAS_TRANSPOSE {
CblasNoTrans = 111,
CblasTrans = 112,
CblasConjTrans = 113, // conjugate transpose
CblasConjNoTrans = 114
} CBLAS_TRANSPOSE;
typedef CBLAS_ORDER CBLAS_LAYOUT;
typedef enum CBLAS_OFFSET {
CblasRowOffset = 171,
CblasColOffset = 172,
CblasFixOffset = 173
} CBLAS_OFFSET;
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
#define ADD_KERNEL_SUFF(name, m_size, n_size) ADD_M_N_SIZES_MACRO(ADD_TYPES_MACRO(name, LHS_TYPE, RHS_TYPE), m_size, n_size)
static void (*gemm_kernels[])(const void*, const void*, int32_t*, size_t, int64_t, int64_t) = {
ADD_KERNEL_SUFF(gemm_kernel, 1, 1), ADD_KERNEL_SUFF(gemm_kernel, 1, 2),ADD_KERNEL_SUFF(gemm_kernel, 1, 3), ADD_KERNEL_SUFF(gemm_kernel, 1, 4),
ADD_KERNEL_SUFF(gemm_kernel, 2, 1), ADD_KERNEL_SUFF(gemm_kernel, 2, 2),ADD_KERNEL_SUFF(gemm_kernel, 2, 3), ADD_KERNEL_SUFF(gemm_kernel, 2, 4),
ADD_KERNEL_SUFF(gemm_kernel, 3, 1), ADD_KERNEL_SUFF(gemm_kernel, 3, 2),ADD_KERNEL_SUFF(gemm_kernel, 3, 3), ADD_KERNEL_SUFF(gemm_kernel, 3, 4),
ADD_KERNEL_SUFF(gemm_kernel, 4, 1), ADD_KERNEL_SUFF(gemm_kernel, 4, 2),ADD_KERNEL_SUFF(gemm_kernel, 4, 3), ADD_KERNEL_SUFF(gemm_kernel, 4, 4)
};
static void (*pack_b_funs[])(void*, const void*, size_t, size_t, size_t, const BLASINT8) = {
ADD_PACK_B_N_SUFF(pack_b),
ADD_PACK_B_T_SUFF(pack_b)
};
static void (*pack_a_funs[])(void*, const void*, size_t, size_t, size_t, const BLASINT8) = {
ADD_PACK_A_N_SUFF(pack_a),
ADD_PACK_A_T_SUFF(pack_a)
};
static void (*small_kernels[])(const size_t, const size_t, const size_t, const float,
const void *, const size_t, const BLASINT8,
const void *, const size_t, const BLASINT8,
const float, int32_t *, const size_t, const int32_t *) = {
ADD_TRANSP_MACRO(small_kernel, n, n, f), ADD_TRANSP_MACRO(small_kernel, n, t, f),
ADD_TRANSP_MACRO(small_kernel, t, n, f), ADD_TRANSP_MACRO(small_kernel, t, t, f),
ADD_TRANSP_MACRO(small_kernel, n, n, c), ADD_TRANSP_MACRO(small_kernel, n, t, c),
ADD_TRANSP_MACRO(small_kernel, t, n, c), ADD_TRANSP_MACRO(small_kernel, t, t, c),
ADD_TRANSP_MACRO(small_kernel, n, n, r), ADD_TRANSP_MACRO(small_kernel, n, t, r),
ADD_TRANSP_MACRO(small_kernel, t, n, r), ADD_TRANSP_MACRO(small_kernel, t, t, r),
};
static void (*beta_funcs[])(int32_t*, const int32_t*, float, size_t, size_t, size_t) = {
beta_cf_s8, beta_cc_s8, beta_cr_s8, beta_rf_s8, beta_rc_s8, beta_rr_s8,
beta_cf_s8_opt, beta_cc_s8_opt, beta_cr_s8_opt, beta_rf_s8_opt, beta_rc_s8_opt, beta_rr_s8_opt
};
static void (*post_op_kernels[])(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc) = {
post_ops, post_ops_opt
};
static size_t row_major_idx(size_t m, size_t n, size_t ld) {
return ld * m + n;
}
static size_t col_major_idx(size_t m, size_t n, size_t ld) {
return m + ld * n;
}
static size_t (*compute_idx[])(size_t m, size_t n, size_t ld) = {
col_major_idx,
row_major_idx
};
static size_t mov_oc_fix(size_t mi, size_t ni) {
UNUSED(mi);
UNUSED(ni);
return 0;
}
static size_t mov_oc_col(size_t mi, size_t ni){
UNUSED(ni);
return mi;
}
static size_t mov_oc_row(size_t mi, size_t ni) {
UNUSED(mi);
return ni;
}
static size_t (*move_oc[])(size_t, size_t) = {
mov_oc_fix, mov_oc_col, mov_oc_row
};
EXTERNAL_API void ADD_TYPES_SUFF(prefill_int4_cblas_gemm)(
const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,
const size_t m, const size_t n, const size_t k, const float alpha,
const void *a, const size_t lda, const BLASINT8 oa,
const void *b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t *c, const size_t ldc, const int32_t *oc) {
int opt_offset = ((alpha == 1.0f) && (beta == 0.0f || beta == 1.0f)) ? 1 : 0;
if(Layout == CblasColMajor) {
int beta_offset = (offsetc == CblasFixOffset) ? 0 : (offsetc == CblasColOffset ? 1:2);
int_gemm_funcs arg = {
small_kernels[(transb == CblasTrans) + 2 * (transa == CblasTrans) + beta_offset * 4],
gemm_kernels,
pack_a_funs[transa == CblasTrans],
pack_b_funs[transb == CblasTrans],
beta_funcs[beta_offset + opt_offset * 6],
post_op_kernels[alpha == 1],
compute_idx[transa == CblasTrans],
compute_idx[transb == CblasTrans],
move_oc[beta_offset],
};
(gemm_impl_8bit(&arg, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc, 262144));
} else if (Layout == CblasRowMajor) {
int beta_offset = (offsetc == CblasFixOffset) ? 3 : (offsetc == CblasColOffset ? 4 : 5);
int beta_offset_small = (offsetc == CblasFixOffset) ? 0 : (offsetc == CblasColOffset ? 2 : 1);
int_gemm_funcs arg = {
small_kernels[(transa == CblasTrans) + 2 * (transb == CblasTrans) + beta_offset_small * 4],
gemm_kernels,
pack_a_funs[transb == CblasTrans],
pack_b_funs[transa == CblasTrans],
beta_funcs[beta_offset + opt_offset * 6],
post_op_kernels[alpha == 1],
compute_idx[transb == CblasTrans],
compute_idx[transa == CblasTrans],
move_oc[beta_offset_small]
};
(gemm_impl_8bit(&arg, n, m, k, alpha, b, ldb, ob, a, lda, oa, beta, c, ldc, oc, 262144));
}
else {
printf("Incorrect layout");
return;
}
}
#ifdef __cplusplus
}
#endif

View File

@@ -1,453 +0,0 @@
#include <stdint.h>
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
#ifdef __cplusplus
extern "C" {
#endif
#define ADD_SUFFIX(name) ADD_M_N_SIZES_MACRO(ADD_TYPES_MACRO(name, LHS_TYPE, RHS_TYPE), M_SIZE, N_SIZE)
#define LD1B_PTR(reg_name, p, ptr, idx) "ld1b {" #reg_name ".b}, " #p "/z, [%[" #ptr "], #" #idx ", MUL VL]\n"
#define COMPUTE_ADDP(out, in1, in2) "addp " #out ".s, " #in1 ".s, " #in2 ".s\n"
#if (defined(LHS_INT) && defined(RHS_INT)) || (defined(LHS_UINT) && defined(RHS_UINT))
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) #lhs_type "dot " #out ".s, " #in1 ".b, " #in2 ".b\n"
#else
#ifdef LHS_INT
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in2 ".b, " #in1 ".b\n"
#else
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in1 ".b, " #in2 ".b\n"
#endif
#endif
#define COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE) COMPUTE_DOT_TYPED(out, in1, in2, LHS_TYPE, RHS_TYPE)
#define COMPUTE_DOT(out, in1, in2) COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE)
#if (N_SIZE > 4)
#error "N_SIZE can't be greater than 4"
#endif
#if (M_SIZE > 4)
#error "M_SIZE can't be greater than 4"
#endif
#define LOAD_Z0(p, ptr) LD1B_PTR(z0, p, ptr, 0)
#define LOAD_Z8(p, ptr) LD1B_PTR(z8, p, ptr, 0)
#if (N_SIZE > 1)
#define LOAD_Z1(p, ptr) LD1B_PTR(z1, p, ptr, 1)
#define LOAD_Z9(p, ptr) LD1B_PTR(z9, p, ptr, 1)
#else
#define LOAD_Z1(p, ptr)
#define LOAD_Z9(p, ptr)
#endif
#if (N_SIZE > 2)
#define LOAD_Z2(p, ptr) LD1B_PTR(z2, p, ptr, 2)
#define LOAD_Z10(p, ptr) LD1B_PTR(z10, p, ptr, 2)
#else
#define LOAD_Z2(p, ptr)
#define LOAD_Z10(p, ptr)
#endif
#if (N_SIZE > 3)
#define LOAD_Z3(p, ptr) LD1B_PTR(z3, p, ptr, 3)
#define LOAD_Z11(p, ptr) LD1B_PTR(z11, p, ptr, 3)
#else
#define LOAD_Z3(p, ptr)
#define LOAD_Z11(p, ptr)
#endif
#define LOAD_Z4(p, ptr) LD1B_PTR(z4, p, ptr, 0)
#define LOAD_Z12(p, ptr) LD1B_PTR(z12, p, ptr, 0)
#if (M_SIZE > 1)
#define LOAD_Z5(p, ptr) LD1B_PTR(z5, p, ptr, 1)
#define LOAD_Z13(p, ptr) LD1B_PTR(z13, p, ptr, 1)
#else
#define LOAD_Z5(p, ptr)
#define LOAD_Z13(p, ptr)
#endif
#if (M_SIZE > 2)
#define LOAD_Z6(p, ptr) LD1B_PTR(z6, p, ptr, 2)
#define LOAD_Z14(p, ptr) LD1B_PTR(z14, p, ptr, 2)
#else
#define LOAD_Z6(p, ptr)
#define LOAD_Z14(p, ptr)
#endif
#if (M_SIZE > 3)
#define LOAD_Z7(p, ptr) LD1B_PTR(z7, p, ptr, 3)
#define LOAD_Z15(p, ptr) LD1B_PTR(z15, p, ptr, 3)
#else
#define LOAD_Z7(p, ptr)
#define LOAD_Z15(p, ptr)
#endif
// macros for dot multiplication
#define ACCUMULATE_Z16(lhs, rhs) COMPUTE_DOT(z16, lhs, rhs)
#if (N_SIZE > 1)
#define ACCUMULATE_Z17(lhs, rhs) COMPUTE_DOT(z17, lhs, rhs)
#else
#define ACCUMULATE_Z17(lhs, rhs)
#endif
#if (N_SIZE > 2)
#define ACCUMULATE_Z18(lhs, rhs) COMPUTE_DOT(z18, lhs, rhs)
#else
#define ACCUMULATE_Z18(lhs, rhs)
#endif
#if (N_SIZE > 3)
#define ACCUMULATE_Z19(lhs, rhs) COMPUTE_DOT(z19, lhs, rhs)
#else
#define ACCUMULATE_Z19(lhs, rhs)
#endif
#if (M_SIZE > 1)
#define ACCUMULATE_Z20(lhs, rhs) COMPUTE_DOT(z20, lhs, rhs)
#if (N_SIZE > 1)
#define ACCUMULATE_Z21(lhs, rhs) COMPUTE_DOT(z21, lhs, rhs)
#else
#define ACCUMULATE_Z21(lhs, rhs)
#endif
#if (N_SIZE > 2)
#define ACCUMULATE_Z22(lhs, rhs) COMPUTE_DOT(z22, lhs, rhs)
#else
#define ACCUMULATE_Z22(lhs, rhs)
#endif
#if (N_SIZE > 3)
#define ACCUMULATE_Z23(lhs, rhs) COMPUTE_DOT(z23, lhs, rhs)
#else
#define ACCUMULATE_Z23(lhs, rhs)
#endif
#else
#define ACCUMULATE_Z20(lhs, rhs)
#define ACCUMULATE_Z21(lhs, rhs)
#define ACCUMULATE_Z22(lhs, rhs)
#define ACCUMULATE_Z23(lhs, rhs)
#endif
#if (M_SIZE > 2)
#define ACCUMULATE_Z24(lhs, rhs) COMPUTE_DOT(z24, lhs, rhs)
#if (N_SIZE > 1)
#define ACCUMULATE_Z25(lhs, rhs) COMPUTE_DOT(z25, lhs, rhs)
#else
#define ACCUMULATE_Z25(lhs, rhs)
#endif
#if (N_SIZE > 2)
#define ACCUMULATE_Z26(lhs, rhs) COMPUTE_DOT(z26, lhs, rhs)
#else
#define ACCUMULATE_Z26(lhs, rhs)
#endif
#if (N_SIZE > 3)
#define ACCUMULATE_Z27(lhs, rhs) COMPUTE_DOT(z27, lhs, rhs)
#else
#define ACCUMULATE_Z27(lhs, rhs)
#endif
#else
#define ACCUMULATE_Z24(lhs, rhs)
#define ACCUMULATE_Z25(lhs, rhs)
#define ACCUMULATE_Z26(lhs, rhs)
#define ACCUMULATE_Z27(lhs, rhs)
#endif
#if (M_SIZE > 3)
#define ACCUMULATE_Z28(lhs, rhs) COMPUTE_DOT(z28, lhs, rhs)
#if (N_SIZE > 1)
#define ACCUMULATE_Z29(lhs, rhs) COMPUTE_DOT(z29, lhs, rhs)
#else
#define ACCUMULATE_Z29(lhs, rhs)
#endif
#if (N_SIZE > 2)
#define ACCUMULATE_Z30(lhs, rhs) COMPUTE_DOT(z30, lhs, rhs)
#else
#define ACCUMULATE_Z30(lhs, rhs)
#endif
#if (N_SIZE > 3)
#define ACCUMULATE_Z31(lhs, rhs) COMPUTE_DOT(z31, lhs, rhs)
#else
#define ACCUMULATE_Z31(lhs, rhs)
#endif
#else
#define ACCUMULATE_Z28(lhs, rhs)
#define ACCUMULATE_Z29(lhs, rhs)
#define ACCUMULATE_Z30(lhs, rhs)
#define ACCUMULATE_Z31(lhs, rhs)
#endif
#define MOVE_LHS_PTR(ptr) "add %[" #ptr "], %[" #ptr "], %[move_lhs]\n"
#define MOVE_RHS_PTR(ptr) "add %[" #ptr "], %[" #ptr "], %[move_rhs]\n"
#define PROCESS_ACCUM(reg_idx, z_reg_idx, tmp_reg, dst, p) \
"ldr w" #reg_idx ", [%[" #dst "]]\n" \
"saddv d" #reg_idx ", " #p ", z" #z_reg_idx ".s\n" \
"fmov " #tmp_reg ", d" #reg_idx "\n" \
"add x" #reg_idx ", x" #reg_idx ", " #tmp_reg "\n" \
"str w" #reg_idx ", [%[" #dst "]], #4\n"
// function logic
void ADD_SUFFIX(gemm_kernel)(const void *lhs_ptr, const void *rhs_ptr,
int32_t *accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len) {
int64_t run_k_depth = k_depth;
int64_t run_sv_len = sv_len;
int64_t run_2sv_len = 2 * sv_len;
int64_t move_lhs = M_SIZE * sv_len;
int64_t move_rhs = N_SIZE * sv_len;
int32_t* dst_ptr = accum_ptr;
ldc -= M_SIZE;
ldc *= 4;
asm volatile(
// predicate for operating on lhs and rhs is always true
"ptrue p0.b, all\n"
// Clear accumulators
LOAD_Z0(p0, rhs_ptr)
"dup z16.s, #0\n"
LOAD_Z1(p0, rhs_ptr)
"dup z17.s, #0\n"
LOAD_Z4(p0, lhs_ptr)
"dup z18.s, #0\n"
LOAD_Z5(p0, lhs_ptr)
"dup z19.s, #0\n"
LOAD_Z6(p0, lhs_ptr)
"dup z20.s, #0\n"
LOAD_Z7(p0, lhs_ptr)
"dup z21.s, #0\n"
LOAD_Z2(p0, rhs_ptr)
"dup z22.s, #0\n"
LOAD_Z3(p0, rhs_ptr)
"dup z23.s, #0\n"
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
"dup z24.s, #0\n"
"mov x16, %[dst_ptr]\n"
"dup z25.s, #0\n"
"dup z26.s, #0\n"
"dup z27.s, #0\n"
MOVE_LHS_PTR(lhs_ptr)
"dup z28.s, #0\n"
MOVE_RHS_PTR(rhs_ptr)
"dup z29.s, #0\n"
"dup z30.s, #0\n"
"dup z31.s, #0\n"
"ble 1f\n"
"cmp %[run_k_depth], %[run_2sv_len]\n"
"blt 2f\n"
"3:\n"
LOAD_Z12(p0, lhs_ptr)
ACCUMULATE_Z16(z4,z0)
ACCUMULATE_Z17(z4,z1)
LOAD_Z13(p0, lhs_ptr)
ACCUMULATE_Z18(z4,z2)
ACCUMULATE_Z19(z4,z3)
LOAD_Z8(p0, rhs_ptr)
ACCUMULATE_Z20(z5,z0)
ACCUMULATE_Z21(z5,z1)
LOAD_Z9(p0, rhs_ptr)
ACCUMULATE_Z22(z5,z2)
ACCUMULATE_Z23(z5,z3)
LOAD_Z10(p0,rhs_ptr)
ACCUMULATE_Z24(z6,z0)
ACCUMULATE_Z25(z6,z1)
LOAD_Z11(p0,rhs_ptr)
ACCUMULATE_Z26(z6,z2)
MOVE_RHS_PTR(rhs_ptr)
"prfw pldl1keep, p0, [%[rhs_ptr], #4, MUL VL]\n"
ACCUMULATE_Z27(z6,z3)
LOAD_Z14(p0, lhs_ptr)
ACCUMULATE_Z28(z7,z0)
ACCUMULATE_Z29(z7,z1)
LOAD_Z15(p0,lhs_ptr)
ACCUMULATE_Z30(z7,z2)
MOVE_LHS_PTR(lhs_ptr)
"prfw pldl1keep, p0, [%[lhs_ptr], #4, MUL VL]\n"
ACCUMULATE_Z31(z7,z3)
LOAD_Z4(p0, lhs_ptr)
ACCUMULATE_Z16(z12,z8)
ACCUMULATE_Z17(z12,z9)
LOAD_Z5(p0, lhs_ptr)
ACCUMULATE_Z18(z12,z10)
ACCUMULATE_Z19(z12,z11)
LOAD_Z6(p0, lhs_ptr)
ACCUMULATE_Z20(z13,z8)
ACCUMULATE_Z21(z13,z9)
LOAD_Z0(p0, rhs_ptr)
"sub %[run_k_depth], %[run_k_depth], %[run_2sv_len]\n"
ACCUMULATE_Z22(z13,z10)
ACCUMULATE_Z23(z13,z11)
LOAD_Z1(p0, rhs_ptr)
ACCUMULATE_Z24(z14,z8)
ACCUMULATE_Z25(z14,z9)
LOAD_Z2(p0,rhs_ptr)
ACCUMULATE_Z26(z14,z10)
ACCUMULATE_Z27(z14,z11)
LOAD_Z3(p0, rhs_ptr)
ACCUMULATE_Z28(z15,z8)
MOVE_RHS_PTR(rhs_ptr)
"prfw pldl1keep, p0, [%[rhs_ptr], #4, MUL VL]\n"
ACCUMULATE_Z29(z15,z9)
LOAD_Z7(p0, lhs_ptr)
"cmp %[run_k_depth], %[run_2sv_len]\n"
ACCUMULATE_Z30(z15, z10)
MOVE_LHS_PTR(lhs_ptr)
"prfw pldl1keep, p0, [%[lhs_ptr], #4, MUL VL]\n"
ACCUMULATE_Z31(z15,z11)
"bge 3b\n"
"cmp %[run_k_depth], #0\n"
"ble 1f\n"
"2:\n"
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
ACCUMULATE_Z16(z4,z0)
ACCUMULATE_Z17(z4,z1)
ACCUMULATE_Z18(z4,z2)
ACCUMULATE_Z19(z4,z3)
LOAD_Z4(p0,lhs_ptr)
ACCUMULATE_Z20(z5,z0)
ACCUMULATE_Z21(z5,z1)
ACCUMULATE_Z22(z5,z2)
ACCUMULATE_Z23(z5,z3)
LOAD_Z5(p0,lhs_ptr)
ACCUMULATE_Z24(z6,z0)
ACCUMULATE_Z25(z6,z1)
ACCUMULATE_Z26(z6,z2)
ACCUMULATE_Z27(z6,z3)
LOAD_Z6(p0,lhs_ptr)
ACCUMULATE_Z28(z7,z0)
LOAD_Z0(p0,rhs_ptr)
ACCUMULATE_Z29(z7,z1)
LOAD_Z1(p0,rhs_ptr)
ACCUMULATE_Z30(z7,z2)
LOAD_Z2(p0,rhs_ptr)
ACCUMULATE_Z31(z7,z3)
LOAD_Z3(p0,rhs_ptr)
MOVE_RHS_PTR(rhs_ptr)
LOAD_Z7(p0,lhs_ptr)
MOVE_LHS_PTR(lhs_ptr)
"bgt 2b\n"
"1:\n"
ACCUMULATE_Z16(z4,z0)
ACCUMULATE_Z17(z4,z1)
ACCUMULATE_Z18(z4,z2)
ACCUMULATE_Z19(z4,z3)
ACCUMULATE_Z20(z5,z0)
ACCUMULATE_Z21(z5,z1)
ACCUMULATE_Z22(z5,z2)
ACCUMULATE_Z23(z5,z3)
ACCUMULATE_Z24(z6,z0)
ACCUMULATE_Z25(z6,z1)
ACCUMULATE_Z26(z6,z2)
ACCUMULATE_Z27(z6,z3)
ACCUMULATE_Z28(z7,z0)
ACCUMULATE_Z29(z7,z1)
ACCUMULATE_Z30(z7,z2)
ACCUMULATE_Z31(z7,z3)
#if (N_SIZE > 0)
#if (M_SIZE > 0)
PROCESS_ACCUM(0, 16, x16, dst_ptr, p0)
#endif
#if (M_SIZE > 1)
PROCESS_ACCUM(4, 20, x17, dst_ptr, p0)
#endif
#if (M_SIZE > 2)
PROCESS_ACCUM(8, 24, x18, dst_ptr, p0)
#endif
#if (M_SIZE > 3)
PROCESS_ACCUM(12, 28, x17, dst_ptr, p0)
#endif
#endif
"add %[dst_ptr], %[dst_ptr], %[ldc]\n"
#if (N_SIZE > 1)
#if (M_SIZE > 0)
PROCESS_ACCUM(1, 17, x16, dst_ptr, p0)
#endif
#if (M_SIZE > 1)
PROCESS_ACCUM(5, 21, x17,dst_ptr,p0)
#endif
#if (M_SIZE > 2)
PROCESS_ACCUM(9, 25, x18, dst_ptr, p0)
#endif
#if (M_SIZE > 3)
PROCESS_ACCUM(13, 29, x17, dst_ptr, p0)
#endif
#endif
"add %[dst_ptr], %[dst_ptr], %[ldc]\n"
#if (N_SIZE > 2)
#if (M_SIZE > 0)
PROCESS_ACCUM(2, 18, x16, dst_ptr, p0)
#endif
#if (M_SIZE > 1)
PROCESS_ACCUM(6, 22, x17,dst_ptr,p0)
#endif
#if (M_SIZE > 2)
PROCESS_ACCUM(10,26,x18,dst_ptr,p0)
#endif
#if (M_SIZE > 3)
PROCESS_ACCUM(14,30,x17,dst_ptr,p0)
#endif
#endif
"add %[dst_ptr], %[dst_ptr], %[ldc]\n"
#if (N_SIZE > 3)
#if (M_SIZE > 0)
PROCESS_ACCUM(3, 19, x16, dst_ptr, p0)
#endif
#if (M_SIZE > 1)
PROCESS_ACCUM(7, 23, x17,dst_ptr,p0)
#endif
#if (M_SIZE > 2)
PROCESS_ACCUM(11,27,x18,dst_ptr,p0)
#endif
#if (M_SIZE > 3)
PROCESS_ACCUM(15,31,x17,dst_ptr,p0)
#endif
#endif
:
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[run_k_depth] "+r"(run_k_depth),
[dst_ptr] "+wr"(dst_ptr)
:
[run_sv_len] "r"(run_sv_len), [run_2sv_len] "r"(run_2sv_len),
[move_lhs] "r"(move_lhs), [move_rhs] "r"(move_rhs), [ldc] "r"(ldc),
[accum_ptr] "r"(accum_ptr)
:
"cc", "memory",
"w0","w1","w2","w3","w4","w5","w6","w7",
"w8","w9","w10","w11","w12","w13","w14","w15",
"x16","x17","x18","x19",
"z0","z1","z2","z3","z4","z5","z6","z7",
"z8","z9","z10","z11","z12","z13","z14","z15",
"z16","z17","z18","z19","z20","z21","z22","z23",
"z24","z25","z26","z27","z28","z29","z30","z31"
);
}
#ifdef __cplusplus
}
#endif

View File

@@ -1,468 +0,0 @@
#ifndef __GEMM_INTFOUR_KERNELS_H__
#define __GEMM_INTFOUR_KERNELS_H__
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
#include <stddef.h>
typedef int8_t BLASINT8;
typedef uint8_t BLASUINT8;
typedef struct {
void (*small_kernel)(const size_t, const size_t, const size_t, const float, const void*, const size_t, const BLASINT8,
const void*, const size_t, const BLASINT8, const float, int32_t*, const size_t, const int32_t*);
void (**gemm_kernels)(const void*, const void*, int32_t*, size_t, int64_t, int64_t);
void (*pack_a_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8);
void (*pack_b_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8);
void (*beta_func)(int32_t*, const int32_t*, float, size_t, size_t, size_t);
void (*post_ops_func)(float, const int32_t*, int32_t*, size_t, size_t, size_t);
size_t (*a_indexing)(size_t, size_t, size_t);
size_t (*b_indexing)(size_t, size_t, size_t);
size_t (*move_oc)(size_t, size_t);
} int_gemm_funcs;
#ifndef COMP_SV_LEN
#error "COMP_SV_LEN is not defined"
#endif
#define KERNEL_M_STEP 4
#define KERNEL_N_STEP 4
#define KERNEL_K_STEP COMP_SV_LEN
#define M_BLOCK 256
#if ((M_BLOCK % KERNEL_M_STEP) != 0)
#error "M_BLOCK % KERNEL_M_STEP != 0"
#endif
#define N_BLOCK 256
#if ((N_BLOCK % KERNEL_N_STEP) != 0)
#error "N_BLOCK % KERNEL_N_STEP != 0"
#endif
#define K_BLOCK 512
#if ((K_BLOCK % KERNEL_K_STEP) != 0)
#error "K_BLOCK % KERNEL_K_STEP != 0"
#endif
#define ALIGNMENT 4096
#define EXTERNAL_API __attribute__((visibility("default")))
#define UNUSED(arg) ((void)(arg))
// general pipeline
void gemm_impl_8bit(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha, const void* a, size_t lda,
const BLASINT8 oa, const void* b, size_t ldb, const BLASINT8 ob, float beta, int32_t* c, size_t ldc,
const int32_t* oc, size_t small_switch);
// s8 kernel
void pack_b_s8_n(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
const BLASINT8 ob);
void pack_a_s8_n(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
const BLASINT8 oa);
void pack_b_s8_t(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
const BLASINT8 ob);
void pack_a_s8_t(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
const BLASINT8 oa);
// u8 kernels
void pack_b_u8_n(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
const BLASINT8 ob);
void pack_a_u8_n(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
const BLASINT8 oa);
void pack_b_u8_t(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
const BLASINT8 ob);
void pack_a_u8_t(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
const BLASINT8 oa);
// beta kernels
void beta_cf_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
void beta_cc_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
void beta_cr_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
void beta_rf_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
void beta_rc_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
void beta_rr_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
void beta_cf_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
size_t ldc);
void beta_cc_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
size_t ldc);
void beta_cr_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
size_t ldc);
void beta_rf_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
size_t ldc);
void beta_rc_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
size_t ldc);
void beta_rr_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
size_t ldc);
// post-ops kernels
void post_ops(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc);
void post_ops_opt(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc);
// drivers
void gemm_driver(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha, const void* a, size_t lda,
const BLASINT8 oa, const void* b, size_t ldb, const BLASINT8 ob, float beta, int32_t* c, size_t ldc,
const int32_t* oc);
void gemm_driver_opt(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha, const void* a, size_t lda,
const BLASINT8 oa, const void* b, size_t ldb, const BLASINT8 ob, float beta, int32_t* c,
size_t ldc, const int32_t* oc);
// matrix multiplication kernels
// s8s8s32 kernels
void gemm_kernel_s8s8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8s8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
// u8u8s32 kernels
void gemm_kernel_u8u8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8u8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
// s8u8s32 kernels
void gemm_kernel_s8u8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_s8u8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
// u8s8s32 kernels
void gemm_kernel_u8s8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_u8s8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
// small kernels
// s8s8s32 kernels
void small_kernel_s8s8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8s8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
// s8u8s32 kernels
void small_kernel_s8u8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_s8u8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
// u8s8s32 kernels
void small_kernel_u8s8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8s8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
// u8u8s32 kernels
void small_kernel_u8u8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void small_kernel_u8u8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -1,158 +0,0 @@
#include <stdint.h>
#include <math.h>
#include <stdlib.h>
#include <stdio.h>
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
#if defined(OC_FIX)
#define OC_TYPE f
#define OC_IDX(mi, nn) 0
#elif defined(OC_COL)
#define OC_TYPE c
#define OC_IDX(mi, ni) mi
#else // OC_ROW
#define OC_TYPE r
#define OC_IDX(mi, ni) ni
#endif // OC_T
#if defined(TRANSA)
#if defined TRANSB
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, t, t, OC_TYPE)
#elif defined(NOTRANSB)
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, t, n, OC_TYPE)
#else
#error "Neither TRANSB or NOTRANSB is defined"
#endif
#elif defined(NOTRANSA)
#if defined TRANSB
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, n, t, OC_TYPE)
#elif defined(NOTRANSB)
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, n, n, OC_TYPE)
#else
#error "Neither TRANSB or NOTRANSB is defined"
#endif
#else
#error "Neither TRANSA or NOTRANSA is defined"
#endif
#if (defined (LHS_INT) && defined(RHS_INT)) || (defined (LHS_UINT) && defined(RHS_UINT))
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) #lhs_type "dot " #out ".s, " #in1 ".b, " #in2 ".b\n"
#else
#if defined(LHS_INT)
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in2 ".b," #in1 ".b\n"
#else // LHS_UINT
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in1 ".b," #in2 ".b\n"
#endif // LHS_INT
#endif // LHS_INT
#define COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE) COMPUTE_DOT_TYPED(out, in1, in2, LHS_TYPE, RHS_TYPE)
#define COMPUTE_DOT(out, in1, in2) COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE)
static inline double compute_dot(size_t k, const void *a, const BLASINT8* oa,
const void *b, const BLASINT8* ob, int64_t sv_len) {
int32_t accum = 0;
int64_t run_k_depth = k;
int64_t run_sv_len = sv_len;
const void* lhs_ptr = a;
const void* rhs_ptr = b;
asm volatile(
"dup z4.s, #0\n"
"ptrue p0.b, all\n"
"ld1b {z0.b}, p0/z, [%[oa]]\n"
"ld1b {z1.b}, p0/z, [%[ob]]\n"
"1:\n"
"whilelt p1.b, xzr, %[run_k_depth]\n"
"ld1b {z2.b}, p1/z, [%[lhs_ptr]]\n"
"ld1b {z3.b}, p1/z, [%[rhs_ptr]]\n"
"add z2.b, p1/m, z2.b, z0.b\n"
"add z3.b, p1/m, z3.b, z1.b\n"
"add %[lhs_ptr], %[lhs_ptr], %[run_sv_len]\n"
"add %[rhs_ptr], %[rhs_ptr], %[run_sv_len]\n"
COMPUTE_DOT(z4, z2, z3)
"subs %[run_k_depth], %[run_k_depth], #1\n"
"bgt 1b\n"
"ptrue p2.s, all\n"
"saddv d1, p2, z4.s\n"
"fmov x2, d1\n"
"add %[accum], %[accum], x2\n"
: // outputs
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[run_k_depth] "+wr"(run_k_depth),
[accum] "+wr"(accum)
: // inputs
[run_sv_len] "r"(run_sv_len), [oa] "r"(oa), [ob] "r"(ob)
: // clobbers
"cc", "memory",
"d1", "x2",
"z0", "z1", "z2", "z3", "z4", "p0", "p1", "p2"
);
return (double) accum;
}
#if !defined(TRANSA) || defined(TRANSB)
// performs transposition (n0 * n1) -> (n1 * n0) assuming col major
static inline void simplest_transpose(const void *in, void *out, size_t n0, size_t ld0, size_t n1) {
// since we care only about size, we can use signed type always
BLASINT8* typed_in = (BLASINT8*) in;
BLASINT8* typed_out = (BLASINT8*) out;
for (size_t i = 0; i < n1; ++i) {
for (size_t j = 0; j < n0; ++j) {
typed_out[i + j * n1] = typed_in[j + i * ld0];
}
}
}
#endif // !defined(TRANSA) || defined(TRANSB)
// A in row-major, B in col-major
void ADD_SUFFIX(small_kernel)(const size_t m, const size_t n, const size_t k, const float alpha,
const void *a, const size_t lda, const BLASINT8 oa,
const void *b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t *c, const size_t ldc, const int32_t *oc) {
double double_alpha = (double) alpha;
// we use typed pointers only for indexing, so we don't care about signess
#ifdef TRANSA
BLASINT8* a_typed = (BLASINT8*) a;
const size_t used_lda = lda;
#else
BLASINT8* a_typed = (BLASINT8*) aligned_alloc(128, sizeof(BLASINT8) * m * k);
simplest_transpose(a, a_typed, m, lda, k);
const size_t used_lda = k;
#endif // TRANSA
#ifndef TRANSB
BLASINT8* b_typed = (BLASINT8*) b;
const size_t used_ldb = ldb;
#else // TRANSB
BLASINT8* b_typed = (BLASINT8*) aligned_alloc(128, sizeof(BLASINT8) * k * n);
simplest_transpose(b, b_typed, n, ldb, k);
const size_t used_ldb = k;
#endif // TRANSB
BLASINT8 oa_buf[KERNEL_K_STEP];
BLASINT8 ob_buf[KERNEL_K_STEP];
for (size_t i = 0; i < KERNEL_K_STEP; ++i) {
oa_buf[i] = oa;
ob_buf[i] = ob;
}
// printf("\n========\n");
for (size_t mi = 0; mi < m; ++mi) {
for (size_t ni = 0; ni < n; ++ni) {
// printf("mi = %lu, ni = %lu, oc_idx = %lu\n", mi, ni, OC_IDX(mi, ni));
double tmp = compute_dot(k, a_typed + mi * used_lda, oa_buf, b_typed + ni * used_ldb, ob_buf, KERNEL_K_STEP);
c[mi + ni * ldc] = round(tmp * double_alpha + ((double)(beta * ((float)c[mi + ni * ldc]) + oc[OC_IDX(mi, ni)])));
}
}
#ifdef TRANSA
free(a_typed);
#endif // TRANSA
#ifdef TRANSB
free(b_typed);
#endif // TRANSB
}
#ifdef __cplusplus
} // extern "C"
#endif /* __cplusplus */

View File

@@ -1,140 +0,0 @@
#include <stdint.h>
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
#if defined(TRANSA)
// row major
#define INDEXING_A(row_idx, col_idx, lda) ((col_idx) * (lda) + row_idx)
#define ADD_A_SUFF(name) ADD_PACK_A_T_SUFF(name)
#elif defined(NOTRANSA)
// col major
#define INDEXING_A(row_idx, col_idx, lda) ((row_idx) * (lda) + col_idx)
#define ADD_A_SUFF(name) ADD_PACK_A_N_SUFF(name)
#else
#error "Neither TRANSA or NOTRANSA is defined"
#endif
#define OLD_N_SIZE 8
#define NEW_N_SIZE 4
void ADD_A_SUFF(pack_a)(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda, const BLASINT8 oa) {
LHS_INT_TYPE* bufferA_typed = (LHS_INT_TYPE*) bufferA;
LHS_INT_TYPE* curr_a_ptr_typed = (LHS_INT_TYPE*) curr_a_ptr;
// printf("m_block_size:%lu ,k_block_size: %lu\n", m_block_size, k_block_size);
for(size_t old_split_n = 0; old_split_n < (m_block_size / OLD_N_SIZE); old_split_n++) {
for(size_t split_k = 0; split_k < (k_block_size / (KERNEL_K_STEP * 2)); split_k++) {
for(size_t old_idx_n = 0; old_idx_n < OLD_N_SIZE; old_idx_n++) {
for(size_t idx_k = 0; idx_k < KERNEL_K_STEP; idx_k++) {
size_t n_idx = old_split_n * OLD_N_SIZE + old_idx_n;
size_t new_split_n = n_idx / NEW_N_SIZE;
size_t new_idx_n = n_idx % NEW_N_SIZE;
size_t old_buff_idx =
old_split_n * OLD_N_SIZE * (lda / 2) +
split_k * OLD_N_SIZE * KERNEL_K_STEP +
old_idx_n * KERNEL_K_STEP +
idx_k;
uint8_t b01 = curr_a_ptr_typed[old_buff_idx];
uint8_t b0 = b01 & 0xF0;
uint8_t b1 = b01 << 4;
size_t new_buff_idx =
new_split_n * NEW_N_SIZE * k_block_size +
split_k * NEW_N_SIZE * (KERNEL_K_STEP * 2) +
new_idx_n * KERNEL_K_STEP +
idx_k;
bufferA_typed[new_buff_idx] = b0;
bufferA_typed[new_buff_idx + KERNEL_K_STEP * NEW_N_SIZE] = b1;
}
}
}
}
// for(size_t n_idx = 0; n_idx < m_block_size; n_idx++) {
// for(size_t k_idx = 0; k_idx < k_block_size; k_idx++) {
// size_t old_split_n = n_idx / OLD_N_SIZE;
// size_t old_idx_n = n_idx % OLD_N_SIZE;
// size_t new_split_n = n_idx / NEW_N_SIZE;
// size_t new_idx_n = n_idx % NEW_N_SIZE;
// size_t split_k = k_idx / KERNEL_K_STEP;
// size_t idx_k = k_idx % KERNEL_K_STEP;
// size_t old_buff_idx =
// old_split_n * OLD_N_SIZE * lda +
// split_k * OLD_N_SIZE * KERNEL_K_STEP +
// old_idx_n * KERNEL_K_STEP +
// idx_k;
// size_t new_buff_idx =
// new_split_n * NEW_N_SIZE * k_block_size +
// split_k * NEW_N_SIZE * KERNEL_K_STEP +
// new_idx_n * KERNEL_K_STEP +
// idx_k;
// bufferA_typed[new_buff_idx] = curr_a_ptr_typed[old_buff_idx] + oa;
// }
// }
// size_t k_block_size_up = (k_block_size + KERNEL_K_STEP - 1) / KERNEL_K_STEP * KERNEL_K_STEP;
// size_t k_portions = k_block_size / KERNEL_K_STEP;
// size_t k_resid = k_block_size - KERNEL_K_STEP * k_portions;
// size_t m_portions = m_block_size / KERNEL_M_STEP;
// size_t m_resid = m_block_size - KERNEL_M_STEP * m_portions;
// for (size_t im4 = 0; im4 < m_portions; ++im4) {
// for (size_t ik16 = 0; ik16 < k_portions; ++ik16) {
// for (size_t im = 0; im < KERNEL_M_STEP; ++im) {
// for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * KERNEL_M_STEP * ik16 + k_block_size_up * KERNEL_M_STEP * im4] =
// curr_a_ptr_typed[INDEXING_A((ik16 * KERNEL_K_STEP + ik), (im4 * KERNEL_M_STEP + im), lda)] + oa;
// }
// }
// }
// if (k_resid) {
// for (size_t im = 0; im < KERNEL_M_STEP; ++im) {
// for (size_t ik = 0; ik < k_resid; ++ik) {
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * KERNEL_M_STEP * k_portions + k_block_size_up * KERNEL_M_STEP * im4] =
// curr_a_ptr_typed[INDEXING_A((k_portions * KERNEL_K_STEP + ik), (im4 * KERNEL_M_STEP + im), lda)] + oa;
// }
// }
// for (size_t im = 0; im < KERNEL_M_STEP; ++im) {
// for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * KERNEL_M_STEP * k_portions + k_block_size_up * KERNEL_M_STEP * im4] = 0;
// }
// }
// }
// }
// if (m_resid) {
// for (size_t ik16 = 0; ik16 < k_portions; ++ik16) {
// for (size_t im = 0; im < m_resid; ++im) {
// for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * m_resid * ik16 + k_block_size_up * KERNEL_M_STEP * m_portions] =
// curr_a_ptr_typed[INDEXING_A((ik16 * KERNEL_K_STEP + ik), (m_portions * KERNEL_M_STEP + im), lda)] + oa;
// }
// }
// }
// if (k_resid) {
// for (size_t im = 0; im < m_resid; ++im) {
// for (size_t ik = 0; ik < k_resid; ++ik) {
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * m_resid * k_portions + k_block_size_up * KERNEL_M_STEP * m_portions] =
// curr_a_ptr_typed[INDEXING_A((k_portions * KERNEL_K_STEP + ik), (m_portions * KERNEL_M_STEP + im), lda)] + oa;
// }
// }
// for (size_t im = 0; im < m_resid; ++im) {
// for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * m_resid * k_portions + k_block_size_up * KERNEL_M_STEP * m_portions] = 0;
// }
// }
// }
// }
}
#ifdef __cplusplus
} // extern "C"
#endif /* __cplusplus */

View File

@@ -1,96 +0,0 @@
#include <stdint.h>
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
#if defined(TRANSB)
// row major
#define INDEXING_B(row_idx, col_idx, ldb) ((col_idx) * (ldb) + row_idx)
#define ADD_B_SUFF(name) ADD_PACK_B_T_SUFF(name)
#elif defined(NOTRANSB)
// col major
#define INDEXING_B(row_idx, col_idx, ldb) ((row_idx) * (ldb) + col_idx)
#define ADD_B_SUFF(name) ADD_PACK_B_N_SUFF(name)
#else
#error "Neither TRANSB or NOTRANSB is defined."
#endif
void ADD_B_SUFF(pack_b)(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb, const BLASINT8 ob) {
RHS_INT_TYPE* bufferB_typed = (RHS_INT_TYPE*) bufferB;
RHS_INT_TYPE* curr_b_ptr_typed = (RHS_INT_TYPE*) curr_b_ptr;
size_t k_block_size_up = (k_block_size + KERNEL_K_STEP - 1) / KERNEL_K_STEP * KERNEL_K_STEP;
size_t k_portions = k_block_size / KERNEL_K_STEP;
size_t k_resid = k_block_size - KERNEL_K_STEP * k_portions;
size_t n_portions = n_block_size / KERNEL_N_STEP;
size_t n_resid = n_block_size - KERNEL_N_STEP * n_portions;
for (size_t in4 = 0; in4 < n_portions; ++in4) {
for (size_t in = 0; in < KERNEL_N_STEP; ++in) {
for (size_t ik16 = 0; ik16 < k_block_size / KERNEL_K_STEP; ++ik16) {
for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * KERNEL_N_STEP * ik16 + k_block_size_up * KERNEL_N_STEP * in4] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * in4 + in), (KERNEL_K_STEP * ik16 + ik), ldb)] + ob;
}
}
if (k_resid) {
for (size_t ik = 0; ik < k_resid; ++ik) {
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * KERNEL_N_STEP * k_portions + k_block_size_up * KERNEL_N_STEP * in4] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * in4 + in), (KERNEL_K_STEP * k_portions + ik), ldb)] + ob;
}
for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * KERNEL_N_STEP * k_portions + k_block_size_up * KERNEL_N_STEP * in4] = 0;
}
}
}
}
if (n_resid) {
for (size_t in = 0; in < n_resid; ++in) {
for (size_t ik16 = 0; ik16 < k_block_size / KERNEL_K_STEP; ++ik16) {
for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * n_resid * ik16 + k_block_size_up * KERNEL_N_STEP * n_portions] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * n_portions + in), (KERNEL_K_STEP * ik16 + ik), ldb)] + ob;
}
}
if (k_resid) {
for (size_t ik = 0; ik < k_resid; ++ik) {
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * n_resid * k_portions + k_block_size_up * KERNEL_N_STEP * n_portions] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * n_portions + in), (KERNEL_K_STEP * k_portions + ik), ldb)] + ob;
}
for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * n_resid * k_portions + k_block_size_up * KERNEL_N_STEP * n_portions] = 0;
}
}
}
}
// printf("n_block_size:%lu ,k_block_size: %lu\n", n_block_size, k_block_size);
// for(size_t n_idx = 0; n_idx < n_block_size; n_idx++) {
// for(size_t k_idx = 0; k_idx < k_block_size; k_idx++) {
// size_t old_split_n = n_idx / OLD_N_SIZE;
// size_t old_idx_n = n_idx % OLD_N_SIZE;
// size_t new_split_n = n_idx / NEW_N_SIZE;
// size_t new_idx_n = n_idx % NEW_N_SIZE;
// size_t split_k = k_idx / KERNEL_K_STEP;
// size_t idx_k = k_idx % KERNEL_K_STEP;
// size_t old_buff_idx =
// old_split_n * OLD_N_SIZE * ldb +
// split_k * OLD_N_SIZE * KERNEL_K_STEP +
// old_idx_n * KERNEL_K_STEP +
// idx_k;
// size_t new_buff_idx =
// new_split_n * NEW_N_SIZE * k_block_size +
// split_k * NEW_N_SIZE * KERNEL_K_STEP +
// new_idx_n * KERNEL_K_STEP +
// idx_k;
// bufferB_typed[new_buff_idx] = curr_b_ptr_typed[old_buff_idx] + ob;
// }
// }
}
#ifdef __cplusplus
} // extern "C"
#endif /* __cplusplus */

View File

@@ -1,25 +0,0 @@
#include <stdint.h>
#include <math.h>
#include "integer_gemm_kernels.h"
#include "helping_macros.h"
#include "beta_macros.h"
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
// Fixed OC
void BETA_SUFF(post_ops)(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc) {
float* current_c_float_ptr = (float*) current_c_ptr;
double double_alpha = (double) alpha;
for (size_t n_idx = 0; n_idx < n_block; ++n_idx) {
for (size_t m_idx = 0; m_idx < m; ++m_idx) {
current_c_ptr[m_idx + n_idx * ldc] = round(((double)(current_c_float_ptr[m_idx + n_idx * ldc]))
+ double_alpha * ((double) bufferC[m_idx + n_idx * LDC(m, ldc)]));
}
}
}
#ifdef __cplusplus
} // extern "C"
#endif /* __cplusplus */

View File

@@ -1,35 +0,0 @@
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include "integer_gemm_kernels.h"
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
// here we will use BLASINT8, because we care only about type size, not signness (no math operations required)
void gemm_impl_8bit(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha,
const void* a, size_t lda, const BLASINT8 oa,
const void* b, size_t ldb, const BLASINT8 ob,
float beta, int32_t* c, size_t ldc, const int32_t* oc, size_t small_switch) {
void (*small_kernel)(const size_t, const size_t, const size_t, const float,
const void *, const size_t, const BLASINT8,
const void *, const size_t, const BLASINT8,
const float, int32_t *, const size_t, const int32_t *) = arg->small_kernel;
if (m * n * k < small_switch) { // experimentally measured constant
small_kernel(m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
} else {
// Corner cases optimizations
if ((alpha == 1.0f) && (beta == 0.0f || beta == 1.0f)) {
gemm_driver_opt(arg, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
} else {
gemm_driver(arg, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
}
}
}
#ifdef __cplusplus
}
#endif /* __cplusplus */

View File

@@ -1,30 +0,0 @@
cmake_minimum_required(VERSION 3.14.1)
project(THREAD_TEST)
set(CMAKE_CXX_FLAGS ${CFLAGS_OPT})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
message(STATUS "Current source dir: ${CMAKE_CURRENT_SOURCE_DIR}/..")
set(BLAS_PATH "")
message(STATUS "BLAS_PATH=${BLAS_PATH}")
if ("${BLAS_PATH}" STREQUAL "")
set(BLAS_LIB "${CMAKE_CURRENT_SOURCE_DIR}/../libint8gemm.so")
message(STATUS "BLAS_LIB=${BLAS_LIB}")
# Add threading library to linker
#find_package(Threads)
add_executable(integer_tester integer_gemm.cpp)
#target_include_directories(integer_tester PUBLIC "${BLAS_PATH}")
#target_include_directories(integer_tester PUBLIC "${BLAS_BUILD_PATH}")
target_link_libraries(integer_tester PRIVATE ${BLAS_LIB})
set_property(TARGET integer_tester PROPERTY CXX_STANDARD 17)
add_test(integer_tester ${CMAKE_CURRENT_BINARY_DIR}/integer_tester)
else ()
message(FATAL_ERROR "Can not find int8_gemm path, pls set right path!")
endif ()

View File

@@ -1,438 +0,0 @@
#include <malloc.h>
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <exception>
#include <iostream>
#include <limits>
#include <random>
#include <type_traits>
#include <vector>
/* matrix saved in rows or cols */
typedef enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 } CBLAS_ORDER;
/* matrix transpose or conjugate transpose */
typedef enum CBLAS_TRANSPOSE {
CblasNoTrans = 111,
CblasTrans = 112,
CblasConjTrans = 113, // conjugate transpose
CblasConjNoTrans = 114
} CBLAS_TRANSPOSE;
typedef CBLAS_ORDER CBLAS_LAYOUT;
typedef enum CBLAS_OFFSET { CblasRowOffset = 171, CblasColOffset = 172, CblasFixOffset = 173 } CBLAS_OFFSET;
typedef int8_t BLASINT8;
typedef uint8_t BLASUINT8;
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
void cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void cblas_gemm_u8u8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void cblas_gemm_s8u8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void cblas_gemm_u8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
#ifdef __cplusplus
}
#endif /* __cplusplus */
namespace test {
namespace tools {
template <typename T, std::size_t alignment = 128>
struct aligned_allocator {
using value_type = T;
using pointer = T*;
using const_pointer = const T*;
using reference = T&;
using const_reference = const T&;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
template <typename U>
struct rebind {
typedef aligned_allocator<U, alignment> other;
};
[[nodiscard]] T* allocate(std::size_t n) {
if (n > std::numeric_limits<std::size_t>::max() / sizeof(T)) throw std::bad_array_new_length();
if (auto p = static_cast<T*>(memalign(alignment, n * sizeof(T)))) {
return p;
}
throw std::bad_alloc();
}
void deallocate(T* p, std::size_t n) noexcept {
(void)(n);
free(p);
}
~aligned_allocator() {}
};
template <typename T, std::size_t alignment_1, typename U, std::size_t alignment_2>
bool operator==(const aligned_allocator<T, alignment_1>&, const aligned_allocator<U, alignment_2>&) {
return (alignment_1 == alignment_2) && std::is_same_v<T, U>;
}
template <typename T, std::size_t alignment_1, typename U, std::size_t alignment_2>
bool operator!=(const aligned_allocator<T, alignment_1>& lhs, const aligned_allocator<U, alignment_2>& rhs) {
return !(lhs == rhs);
}
template <typename Func, typename... Args>
double timing(Func&& func, Args&&... args) {
double time = 0.0;
double time_begin = 0.0;
std::size_t n_run = 0;
auto start_begin = std::chrono::steady_clock::now();
std::forward<Func>(func)(std::forward<Args>(args)...);
auto end_begin = std::chrono::steady_clock::now();
time_begin = std::chrono::duration_cast<std::chrono::nanoseconds>(end_begin - start_begin).count() / 1e9;
n_run = std::max<std::size_t>(std::size_t(1.0 / time_begin), 3);
auto start = std::chrono::steady_clock::now();
for (std::size_t i = 0; i < n_run; ++i) {
std::forward<Func>(func)(std::forward<Args>(args)...);
}
auto end = std::chrono::steady_clock::now();
time += std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count() / 1e9;
return time / n_run;
}
} // namespace tools
namespace helpers {
std::size_t get_oc_size(CBLAS_OFFSET offset, std::size_t m, std::size_t n) {
std::size_t ret_val = 0;
switch (offset) {
case CblasFixOffset:
ret_val = 1;
break;
case CblasColOffset:
ret_val = m;
break;
case CblasRowOffset:
ret_val = n;
break;
default:
std::cout << "Incorrect value of offset to the function " << __PRETTY_FUNCTION__ << std::endl;
}
return ret_val;
}
template <typename T>
auto get_ab_matrix(CBLAS_LAYOUT lt, CBLAS_TRANSPOSE trans_, T&& non_trans_mtx, T&& trans_mtx) {
if (lt == CblasColMajor) {
if (trans_ == CblasNoTrans) {
return non_trans_mtx.data();
} else {
return trans_mtx.data();
}
} else {
if (trans_ == CblasNoTrans) {
return trans_mtx.data();
} else {
return non_trans_mtx.data();
}
}
}
auto get_ldab(CBLAS_LAYOUT lt, CBLAS_TRANSPOSE trans_mtx, std::size_t ld_n, std::size_t ld_t) {
if (lt == CblasColMajor) {
if (trans_mtx == CblasNoTrans) {
return ld_n;
} else {
return ld_t;
}
} else {
if (trans_mtx == CblasNoTrans) {
return ld_t;
} else {
return ld_n;
}
}
}
// returns copy of the matrix
template <typename T>
auto get_c_matrix(CBLAS_LAYOUT lt, T&& non_trans_mtx, T&& trans_mtx) {
if (lt == CblasColMajor) {
return non_trans_mtx;
} else {
return trans_mtx;
}
}
auto get_ldc(CBLAS_LAYOUT lt, std::size_t ldc_n, std::size_t ldc_t) {
if (lt == CblasColMajor) {
return ldc_n;
} else {
return ldc_t;
}
}
template <typename A_Type, typename B_Type>
void cblas_gemm_wrapper(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
const A_Type* a, const size_t lda, const int8_t oa, const B_Type* b, const size_t ldb,
const int8_t ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc) {
if constexpr (std::is_same_v<A_Type, std::int8_t>) {
if constexpr (std::is_same_v<B_Type, std::int8_t>) {
cblas_gemm_s8s8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
} else {
cblas_gemm_s8u8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
}
} else {
if constexpr (std::is_same_v<B_Type, std::int8_t>) {
cblas_gemm_u8s8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
} else {
cblas_gemm_u8u8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
}
}
}
std::size_t return_oc_idx(const CBLAS_OFFSET offsetc, std::size_t mi, std::size_t ni) {
return (offsetc == CblasFixOffset) ? 0 : ((offsetc == CblasColOffset) ? mi : ni);
}
} // namespace helpers
enum class status_t { passed, failed };
std::ostream& operator<<(std::ostream& os, const status_t& st) {
if (status_t::passed == st) {
os << "PASSED";
} else if (status_t::failed == st) {
os << "FAILED";
}
return os;
}
// column major
template <typename A_Type, typename B_Type>
void ref_gemm(const CBLAS_OFFSET offsetc, const std::size_t m, const std::size_t n, const std::size_t k,
const float alpha, const A_Type* a, const std::size_t lda, const std::int8_t oa, const B_Type* b,
const std::size_t ldb, const std::int8_t ob, const float beta, std::int32_t* c, const std::size_t ldc,
const std::int32_t* oc) {
for (std::size_t mi = 0; mi < m; ++mi) {
for (std::size_t ni = 0; ni < n; ++ni) {
std::int32_t tmp = 0;
for (std::size_t ki = 0; ki < k; ++ki) {
tmp += (a[mi + ki * lda] + oa) * (b[ki + ni * ldb] + ob);
}
c[mi + ni * ldc] = std::round(alpha * static_cast<double>(tmp) +
static_cast<double>(beta * static_cast<float>(c[mi + ni * ldc])) +
static_cast<float>(oc[helpers::return_oc_idx(offsetc, mi, ni)]));
}
}
}
template <typename DataType>
void fill_random(DataType* buffer, std::size_t len) {
static std::mt19937 generator(0);
std::uniform_int_distribution<DataType> dist(0, 64);
for (std::size_t i = 0; i < len; i++) {
buffer[i] = static_cast<DataType>(dist(generator));
}
}
template <typename DataType>
void fill_const(DataType* buffer, std::size_t len) {
for (std::size_t i = 0; i < len; i++) {
buffer[i] = DataType{-8};
}
}
// performs transposition (n0 * n1) -> (n1 * n0), assuming col major
template <typename T>
void simplest_transpose(T* in, T* out, std::size_t n0, std::size_t n1, std::size_t ld0, std::size_t ld1) {
for (std::size_t i = 0; i < n0; ++i) {
for (std::size_t j = 0; j < n1; ++j) {
out[i + j * ld1] = in[j + i * ld0];
}
}
}
template <typename DataType>
status_t compare(DataType* ref, DataType* test, std::size_t m, std::size_t n, std::size_t ld) {
for (std::size_t mi = 0; mi < m; ++mi) {
for (std::size_t ni = 0; ni < n; ++ni) {
if (ref[mi + ni * ld] != test[mi + ni * ld]) {
return status_t::failed;
}
}
}
return status_t::passed;
}
template <typename DataType>
void print_matrix(DataType* buffer, std::size_t m, std::size_t n) {
for (std::size_t mi = 0; mi < m; ++mi) {
for (std::size_t ni = 0; ni < n; ++ni) {
std::cout << static_cast<std::int32_t>(buffer[mi + ni * m]) << ", ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}
template <typename A_Type, typename B_Type>
status_t gemm(std::size_t m, std::size_t n, std::size_t k, float alpha, float beta) {
std::int8_t oa = 4;
std::int8_t ob = 9;
std::size_t lda_n = m;
std::size_t ldb_n = k;
std::size_t ldc_n = m;
std::size_t lda_t = k;
std::size_t ldb_t = n;
std::size_t ldc_t = n;
if (std::getenv("LD_STRIDE")) {
lda_n += 2;
ldb_n += 7;
ldc_n += 3;
lda_t += 8;
ldb_t += 3;
ldc_t += 23;
}
bool only_performance = false;
if (std::getenv("ONLY_PERF")) {
only_performance = true;
}
std::vector<A_Type, tools::aligned_allocator<A_Type, 128>> a_n(lda_n * k);
std::vector<A_Type, tools::aligned_allocator<A_Type, 128>> a_t(m * lda_t);
std::vector<B_Type, tools::aligned_allocator<B_Type, 128>> b_n(ldb_n * n);
std::vector<B_Type, tools::aligned_allocator<B_Type, 128>> b_t(k * ldb_t);
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_ref(ldc_n * n);
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_n(ldc_n * n);
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_t(m * ldc_t);
// fill the whole array even if ld* > corresponding dim
fill_random(a_n.data(), a_n.size());
fill_random(b_n.data(), b_n.size());
simplest_transpose(a_n.data(), a_t.data(), m, k, lda_n, lda_t);
simplest_transpose(b_n.data(), b_t.data(), k, n, ldb_n, ldb_t);
fill_const(c_ref.data(), ldc_n * n);
c_n = c_ref;
simplest_transpose(c_n.data(), c_t.data(), m, ldc_n, n, ldc_t);
auto return_st = status_t::passed;
double total = 0;
size_t cnt = 0;
for (auto c_offset : {CblasFixOffset, CblasColOffset, CblasRowOffset}) {
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> oc(helpers::get_oc_size(c_offset, m, n));
for (std::size_t i = 0; i < oc.size(); ++i) {
oc[i] = i + i;
}
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_ref_copy = c_ref;
if (!only_performance) {
ref_gemm(c_offset, m, n, k, alpha, a_n.data(), lda_n, oa, b_n.data(), ldb_n, ob, beta, c_ref_copy.data(), ldc_n,
oc.data());
}
for (auto layout : {CblasColMajor, CblasRowMajor}) {
for (auto transa : {CblasNoTrans, CblasTrans}) {
for (auto transb : {CblasNoTrans, CblasTrans}) {
auto&& c_tested = helpers::get_c_matrix(layout, c_n, c_t);
if (!only_performance) {
helpers::cblas_gemm_wrapper(
layout, transa, transb, c_offset, m, n, k, alpha, helpers::get_ab_matrix(layout, transa, a_n, a_t),
helpers::get_ldab(layout, transa, lda_n, lda_t), oa, helpers::get_ab_matrix(layout, transb, b_n, b_t),
helpers::get_ldab(layout, transb, ldb_n, ldb_t), ob, beta, c_tested.data(),
helpers::get_ldc(layout, ldc_n, ldc_t), oc.data());
// transpose c_tested to col-major if required
auto loc_st = status_t::passed;
if (layout == CblasRowMajor) {
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_tested_n(ldc_n * n);
simplest_transpose(c_tested.data(), c_tested_n.data(), n, ldc_t, m, ldc_n);
loc_st = compare(c_ref_copy.data(), c_tested_n.data(), m, n, ldc_n);
} else {
loc_st = compare(c_ref_copy.data(), c_tested.data(), m, n, ldc_n);
}
if (loc_st != status_t::passed) {
std::cout << "-";
return_st = status_t::failed;
} else {
std::cout << "+";
}
} else {
double cur = (2.0 * m * n * k) /
tools::timing(helpers::cblas_gemm_wrapper<A_Type, B_Type>, layout, transa, transb, c_offset, m,
n, k, alpha, helpers::get_ab_matrix(layout, transa, a_n, a_t),
helpers::get_ldab(layout, transa, lda_n, lda_t), oa,
helpers::get_ab_matrix(layout, transb, b_n, b_t),
helpers::get_ldab(layout, transb, ldb_n, ldb_t), ob, beta, c_tested.data(),
helpers::get_ldc(layout, ldc_n, ldc_t), oc.data()) /
1e12;
total += cur;
++cnt;
std::cout << cur << ", ";
}
}
}
}
}
if (only_performance) {
std::cout << "Average " << total / cnt << " TFlops";
}
std::cout << " ";
return return_st;
}
} // namespace test
int main(int argc, char** argv) {
std::size_t m = 128;
std::size_t n = 128;
std::size_t k = 128;
float alpha = 1.0f;
float beta = 1.0f;
if (argc > 1) {
m = std::stoi(argv[1]);
if (argc > 2) {
n = std::stoi(argv[2]);
if (argc > 3) {
k = std::stoi(argv[3]);
if (argc > 4) {
alpha = std::stof(argv[4]);
if (argc > 5) {
beta = std::stof(argv[5]);
}
}
}
}
}
std::cout << "Testing matrix m = " << m << ", n = " << n << ", k = " << k << ", alpha = " << alpha
<< ", beta = " << beta << std::endl;
std::cout << "\tTesting i8i8i32: " << test::gemm<std::int8_t, std::int8_t>(m, n, k, alpha, beta) << std::endl;
std::cout << "\tTesting i8u8i32: " << test::gemm<std::int8_t, std::uint8_t>(m, n, k, alpha, beta) << std::endl;
std::cout << "\tTesting u8i8i32: " << test::gemm<std::uint8_t, std::int8_t>(m, n, k, alpha, beta) << std::endl;
std::cout << "\tTesting u8u8i32: " << test::gemm<std::uint8_t, std::uint8_t>(m, n, k, alpha, beta) << std::endl;
return 0;
}

View File

@@ -1,45 +0,0 @@
#pragma once
/*** gemm helper ***/
#include "../../api/common.h"
#ifdef __cplusplus
extern "C" {
#endif
#define COMP_SV_LEN 32
#define K_SIZE COMP_SV_LEN
#define M_SIZE 1
#define N_SIZE 8
#define INDEXING_B(row_idx, col_idx, ldb) ((col_idx) * (ldb) + row_idx)
#define PROCESS_ACCUM(reg_idx, z_reg_idx, tmp_reg, dst, p) \
"mov w" #reg_idx \
", #0\n" \
"saddv d" #reg_idx ", " #p ", z" #z_reg_idx \
".s\n" \
"fmov " #tmp_reg ", d" #reg_idx \
"\n" \
"add x" #reg_idx ", x" #reg_idx ", " #tmp_reg \
"\n" \
"str w" #reg_idx ", [%[" #dst "]], #4\n"
#define INT4_CP_MASK_SHIFT_1x8(src_reg, dst_reg, mask_reg1, mask_reg2, shift) \
"movprfx z" #dst_reg ", z" #src_reg \
"\n" \
"lsl z" #dst_reg ".b, p0/m, z" #dst_reg ".b, #" #shift \
"\n" \
"and z" #src_reg ".b, p0/m, z" #src_reg ".b, z" #mask_reg1 ".b\n"
void pack_b_1x8(void* bufferB, const void* cur_b_ptr, size_t n, size_t k, size_t ldb, const BLASINT8 ob);
void pack_b_1x8_int4(void* bufferB, const void* cur_b_ptr, size_t n, size_t k, size_t ldb, const BLASINT8 ob);
void gemm_kernel_1x8(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_1x8_int4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
#ifdef __cplusplus
}
#endif
/*** gemm helper ***/