mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
[feat]: patch kml problem (#1704)
This commit is contained in:
@@ -1,56 +0,0 @@
|
||||
#include <stdexcept>
|
||||
|
||||
#include "../batch_gemm_api.hpp"
|
||||
#include "utils.hpp"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void decode_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
|
||||
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
|
||||
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,
|
||||
int32_t* c, const size_t ldc, const int32_t* oc) {
|
||||
BLASINT8* ptrA = (BLASINT8*)a;
|
||||
BLASINT8* ptrB = (BLASINT8*)b;
|
||||
int32_t* ptrC = c;
|
||||
size_t split_n = n / N_SIZE;
|
||||
|
||||
for (size_t n_block = 0; n_block < split_n; n_block++) {
|
||||
BLASINT8* cur_ptrA = ptrA;
|
||||
BLASINT8* cur_ptrB = ptrB + n_block * (N_SIZE * k);
|
||||
int32_t* cur_ptrC = ptrC + n_block * N_SIZE;
|
||||
gemm_kernel_1x8(cur_ptrA, cur_ptrB, cur_ptrC, ldc, k, COMP_SV_LEN);
|
||||
}
|
||||
}
|
||||
|
||||
void decode_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
|
||||
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
|
||||
const int32_t* oc) {
|
||||
BLASINT8* ptrA = (BLASINT8*)a;
|
||||
BLASINT8* ptrB = (BLASINT8*)b;
|
||||
int32_t* ptrC = c;
|
||||
size_t split_n = n / N_SIZE;
|
||||
|
||||
for (size_t n_block = 0; n_block < split_n; n_block++) {
|
||||
BLASINT8* cur_ptrA = ptrA;
|
||||
BLASINT8* cur_ptrB = ptrB + n_block * (N_SIZE * (k / 2));
|
||||
int32_t* cur_ptrC = ptrC + n_block * N_SIZE;
|
||||
gemm_kernel_1x8_int4(cur_ptrA, cur_ptrB, cur_ptrC, (ldc / 2), (k / 2), COMP_SV_LEN);
|
||||
}
|
||||
}
|
||||
void reorder_B_gemm(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
|
||||
const size_t n, const size_t ldb, const void* b, void* b_reordered) {
|
||||
throw std::runtime_error("haven't supported reorder");
|
||||
}
|
||||
|
||||
size_t get_reorder_B_size(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
|
||||
const size_t n) {
|
||||
throw std::runtime_error("haven't supported reorder");
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,414 +0,0 @@
|
||||
#include "prefillgemm_int4/integer_gemm_kernels.h"
|
||||
#include "utils.hpp"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void pack_b_1x8(void* bufferB, const void* cur_b_ptr, size_t n, size_t k, size_t ldb, const BLASINT8 ob) {
|
||||
BLASINT8* bufferB_typed = (BLASINT8*)bufferB;
|
||||
BLASINT8* cur_b_typed = (BLASINT8*)cur_b_ptr;
|
||||
|
||||
size_t split_n = n / N_SIZE;
|
||||
size_t split_k = k / K_SIZE;
|
||||
|
||||
// TODO::vectorization
|
||||
for (size_t np = 0; np < split_n; np++) {
|
||||
for (size_t n_idx = 0; n_idx < N_SIZE; n_idx++) {
|
||||
for (size_t kp = 0; kp < split_k; kp++) {
|
||||
for (size_t k_idx = 0; k_idx < K_SIZE; k_idx++) {
|
||||
bufferB_typed[np * (N_SIZE * k) + kp * (K_SIZE * N_SIZE) + n_idx * K_SIZE + k_idx] =
|
||||
cur_b_typed[INDEXING_B((kp * K_SIZE + k_idx), (np * N_SIZE + n_idx), ldb)] + ob;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void pack_b_1x8_int4(void* bufferB, const void* cur_b_ptr, size_t n, size_t k, size_t ldb, const BLASINT8 ob) {
|
||||
uint8_t* bufferB_typed = (uint8_t*)bufferB;
|
||||
uint8_t* cur_b_typed = (uint8_t*)cur_b_ptr;
|
||||
|
||||
#define RHS_MASK 0x0F
|
||||
#define LHS_MASK 0xF0
|
||||
|
||||
size_t split_n = n / N_SIZE;
|
||||
size_t split_k = k / K_SIZE;
|
||||
|
||||
// TODO::vectorization
|
||||
for (size_t np = 0; np < split_n; np++) {
|
||||
for (size_t n_idx = 0; n_idx < N_SIZE; n_idx++) {
|
||||
for (size_t kp = 0; kp < split_k; kp++) {
|
||||
for (size_t k_idx = 0; k_idx < K_SIZE; k_idx += 2) {
|
||||
uint8_t b01 = cur_b_typed[INDEXING_B((kp * K_SIZE + k_idx / 2), (np * N_SIZE + n_idx), ldb)];
|
||||
uint8_t b23 = cur_b_typed[INDEXING_B((kp * K_SIZE + k_idx / 2 + K_SIZE / 2), (np * N_SIZE + n_idx), ldb)];
|
||||
uint8_t b02 = (b01 & LHS_MASK) | ((b23 & LHS_MASK) >> 4);
|
||||
uint8_t b13 = (b23 & RHS_MASK) | ((b01 & RHS_MASK) << 4);
|
||||
bufferB_typed[np * (N_SIZE * k) + kp * (K_SIZE * N_SIZE) + n_idx * K_SIZE + k_idx] = b02;
|
||||
bufferB_typed[np * (N_SIZE * k) + kp * (K_SIZE * N_SIZE) + n_idx * K_SIZE + k_idx + 1] = b13;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void gemm_kernel_1x8(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len) {
|
||||
int64_t run_k_depth = k_depth;
|
||||
int64_t run_sv_len = sv_len;
|
||||
int64_t run_2sv_len = 2 * sv_len;
|
||||
int64_t move_lhs = sv_len;
|
||||
int64_t move_rhs = N_SIZE * sv_len;
|
||||
int32_t* dst_ptr = accum_ptr;
|
||||
ldc -= N_SIZE;
|
||||
ldc *= 4;
|
||||
|
||||
asm volatile(
|
||||
|
||||
"ptrue p0.b, all\n"
|
||||
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
|
||||
"dup z16.s, #0\n"
|
||||
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
|
||||
"dup z17.s, #0\n"
|
||||
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
|
||||
"dup z18.s, #0\n"
|
||||
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
|
||||
"dup z19.s, #0\n"
|
||||
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
|
||||
"dup z20.s, #0\n"
|
||||
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
|
||||
"dup z21.s, #0\n"
|
||||
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
|
||||
"dup z22.s, #0\n"
|
||||
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
|
||||
"dup z23.s, #0\n"
|
||||
"ld1b {z8.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
|
||||
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
|
||||
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
|
||||
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
|
||||
|
||||
"ble 1f\n"
|
||||
|
||||
"cmp %[run_k_depth], %[run_2sv_len]\n"
|
||||
"blt 2f\n"
|
||||
|
||||
"3:\n"
|
||||
"ld1b {z9.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
|
||||
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
|
||||
"sdot z16.s, z8.b, z0.b\n"
|
||||
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
|
||||
"sdot z17.s, z8.b, z1.b\n"
|
||||
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
|
||||
"sdot z18.s, z8.b, z2.b\n"
|
||||
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
|
||||
"sdot z19.s, z8.b, z3.b\n"
|
||||
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
|
||||
"sdot z20.s, z8.b, z4.b\n"
|
||||
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
|
||||
"sdot z21.s, z8.b, z5.b\n"
|
||||
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
|
||||
"sdot z22.s, z8.b, z6.b\n"
|
||||
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
|
||||
"sdot z23.s, z8.b, z7.b\n"
|
||||
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
|
||||
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
|
||||
"sub %[run_k_depth], %[run_k_depth], %[run_2sv_len]\n"
|
||||
|
||||
"ld1b {z8.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
|
||||
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
|
||||
"sdot z16.s, z9.b, z0.b\n"
|
||||
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
|
||||
"sdot z17.s, z9.b, z1.b\n"
|
||||
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
|
||||
"sdot z18.s, z9.b, z2.b\n"
|
||||
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
|
||||
"sdot z19.s, z9.b, z3.b\n"
|
||||
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
|
||||
"sdot z20.s, z9.b, z4.b\n"
|
||||
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
|
||||
"sdot z21.s, z9.b, z5.b\n"
|
||||
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
|
||||
"sdot z22.s, z9.b, z6.b\n"
|
||||
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
|
||||
"sdot z23.s, z9.b, z7.b\n"
|
||||
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
|
||||
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
|
||||
"cmp %[run_k_depth], %[run_2sv_len]\n"
|
||||
"bge 3b\n"
|
||||
|
||||
"cmp %[run_k_depth], #0\n"
|
||||
"ble 1f\n"
|
||||
|
||||
"2:\n"
|
||||
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
|
||||
"sdot z16.s, z8.b, z0.b\n"
|
||||
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
|
||||
"sdot z17.s, z8.b, z1.b\n"
|
||||
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
|
||||
"sdot z18.s, z8.b, z2.b\n"
|
||||
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
|
||||
"sdot z19.s, z8.b, z3.b\n"
|
||||
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
|
||||
"sdot z20.s, z8.b, z4.b\n"
|
||||
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
|
||||
"sdot z21.s, z8.b, z5.b\n"
|
||||
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
|
||||
"sdot z22.s, z8.b, z6.b\n"
|
||||
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
|
||||
"sdot z23.s, z8.b, z7.b\n"
|
||||
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
|
||||
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
|
||||
"ld1b {z8.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
|
||||
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
|
||||
"bgt 2b\n"
|
||||
|
||||
"1:\n"
|
||||
"sdot z16.s, z8.b, z0.b\n"
|
||||
"sdot z17.s, z8.b, z1.b\n"
|
||||
"sdot z18.s, z8.b, z2.b\n"
|
||||
"sdot z19.s, z8.b, z3.b\n"
|
||||
"sdot z20.s, z8.b, z4.b\n"
|
||||
"sdot z21.s, z8.b, z5.b\n"
|
||||
"sdot z22.s, z8.b, z6.b\n"
|
||||
"sdot z23.s, z8.b, z7.b\n"
|
||||
|
||||
PROCESS_ACCUM(0, 16, x16, dst_ptr, p0) PROCESS_ACCUM(1, 17, x17, dst_ptr, p0)
|
||||
PROCESS_ACCUM(2, 18, x18, dst_ptr, p0) PROCESS_ACCUM(3, 19, x19, dst_ptr, p0)
|
||||
PROCESS_ACCUM(4, 20, x16, dst_ptr, p0) PROCESS_ACCUM(5, 21, x17, dst_ptr, p0)
|
||||
PROCESS_ACCUM(6, 22, x18, dst_ptr, p0) PROCESS_ACCUM(7, 23, x19, dst_ptr, p0)
|
||||
|
||||
: [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [run_k_depth] "+r"(run_k_depth), [dst_ptr] "+wr"(dst_ptr)
|
||||
: [run_sv_len] "r"(run_sv_len), [run_2sv_len] "r"(run_2sv_len), [move_lhs] "r"(move_lhs),
|
||||
[move_rhs] "r"(move_rhs), [ldc] "r"(ldc), [accum_ptr] "r"(accum_ptr)
|
||||
: "cc", "memory", "w0", "w1", "w2", "w3", "w4", "w5", "w6", "w7", "w8", "w9", "w10", "w11", "w12", "w13", "w14",
|
||||
"w15", "x16", "x17", "x18", "x19", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11",
|
||||
"z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27",
|
||||
"z28", "z29", "z30", "z31");
|
||||
}
|
||||
|
||||
// A: z8 ~ z11
|
||||
// B: z0 ~ z7
|
||||
// C: z16 ~ z23
|
||||
// M: z12 z13
|
||||
// MB: z14
|
||||
|
||||
void gemm_kernel_1x8_int4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len) {
|
||||
int64_t run_k_depth = k_depth;
|
||||
int64_t run_sv_len = sv_len;
|
||||
int64_t run_2sv_len = 2 * sv_len;
|
||||
int64_t move_lhs = 2 * sv_len;
|
||||
int64_t move_rhs = N_SIZE * sv_len;
|
||||
int32_t* dst_ptr = accum_ptr;
|
||||
ldc -= N_SIZE;
|
||||
ldc *= 4;
|
||||
|
||||
asm volatile(
|
||||
"ptrue p0.b, all\n"
|
||||
"mov z12.b, #0xF0\n" //mask high
|
||||
"mov z13.b, #0x0F\n" //mask low
|
||||
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
|
||||
"dup z16.s, #0\n"
|
||||
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
|
||||
"dup z17.s, #0\n"
|
||||
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
|
||||
"dup z18.s, #0\n"
|
||||
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
|
||||
"dup z19.s, #0\n"
|
||||
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
|
||||
"dup z20.s, #0\n"
|
||||
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
|
||||
"dup z21.s, #0\n"
|
||||
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
|
||||
"dup z22.s, #0\n"
|
||||
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
|
||||
"dup z23.s, #0\n"
|
||||
|
||||
"ld1b {z8.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
|
||||
"ld1b {z9.b}, p0/z, [%[lhs_ptr], #1, MUL VL]\n"
|
||||
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
|
||||
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
|
||||
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
|
||||
|
||||
"ble 1f\n"
|
||||
|
||||
"cmp %[run_k_depth], %[run_2sv_len]\n"
|
||||
"blt 2f\n"
|
||||
|
||||
"3:\n"
|
||||
"ld1b {z10.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
|
||||
"ld1b {z11.b}, p0/z, [%[lhs_ptr], #1, MUL VL]\n"
|
||||
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(0, 14, 12, 13, 4)
|
||||
"sdot z16.s, z8.b, z0.b\n"
|
||||
"sdot z16.s, z9.b, z14.b\n"
|
||||
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(1, 14, 12, 13, 4)
|
||||
"sdot z17.s, z8.b, z1.b\n"
|
||||
"sdot z17.s, z9.b, z14.b\n"
|
||||
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(2, 14, 12, 13, 4)
|
||||
"sdot z18.s, z8.b, z2.b\n"
|
||||
"sdot z18.s, z9.b, z14.b\n"
|
||||
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(3, 14, 12, 13, 4)
|
||||
"sdot z19.s, z8.b, z3.b\n"
|
||||
"sdot z19.s, z9.b, z14.b\n"
|
||||
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(4, 14, 12, 13, 4)
|
||||
"sdot z20.s, z8.b, z4.b\n"
|
||||
"sdot z20.s, z9.b, z14.b\n"
|
||||
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(5, 14, 12, 13, 4)
|
||||
"sdot z21.s, z8.b, z5.b\n"
|
||||
"sdot z21.s, z9.b, z14.b\n"
|
||||
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(6, 14, 12, 13, 4)
|
||||
"sdot z22.s, z8.b, z6.b\n"
|
||||
"sdot z22.s, z9.b, z14.b\n"
|
||||
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(7, 14, 12, 13, 4)
|
||||
"sdot z23.s, z8.b, z7.b\n"
|
||||
"sdot z23.s, z9.b, z14.b\n"
|
||||
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
|
||||
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
|
||||
"sub %[run_k_depth], %[run_k_depth], %[run_2sv_len]\n"
|
||||
|
||||
"ld1b {z8.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
|
||||
"ld1b {z9.b}, p0/z, [%[lhs_ptr], #1, MUL VL]\n"
|
||||
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(0, 14, 12, 13, 4)
|
||||
"sdot z16.s, z10.b, z0.b\n"
|
||||
"sdot z16.s, z11.b, z14.b\n"
|
||||
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(1, 14, 12, 13, 4)
|
||||
"sdot z17.s, z10.b, z1.b\n"
|
||||
"sdot z17.s, z11.b, z14.b\n"
|
||||
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(2, 14, 12, 13, 4)
|
||||
"sdot z18.s, z10.b, z2.b\n"
|
||||
"sdot z18.s, z11.b, z14.b\n"
|
||||
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(3, 14, 12, 13, 4)
|
||||
"sdot z19.s, z10.b, z3.b\n"
|
||||
"sdot z19.s, z11.b, z14.b\n"
|
||||
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(4, 14, 12, 13, 4)
|
||||
"sdot z20.s, z10.b, z4.b\n"
|
||||
"sdot z20.s, z11.b, z14.b\n"
|
||||
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(5, 14, 12, 13, 4)
|
||||
"sdot z21.s, z10.b, z5.b\n"
|
||||
"sdot z21.s, z11.b, z14.b\n"
|
||||
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(6, 14, 12, 13, 4)
|
||||
"sdot z22.s, z10.b, z6.b\n"
|
||||
"sdot z22.s, z11.b, z14.b\n"
|
||||
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(7, 14, 12, 13, 4)
|
||||
"sdot z23.s, z10.b, z7.b\n"
|
||||
"sdot z23.s, z11.b, z14.b\n"
|
||||
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
|
||||
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
|
||||
"cmp %[run_k_depth], %[run_2sv_len]\n"
|
||||
"bge 3b\n"
|
||||
|
||||
"cmp %[run_k_depth], #0\n"
|
||||
"ble 1f\n"
|
||||
|
||||
"2:\n"
|
||||
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(0, 14, 12, 13, 4)
|
||||
"sdot z16.s, z8.b, z0.b\n"
|
||||
"sdot z16.s, z9.b, z14.b\n"
|
||||
"ld1b {z0.b}, p0/z, [%[rhs_ptr], #0, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(1, 14, 12, 13, 4)
|
||||
"sdot z17.s, z8.b, z1.b\n"
|
||||
"sdot z17.s, z9.b, z14.b\n"
|
||||
"ld1b {z1.b}, p0/z, [%[rhs_ptr], #1, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(2, 14, 12, 13, 4)
|
||||
"sdot z18.s, z8.b, z2.b\n"
|
||||
"sdot z18.s, z9.b, z14.b\n"
|
||||
"ld1b {z2.b}, p0/z, [%[rhs_ptr], #2, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(3, 14, 12, 13, 4)
|
||||
"sdot z19.s, z8.b, z3.b\n"
|
||||
"sdot z19.s, z9.b, z14.b\n"
|
||||
"ld1b {z3.b}, p0/z, [%[rhs_ptr], #3, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(4, 14, 12, 13, 4)
|
||||
"sdot z20.s, z8.b, z4.b\n"
|
||||
"sdot z20.s, z9.b, z14.b\n"
|
||||
"ld1b {z4.b}, p0/z, [%[rhs_ptr], #4, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(5, 14, 12, 13, 4)
|
||||
"sdot z21.s, z8.b, z5.b\n"
|
||||
"sdot z21.s, z9.b, z14.b\n"
|
||||
"ld1b {z5.b}, p0/z, [%[rhs_ptr], #5, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(6, 14, 12, 13, 4)
|
||||
"sdot z22.s, z8.b, z6.b\n"
|
||||
"sdot z22.s, z9.b, z14.b\n"
|
||||
"ld1b {z6.b}, p0/z, [%[rhs_ptr], #6, MUL VL]\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(7, 14, 12, 13, 4)
|
||||
"sdot z23.s, z8.b, z7.b\n"
|
||||
"sdot z23.s, z9.b, z14.b\n"
|
||||
"ld1b {z7.b}, p0/z, [%[rhs_ptr], #7, MUL VL]\n"
|
||||
"add %[rhs_ptr], %[rhs_ptr], %[move_rhs]\n"
|
||||
"ld1b {z8.b}, p0/z, [%[lhs_ptr], #0, MUL VL]\n"
|
||||
"ld1b {z9.b}, p0/z, [%[lhs_ptr], #1, MUL VL]\n"
|
||||
"add %[lhs_ptr], %[lhs_ptr], %[move_lhs]\n"
|
||||
"bgt 2b\n"
|
||||
|
||||
"1:\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(0, 14, 12, 13, 4)
|
||||
"sdot z16.s, z8.b, z0.b\n"
|
||||
"sdot z16.s, z9.b, z14.b\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(1, 14, 12, 13, 4)
|
||||
"sdot z17.s, z8.b, z1.b\n"
|
||||
"sdot z17.s, z9.b, z14.b\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(2, 14, 12, 13, 4)
|
||||
"sdot z18.s, z8.b, z2.b\n"
|
||||
"sdot z18.s, z9.b, z14.b\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(3, 14, 12, 13, 4)
|
||||
"sdot z19.s, z8.b, z3.b\n"
|
||||
"sdot z19.s, z9.b, z14.b\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(4, 14, 12, 13, 4)
|
||||
"sdot z20.s, z8.b, z4.b\n"
|
||||
"sdot z20.s, z9.b, z14.b\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(5, 14, 12, 13, 4)
|
||||
"sdot z21.s, z8.b, z5.b\n"
|
||||
"sdot z21.s, z9.b, z14.b\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(6, 14, 12, 13, 4)
|
||||
"sdot z22.s, z8.b, z6.b\n"
|
||||
"sdot z22.s, z9.b, z14.b\n"
|
||||
INT4_CP_MASK_SHIFT_1x8(7, 14, 12, 13, 4)
|
||||
"sdot z23.s, z8.b, z7.b\n"
|
||||
"sdot z23.s, z9.b, z14.b\n"
|
||||
|
||||
PROCESS_ACCUM(0, 16, x16, dst_ptr, p0)
|
||||
PROCESS_ACCUM(1, 17, x17, dst_ptr, p0)
|
||||
PROCESS_ACCUM(2, 18, x18, dst_ptr, p0)
|
||||
PROCESS_ACCUM(3, 19, x19, dst_ptr, p0)
|
||||
PROCESS_ACCUM(4, 20, x16, dst_ptr, p0)
|
||||
PROCESS_ACCUM(5, 21, x17, dst_ptr, p0)
|
||||
PROCESS_ACCUM(6, 22, x18, dst_ptr, p0)
|
||||
PROCESS_ACCUM(7, 23, x19, dst_ptr, p0)
|
||||
|
||||
:
|
||||
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
|
||||
[run_k_depth] "+r"(run_k_depth),
|
||||
[dst_ptr] "+wr"(dst_ptr)
|
||||
:
|
||||
[run_sv_len] "r"(run_sv_len), [run_2sv_len] "r"(run_2sv_len),
|
||||
[move_lhs] "r"(move_lhs), [move_rhs] "r"(move_rhs), [ldc] "r"(ldc),
|
||||
[accum_ptr] "r"(accum_ptr)
|
||||
:
|
||||
"cc", "memory",
|
||||
"w0","w1","w2","w3","w4","w5","w6","w7",
|
||||
"w8","w9","w10","w11","w12","w13","w14","w15",
|
||||
"x16","x17","x18","x19",
|
||||
"z0","z1","z2","z3","z4","z5","z6","z7",
|
||||
"z8","z9","z10","z11","z12","z13","z14","z15",
|
||||
"z16","z17","z18","z19","z20","z21","z22","z23",
|
||||
"z24","z25","z26","z27","z28","z29","z30","z31"
|
||||
);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,160 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_C_STANDARD_REQUIRED ON)
|
||||
|
||||
# can be compiled for SVE256 (32) or SVE512 (64)
|
||||
set(SV_LENGTH 32) # in bytes
|
||||
|
||||
include_directories("${PROJECT_SOURCE_DIR}")
|
||||
include_directories("${PROJECT_SOURCE_DIR}/include")
|
||||
include_directories("${PROJECT_SOURCE_DIR}/standalone")
|
||||
add_compile_options(-fPIC -fvisibility=hidden -fstack-protector-strong -march=armv8.3-a+sve+i8mm -O3)
|
||||
# add_compile_options(-Wall -Wextra -Werror)
|
||||
|
||||
# sources are split into several groups: matmul kernels (sources are compiled multiple times),
|
||||
# packing kernels (sources are compiled multiple times),
|
||||
# sequential/parallel pipelines,
|
||||
# interface (sources are compiled multiple times)
|
||||
set(INT_GEMM_KERNELS "")
|
||||
set(INT_PACK_KERNELS "")
|
||||
set(INT_GEMM_INTERFACE "")
|
||||
set(INT_GEMM_PARALLEL "")
|
||||
set(INT_GEMM_SEQ "")
|
||||
set(BETA_KERNELS "")
|
||||
set(POST_OPS_KERNELS "")
|
||||
set(INT_SMALL_KERNELS "")
|
||||
set(GEMM_DRIVERS "")
|
||||
|
||||
# Supported precisions are i/u for A and B matrices (4 combinations)
|
||||
set(LHS_TYPES LHS_INT LHS_UINT)
|
||||
set(RHS_TYPES RHS_INT RHS_UINT)
|
||||
|
||||
# compile matrix-multiplication kernels multiple times
|
||||
set(INTEGER_MM_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_kernels.c")
|
||||
set(M_SIZES 1 2 3 4)
|
||||
set(N_SIZES 1 2 3 4)
|
||||
foreach(M_SIZE ${M_SIZES})
|
||||
foreach(N_SIZE ${N_SIZES})
|
||||
foreach(LHS_TYPE ${LHS_TYPES})
|
||||
foreach(RHS_TYPE ${RHS_TYPES})
|
||||
add_library(integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH} OBJECT ${INTEGER_MM_KERNELS_SRC})
|
||||
target_compile_options(integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH} PUBLIC -DM_SIZE=${M_SIZE}
|
||||
-DN_SIZE=${N_SIZE} -D${LHS_TYPE} -D${RHS_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH} PUBLIC
|
||||
"${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND INT_GEMM_KERNELS $<TARGET_OBJECTS:integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH}>)
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# compile interface multiple times
|
||||
set(INTEGER_GEMM_IFACE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_interface.c")
|
||||
foreach(LHS_TYPE ${LHS_TYPES})
|
||||
foreach(RHS_TYPE ${RHS_TYPES})
|
||||
add_library(integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH} OBJECT ${INTEGER_GEMM_IFACE_SRC})
|
||||
target_compile_options(integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH} PUBLIC -D${LHS_TYPE} -D${RHS_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH} PUBLIC
|
||||
"${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND INT_GEMM_INTERFACE $<TARGET_OBJECTS:integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH}>)
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# compile threading layer
|
||||
# set(INTEGER_GEMM_PAR_PIPE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/parallel_int_gemm_pipeline.c")
|
||||
# add_library(integer_gemm_par_pipe_${SV_LENGTH} OBJECT ${INTEGER_GEMM_PAR_PIPE_SRC})
|
||||
# target_compile_options(integer_gemm_par_pipe_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
# target_include_directories(integer_gemm_par_pipe_${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
# list(APPEND INT_GEMM_PARALLEL $<TARGET_OBJECTS:integer_gemm_par_pipe_${SV_LENGTH}>)
|
||||
|
||||
|
||||
# compile sequential layer
|
||||
set(INTEGER_GEMMSEQ_PIPE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/sequential_int_gemm_pipeline.c")
|
||||
add_library(integer_gemm_seq_pipe_${SV_LENGTH} OBJECT ${INTEGER_GEMMSEQ_PIPE_SRC})
|
||||
target_compile_options(integer_gemm_seq_pipe_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(integer_gemm_seq_pipe_${SV_LENGTH} PUBLIC
|
||||
"${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND INT_GEMM_SEQ $<TARGET_OBJECTS:integer_gemm_seq_pipe_${SV_LENGTH}>)
|
||||
|
||||
# compile packingA kernels
|
||||
set(INT_PACK_A_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_pack_a_kernels.c")
|
||||
set(TRANSA_VALS TRANSA NOTRANSA)
|
||||
foreach(LHS_TYPE ${LHS_TYPES})
|
||||
foreach(TRANSA_VAL ${TRANSA_VALS})
|
||||
add_library(integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH} OBJECT ${INT_PACK_A_KERNELS_SRC})
|
||||
target_compile_options(integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH} -D${LHS_TYPE} -D${TRANSA_VAL})
|
||||
target_include_directories(integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND INT_PACK_KERNELS $<TARGET_OBJECTS:integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH}>)
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# compile packingB kernels
|
||||
set(INT_PACK_B_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_pack_b_kernels.c")
|
||||
set(TRANSB_VALS TRANSB NOTRANSB)
|
||||
foreach(RHS_TYPE ${RHS_TYPES})
|
||||
foreach(TRANSB_VAL ${TRANSB_VALS})
|
||||
add_library(integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH} OBJECT ${INT_PACK_B_KERNELS_SRC})
|
||||
target_compile_options(integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH} -D${RHS_TYPE} -D${TRANSB_VAL})
|
||||
target_include_directories(integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND INT_PACK_KERNELS $<TARGET_OBJECTS:integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH}>)
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# compile beta kernels
|
||||
set(BETA_OPTS BETA_OPT BETA_NO_OPT)
|
||||
|
||||
foreach(B_TYPE ${BETA_OPTS})
|
||||
add_library(beta_kernels_${B_TYPE} OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/integer_beta_kernels.c")
|
||||
target_compile_options(beta_kernels_${B_TYPE} PUBLIC -D${B_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(beta_kernels_${B_TYPE} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND BETA_KERNELS $<TARGET_OBJECTS:beta_kernels_${B_TYPE}>)
|
||||
endforeach()
|
||||
|
||||
# compile int gemm drivers
|
||||
foreach(B_TYPE ${BETA_OPTS})
|
||||
add_library(gemm_driver_${B_TYPE} OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_driver.c")
|
||||
target_compile_options(gemm_driver_${B_TYPE} PUBLIC -D${B_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(gemm_driver_${B_TYPE} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND GEMM_DRIVERS $<TARGET_OBJECTS:gemm_driver_${B_TYPE}>)
|
||||
endforeach()
|
||||
|
||||
# compile post-ops kernels
|
||||
foreach(B_TYPE ${BETA_OPTS})
|
||||
add_library(post_ops_kernels_${B_TYPE} OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/integer_post_ops_kernels.c")
|
||||
target_compile_options(post_ops_kernels_${B_TYPE} PUBLIC -D${B_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(post_ops_kernels_${B_TYPE} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND POST_OPS_KERNELS $<TARGET_OBJECTS:post_ops_kernels_${B_TYPE}>)
|
||||
endforeach()
|
||||
|
||||
# compile matrix-multiplication small kernels multiple times
|
||||
set(SMALL_KERNELS_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_small_kernels.c")
|
||||
set(TRANSA_VALS TRANSA NOTRANSA)
|
||||
set(TRANSB_VALS TRANSB NOTRANSB)
|
||||
set(OC_TYPES OC_FIX OC_COL OC_ROW)
|
||||
|
||||
foreach(LHS_TYPE ${LHS_TYPES})
|
||||
foreach(RHS_TYPE ${RHS_TYPES})
|
||||
foreach(TRANSA_VAL ${TRANSA_VALS})
|
||||
foreach(TRANSB_VAL ${TRANSB_VALS})
|
||||
foreach(OC_TYPE ${OC_TYPES})
|
||||
add_library(small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH} OBJECT ${SMALL_KERNELS_KERNELS_SRC})
|
||||
target_compile_options(small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH} PUBLIC -D${LHS_TYPE} -D${RHS_TYPE} -D${TRANSA_VAL} -D${TRANSB_VAL} -D${OC_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND INT_SMALL_KERNELS $<TARGET_OBJECTS:small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH}>)
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
list(APPEND OBJ_FILES_STANDALONE_DIR ${INT_GEMM_KERNELS} ${INT_GEMM_INTERFACE} ${INT_GEMM_SEQ} ${INT_PACK_KERNELS} ${BETA_KERNELS} ${POST_OPS_KERNELS} ${INT_SMALL_KERNELS} ${GEMM_DRIVERS})
|
||||
# set(OBJ_FILES_STANDALONE_DIR ${OBJ_FILES_STANDALONE_DIR} PARENT_SCOPE)
|
||||
# all compiled object files are united into one object library
|
||||
add_library(prefillint8gemm SHARED ${OBJ_FILES_STANDALONE_DIR})
|
||||
|
||||
set_target_properties(prefillint8gemm PROPERTIES
|
||||
ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/prefillint8gemm
|
||||
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/prefillint8gemm
|
||||
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/prefillint8gemm
|
||||
)
|
||||
@@ -1,25 +0,0 @@
|
||||
#ifndef BETA_MACROS_H
|
||||
#define BETA_MACROS_H
|
||||
|
||||
#if defined(OC_FIX)
|
||||
#define OC_TYPE f
|
||||
#define OC_IDX(mi, nn) 0
|
||||
#elif defined(OC_COL)
|
||||
#define OC_TYPE c
|
||||
#define OC_IDX(mi, ni) mi
|
||||
#else
|
||||
#define OC_TYPE r
|
||||
#define OC_IDX(mi, ni) ni
|
||||
#endif
|
||||
|
||||
#if defined(BETA_OPT)
|
||||
#define BETA_SUFF(name) name##_opt
|
||||
#define LDC(m, ldc) ldc
|
||||
#define DTYPE int32_t
|
||||
#else
|
||||
#define BETA_SUFF(name) name
|
||||
#define LDC(m, ldc) m
|
||||
#define DTYPE float
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -1,63 +0,0 @@
|
||||
#ifndef __HELPING_MACROS_H__
|
||||
#define __HELPING_MACROS_H__
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#if defined(LHS_INT) && defined(LHS_UINT)
|
||||
#error "Both LHS_INT and LHS_UINT are defined"
|
||||
#endif
|
||||
|
||||
#if defined(RHS_INT) && defined(RHS_UINT)
|
||||
#error "Both RHS_INT and RHS_UINT are defined"
|
||||
#endif
|
||||
|
||||
#ifdef LHS_INT
|
||||
#define LHS_TYPE s
|
||||
#define LHS_INT_TYPE int8_t
|
||||
#endif
|
||||
#ifdef LHS_UINT
|
||||
#define LHS_TYPE u
|
||||
#define LHS_INT_TYPE uint8_t
|
||||
#endif
|
||||
#ifdef RHS_INT
|
||||
#define RHS_TYPE s
|
||||
#define RHS_INT_TYPE int8_t
|
||||
#endif
|
||||
#ifdef RHS_UINT
|
||||
#define RHS_TYPE u
|
||||
#define RHS_INT_TYPE uint8_t
|
||||
#endif
|
||||
|
||||
// mangling macros
|
||||
#define ADD_M_N_SIZES(name, m_size, n_size) name##_##m_size##x##n_size
|
||||
#define ADD_M_N_SIZES_MACRO(name, m_size, n_size) ADD_M_N_SIZES(name, m_size, n_size)
|
||||
#define ADD_TYPES(name, lhs_type, rhs_type) name##_##lhs_type##8##rhs_type##8s32
|
||||
#define ADD_TYPES_MACRO(name, lhs_type, rhs_type) ADD_TYPES(name, lhs_type, rhs_type)
|
||||
#define ADD_TYPES_SUFF(name) ADD_TYPES_MACRO(name, LHS_TYPE, RHS_TYPE)
|
||||
#define ADD_ONE_TYPE_TRANSP(name, type, nt) name##_##type##8_##nt
|
||||
#define ADD_ONE_TYPE_TRANSP_MACRO(name, type, nt) ADD_ONE_TYPE_TRANSP(name, type, nt)
|
||||
#define ADD_PACK_A_N_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, LHS_TYPE, n)
|
||||
#define ADD_PACK_B_N_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, RHS_TYPE, n)
|
||||
#define ADD_PACK_A_T_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, LHS_TYPE, t)
|
||||
#define ADD_PACK_B_T_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, RHS_TYPE, t)
|
||||
#define ADD_TWO_TYPES_TRANSP(name, lhs_type, rhs_type, a_t, b_t, oc_t) \
|
||||
name##_##lhs_type##8##rhs_type##8s32##_##a_t##b_t##_##oc_t
|
||||
#define ADD_TWO_TYPES_TRANSP_MACRO(name, lhs_type, rhs_type, a_t, b_t, oc_t) \
|
||||
ADD_TWO_TYPES_TRANSP(name, lhs_type, rhs_type, a_t, b_t, oc_t)
|
||||
#define ADD_TRANSP_MACRO(name, a_t, b_t, oc_t) ADD_TWO_TYPES_TRANSP_MACRO(name, LHS_TYPE, RHS_TYPE, a_t, b_t, oc_t)
|
||||
|
||||
#ifdef ENABLE_THREADING
|
||||
#define ADD_THREAD_SUFF(name) name##_thread
|
||||
#else
|
||||
#define ADD_THREAD_SUFF(name) name
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -1,82 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
#include "beta_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
// Column major C, Fixed C_offset
|
||||
void BETA_SUFF(beta_cf_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
|
||||
DTYPE oc_val = (DTYPE)*oc;
|
||||
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
|
||||
DTYPE beta_val = (DTYPE)beta;
|
||||
for (size_t ni = 0; ni < n_block_size; ++ni) {
|
||||
for (size_t mi = 0; mi < m_block_size; ++mi) {
|
||||
c_typed_ptr[ldc * ni + mi] = beta_val * ((DTYPE)c_ptr[ldc * ni + mi]) + oc_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Column major C, Column major C_offset
|
||||
void BETA_SUFF(beta_cc_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
|
||||
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
|
||||
DTYPE beta_val = (DTYPE)beta;
|
||||
for (size_t ni = 0; ni < n_block_size; ++ni) {
|
||||
for (size_t mi = 0; mi < m_block_size; ++mi) {
|
||||
c_typed_ptr[ldc * ni + mi] = beta_val * ((DTYPE)c_ptr[ldc * ni + mi]) + ((DTYPE)oc[mi]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Column major C, Row major C_offset
|
||||
void BETA_SUFF(beta_cr_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
|
||||
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
|
||||
DTYPE beta_val = (DTYPE)beta;
|
||||
for (size_t ni = 0; ni < n_block_size; ++ni) {
|
||||
for (size_t mi = 0; mi < m_block_size; ++mi) {
|
||||
c_typed_ptr[ldc * ni + mi] = beta_val * ((DTYPE)c_ptr[ldc * ni + mi]) + ((DTYPE)oc[ni]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Row major C, Fixed C_offset
|
||||
// for row-major we actually swap m and n values so we reswap it here again
|
||||
void BETA_SUFF(beta_rf_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
|
||||
DTYPE oc_val = (DTYPE)*oc;
|
||||
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
|
||||
DTYPE beta_val = (DTYPE)beta;
|
||||
for (size_t mi = 0; mi < n_block_size; ++mi) {
|
||||
for (size_t ni = 0; ni < m_block_size; ++ni) {
|
||||
c_typed_ptr[ldc * mi + ni] = beta_val * ((DTYPE)c_ptr[ldc * mi + ni]) + oc_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Row major C, Column major C_offset
|
||||
// for row-major we actually swap m and n values so we reswap it here again
|
||||
void BETA_SUFF(beta_rc_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
|
||||
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
|
||||
DTYPE beta_val = (DTYPE)beta;
|
||||
for (size_t mi = 0; mi < n_block_size; ++mi) {
|
||||
for (size_t ni = 0; ni < m_block_size; ++ni) {
|
||||
c_typed_ptr[ldc * mi + ni] = beta_val * ((DTYPE)c_ptr[ldc * mi + ni]) + ((DTYPE)oc[mi]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Row major C, Row major C_offset
|
||||
// for row-major we actually swap m and n values so we reswap it here again
|
||||
void BETA_SUFF(beta_rr_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
|
||||
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
|
||||
DTYPE beta_val = (DTYPE)beta;
|
||||
for (size_t mi = 0; mi < n_block_size; ++mi) {
|
||||
for (size_t ni = 0; ni < m_block_size; ++ni) {
|
||||
c_typed_ptr[ldc * mi + ni] = beta_val * ((DTYPE)c_ptr[ldc * mi + ni]) + ((DTYPE)oc[ni]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif /* __cplusplus */
|
||||
@@ -1,117 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "beta_macros.h"
|
||||
|
||||
#define OLD_N_SIZE 8
|
||||
#define PACKED_LD_STEP(n_step, k_step, ldb) (n_step * ldb + k_step * OLD_N_SIZE)
|
||||
|
||||
void BETA_SUFF(gemm_driver)(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha,
|
||||
const void* a, size_t lda, const BLASINT8 oa,
|
||||
const void* b, size_t ldb, const BLASINT8 ob,
|
||||
float beta, int32_t* c, size_t ldc, const int32_t* oc) {
|
||||
|
||||
void (**gemm_kernels)(const void*, const void*, int32_t*, size_t, int64_t, int64_t) = arg->gemm_kernels;
|
||||
void (*pack_a_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8) = arg->pack_a_fun;
|
||||
void (*pack_b_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8) = arg->pack_b_fun;
|
||||
void (*beta_func)(int32_t*, const int32_t*, float, size_t, size_t, size_t) = arg->beta_func;
|
||||
size_t (*a_indexing)(size_t m, size_t n, size_t ld) = arg->a_indexing;
|
||||
size_t (*b_indexing)(size_t m, size_t n, size_t ld) = arg->b_indexing;
|
||||
|
||||
#ifndef BETA_OPT
|
||||
void (*post_ops_func)(float, const int32_t*, int32_t*, size_t, size_t, size_t) = arg->post_ops_func;
|
||||
#endif // BETA_OPT
|
||||
|
||||
const BLASINT8* a_typed = (const BLASINT8*) a;
|
||||
const BLASINT8* b_typed = (const BLASINT8*) b;
|
||||
|
||||
BLASINT8* bufferA = (BLASINT8*) aligned_alloc(ALIGNMENT, sizeof(BLASINT8) * K_BLOCK * M_BLOCK);
|
||||
BLASINT8* bufferB = (BLASINT8*) aligned_alloc(ALIGNMENT, sizeof(BLASINT8) * K_BLOCK * N_BLOCK);
|
||||
|
||||
// Tmp buffer is not needed when (alpha = 1 and beta = 0/1)
|
||||
#ifdef BETA_OPT
|
||||
int32_t* bufferC = c;
|
||||
#else
|
||||
int32_t* bufferC = (int32_t*) aligned_alloc(ALIGNMENT, sizeof(int32_t) * m * N_BLOCK);
|
||||
#endif
|
||||
|
||||
if (!bufferA || !bufferB || !bufferC) {
|
||||
free(bufferA);
|
||||
free(bufferB);
|
||||
free(bufferC);
|
||||
printf("Integer GEMM unsuccessful allocation");
|
||||
return;
|
||||
}
|
||||
// printf("pack b beta: %f\n",beta);
|
||||
beta_func(c, oc, beta, m, n, ldc);
|
||||
|
||||
for (size_t n_block = 0; n_block < n; n_block += N_BLOCK) {
|
||||
size_t n_block_size = n - n_block;
|
||||
if (n_block_size > N_BLOCK) {
|
||||
n_block_size = N_BLOCK;
|
||||
}
|
||||
|
||||
#ifndef BETA_OPT
|
||||
// fill bufferC w/ zeros
|
||||
for (size_t tmp_idx = 0; tmp_idx < (m * N_BLOCK); ++tmp_idx) {
|
||||
bufferC[tmp_idx] = 0;
|
||||
}
|
||||
#endif // BETA_OPT
|
||||
|
||||
if (alpha != 0.0f){
|
||||
for (size_t k_block = 0; k_block < k; k_block += K_BLOCK){
|
||||
size_t k_block_size = k - k_block;
|
||||
if (k_block_size > K_BLOCK) {
|
||||
k_block_size = K_BLOCK;
|
||||
}
|
||||
size_t k_block_size_up = (k_block_size + KERNEL_K_STEP - 1) / KERNEL_K_STEP * KERNEL_K_STEP;
|
||||
|
||||
const BLASINT8* curr_b_ptr = b_typed + b_indexing(k_block, n_block, ldb);
|
||||
|
||||
pack_b_fun(bufferB, curr_b_ptr, n_block_size, k_block_size, ldb, ob);
|
||||
for (size_t m_block = 0; m_block < m; m_block += M_BLOCK) {
|
||||
size_t m_block_size = m - m_block;
|
||||
if (m_block_size > M_BLOCK) {
|
||||
m_block_size = M_BLOCK;
|
||||
}
|
||||
const BLASINT8* curr_a_ptr = a_typed + PACKED_LD_STEP(m_block, k_block, lda);
|
||||
pack_a_fun(bufferA, curr_a_ptr, m_block_size, k_block_size, lda, oa);
|
||||
// loop over bufferB, taking parts which fit into L1
|
||||
for (size_t n_sub_block = 0; n_sub_block < n_block_size; n_sub_block += KERNEL_N_STEP) {
|
||||
size_t n_sub_block_size = n_block_size - n_sub_block;
|
||||
if (n_sub_block_size > KERNEL_N_STEP) {
|
||||
n_sub_block_size = KERNEL_N_STEP;
|
||||
}
|
||||
BLASINT8* current_bufferB_ptr = bufferB + n_sub_block * k_block_size_up;
|
||||
// loop over bufferA, taking parts which fit into L1
|
||||
for (size_t m_sub_block = 0; m_sub_block < m_block_size; m_sub_block += KERNEL_M_STEP) {
|
||||
size_t m_sub_block_size = m_block_size - m_sub_block;
|
||||
if (m_sub_block_size > KERNEL_M_STEP) {
|
||||
m_sub_block_size = KERNEL_M_STEP;
|
||||
}
|
||||
BLASINT8* current_bufferA_ptr = bufferA + m_sub_block * k_block_size_up;
|
||||
int32_t* current_bufferC_ptr = bufferC + n_block * ldc + n_sub_block * LDC(m, ldc) + m_sub_block + m_block;
|
||||
// call kernel which performs loop over k_block_size
|
||||
gemm_kernels[(n_sub_block_size - 1) + (m_sub_block_size - 1) * KERNEL_N_STEP](current_bufferA_ptr, current_bufferB_ptr, current_bufferC_ptr,
|
||||
LDC(m, ldc), k_block_size_up, COMP_SV_LEN);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef BETA_OPT
|
||||
// copy C data from bufferC multiplying by alpha and adding initial C data (scaled by beta)
|
||||
int32_t* current_c_ptr = c + n_block * ldc; // col major
|
||||
post_ops_func(alpha, bufferC, current_c_ptr, LDC(m, ldc), n_block_size, ldc);
|
||||
#endif
|
||||
}
|
||||
|
||||
free(bufferA);
|
||||
free(bufferB);
|
||||
|
||||
#ifndef BETA_OPT
|
||||
free(bufferC);
|
||||
#endif
|
||||
}
|
||||
@@ -1,154 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
//#include "cblas.h"
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
|
||||
|
||||
/* matrix saved in rows or cols */
|
||||
typedef enum CBLAS_ORDER {
|
||||
CblasRowMajor = 101,
|
||||
CblasColMajor = 102
|
||||
} CBLAS_ORDER;
|
||||
|
||||
/* matrix transpose or conjugate transpose */
|
||||
typedef enum CBLAS_TRANSPOSE {
|
||||
CblasNoTrans = 111,
|
||||
CblasTrans = 112,
|
||||
CblasConjTrans = 113, // conjugate transpose
|
||||
CblasConjNoTrans = 114
|
||||
} CBLAS_TRANSPOSE;
|
||||
|
||||
typedef CBLAS_ORDER CBLAS_LAYOUT;
|
||||
|
||||
typedef enum CBLAS_OFFSET {
|
||||
CblasRowOffset = 171,
|
||||
CblasColOffset = 172,
|
||||
CblasFixOffset = 173
|
||||
} CBLAS_OFFSET;
|
||||
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#define ADD_KERNEL_SUFF(name, m_size, n_size) ADD_M_N_SIZES_MACRO(ADD_TYPES_MACRO(name, LHS_TYPE, RHS_TYPE), m_size, n_size)
|
||||
|
||||
static void (*gemm_kernels[])(const void*, const void*, int32_t*, size_t, int64_t, int64_t) = {
|
||||
ADD_KERNEL_SUFF(gemm_kernel, 1, 1), ADD_KERNEL_SUFF(gemm_kernel, 1, 2),ADD_KERNEL_SUFF(gemm_kernel, 1, 3), ADD_KERNEL_SUFF(gemm_kernel, 1, 4),
|
||||
ADD_KERNEL_SUFF(gemm_kernel, 2, 1), ADD_KERNEL_SUFF(gemm_kernel, 2, 2),ADD_KERNEL_SUFF(gemm_kernel, 2, 3), ADD_KERNEL_SUFF(gemm_kernel, 2, 4),
|
||||
ADD_KERNEL_SUFF(gemm_kernel, 3, 1), ADD_KERNEL_SUFF(gemm_kernel, 3, 2),ADD_KERNEL_SUFF(gemm_kernel, 3, 3), ADD_KERNEL_SUFF(gemm_kernel, 3, 4),
|
||||
ADD_KERNEL_SUFF(gemm_kernel, 4, 1), ADD_KERNEL_SUFF(gemm_kernel, 4, 2),ADD_KERNEL_SUFF(gemm_kernel, 4, 3), ADD_KERNEL_SUFF(gemm_kernel, 4, 4)
|
||||
};
|
||||
|
||||
static void (*pack_b_funs[])(void*, const void*, size_t, size_t, size_t, const BLASINT8) = {
|
||||
ADD_PACK_B_N_SUFF(pack_b),
|
||||
ADD_PACK_B_T_SUFF(pack_b)
|
||||
};
|
||||
|
||||
static void (*pack_a_funs[])(void*, const void*, size_t, size_t, size_t, const BLASINT8) = {
|
||||
ADD_PACK_A_N_SUFF(pack_a),
|
||||
ADD_PACK_A_T_SUFF(pack_a)
|
||||
};
|
||||
|
||||
static void (*small_kernels[])(const size_t, const size_t, const size_t, const float,
|
||||
const void *, const size_t, const BLASINT8,
|
||||
const void *, const size_t, const BLASINT8,
|
||||
const float, int32_t *, const size_t, const int32_t *) = {
|
||||
ADD_TRANSP_MACRO(small_kernel, n, n, f), ADD_TRANSP_MACRO(small_kernel, n, t, f),
|
||||
ADD_TRANSP_MACRO(small_kernel, t, n, f), ADD_TRANSP_MACRO(small_kernel, t, t, f),
|
||||
ADD_TRANSP_MACRO(small_kernel, n, n, c), ADD_TRANSP_MACRO(small_kernel, n, t, c),
|
||||
ADD_TRANSP_MACRO(small_kernel, t, n, c), ADD_TRANSP_MACRO(small_kernel, t, t, c),
|
||||
ADD_TRANSP_MACRO(small_kernel, n, n, r), ADD_TRANSP_MACRO(small_kernel, n, t, r),
|
||||
ADD_TRANSP_MACRO(small_kernel, t, n, r), ADD_TRANSP_MACRO(small_kernel, t, t, r),
|
||||
};
|
||||
|
||||
static void (*beta_funcs[])(int32_t*, const int32_t*, float, size_t, size_t, size_t) = {
|
||||
beta_cf_s8, beta_cc_s8, beta_cr_s8, beta_rf_s8, beta_rc_s8, beta_rr_s8,
|
||||
beta_cf_s8_opt, beta_cc_s8_opt, beta_cr_s8_opt, beta_rf_s8_opt, beta_rc_s8_opt, beta_rr_s8_opt
|
||||
};
|
||||
|
||||
static void (*post_op_kernels[])(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc) = {
|
||||
post_ops, post_ops_opt
|
||||
};
|
||||
|
||||
static size_t row_major_idx(size_t m, size_t n, size_t ld) {
|
||||
return ld * m + n;
|
||||
}
|
||||
|
||||
static size_t col_major_idx(size_t m, size_t n, size_t ld) {
|
||||
return m + ld * n;
|
||||
}
|
||||
|
||||
static size_t (*compute_idx[])(size_t m, size_t n, size_t ld) = {
|
||||
col_major_idx,
|
||||
row_major_idx
|
||||
};
|
||||
|
||||
static size_t mov_oc_fix(size_t mi, size_t ni) {
|
||||
UNUSED(mi);
|
||||
UNUSED(ni);
|
||||
return 0;
|
||||
}
|
||||
static size_t mov_oc_col(size_t mi, size_t ni){
|
||||
UNUSED(ni);
|
||||
return mi;
|
||||
}
|
||||
|
||||
static size_t mov_oc_row(size_t mi, size_t ni) {
|
||||
UNUSED(mi);
|
||||
return ni;
|
||||
}
|
||||
|
||||
static size_t (*move_oc[])(size_t, size_t) = {
|
||||
mov_oc_fix, mov_oc_col, mov_oc_row
|
||||
};
|
||||
|
||||
EXTERNAL_API void ADD_TYPES_SUFF(prefill_cblas_gemm)(
|
||||
const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,
|
||||
const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const void *a, const size_t lda, const BLASINT8 oa,
|
||||
const void *b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t *c, const size_t ldc, const int32_t *oc) {
|
||||
|
||||
int opt_offset = ((alpha == 1.0f) && (beta == 0.0f || beta == 1.0f)) ? 1 : 0;
|
||||
if(Layout == CblasColMajor) {
|
||||
int beta_offset = (offsetc == CblasFixOffset) ? 0 : (offsetc == CblasColOffset ? 1:2);
|
||||
int_gemm_funcs arg = {
|
||||
small_kernels[(transb == CblasTrans) + 2 * (transa == CblasTrans) + beta_offset * 4],
|
||||
gemm_kernels,
|
||||
pack_a_funs[transa == CblasTrans],
|
||||
pack_b_funs[transb == CblasTrans],
|
||||
beta_funcs[beta_offset + opt_offset * 6],
|
||||
post_op_kernels[alpha == 1],
|
||||
compute_idx[transa == CblasTrans],
|
||||
compute_idx[transb == CblasTrans],
|
||||
move_oc[beta_offset],
|
||||
};
|
||||
(gemm_impl_8bit(&arg, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc, 262144));
|
||||
} else if (Layout == CblasRowMajor) {
|
||||
int beta_offset = (offsetc == CblasFixOffset) ? 3 : (offsetc == CblasColOffset ? 4 : 5);
|
||||
int beta_offset_small = (offsetc == CblasFixOffset) ? 0 : (offsetc == CblasColOffset ? 2 : 1);
|
||||
int_gemm_funcs arg = {
|
||||
small_kernels[(transa == CblasTrans) + 2 * (transb == CblasTrans) + beta_offset_small * 4],
|
||||
gemm_kernels,
|
||||
pack_a_funs[transb == CblasTrans],
|
||||
pack_b_funs[transa == CblasTrans],
|
||||
beta_funcs[beta_offset + opt_offset * 6],
|
||||
post_op_kernels[alpha == 1],
|
||||
compute_idx[transb == CblasTrans],
|
||||
compute_idx[transa == CblasTrans],
|
||||
move_oc[beta_offset_small]
|
||||
};
|
||||
(gemm_impl_8bit(&arg, n, m, k, alpha, b, ldb, ob, a, lda, oa, beta, c, ldc, oc, 262144));
|
||||
}
|
||||
else {
|
||||
printf("Incorrect layout");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,453 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define ADD_SUFFIX(name) ADD_M_N_SIZES_MACRO(ADD_TYPES_MACRO(name, LHS_TYPE, RHS_TYPE), M_SIZE, N_SIZE)
|
||||
|
||||
#define LD1B_PTR(reg_name, p, ptr, idx) "ld1b {" #reg_name ".b}, " #p "/z, [%[" #ptr "], #" #idx ", MUL VL]\n"
|
||||
#define COMPUTE_ADDP(out, in1, in2) "addp " #out ".s, " #in1 ".s, " #in2 ".s\n"
|
||||
#if (defined(LHS_INT) && defined(RHS_INT)) || (defined(LHS_UINT) && defined(RHS_UINT))
|
||||
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) #lhs_type "dot " #out ".s, " #in1 ".b, " #in2 ".b\n"
|
||||
#else
|
||||
#ifdef LHS_INT
|
||||
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in2 ".b, " #in1 ".b\n"
|
||||
#else
|
||||
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in1 ".b, " #in2 ".b\n"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE) COMPUTE_DOT_TYPED(out, in1, in2, LHS_TYPE, RHS_TYPE)
|
||||
#define COMPUTE_DOT(out, in1, in2) COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE)
|
||||
|
||||
#if (N_SIZE > 4)
|
||||
#error "N_SIZE can't be greater than 4"
|
||||
#endif
|
||||
|
||||
#if (M_SIZE > 4)
|
||||
#error "M_SIZE can't be greater than 4"
|
||||
#endif
|
||||
|
||||
#define LOAD_Z0(p, ptr) LD1B_PTR(z0, p, ptr, 0)
|
||||
#define LOAD_Z8(p, ptr) LD1B_PTR(z8, p, ptr, 0)
|
||||
|
||||
#if (N_SIZE > 1)
|
||||
#define LOAD_Z1(p, ptr) LD1B_PTR(z1, p, ptr, 1)
|
||||
#define LOAD_Z9(p, ptr) LD1B_PTR(z9, p, ptr, 1)
|
||||
#else
|
||||
#define LOAD_Z1(p, ptr)
|
||||
#define LOAD_Z9(p, ptr)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 2)
|
||||
#define LOAD_Z2(p, ptr) LD1B_PTR(z2, p, ptr, 2)
|
||||
#define LOAD_Z10(p, ptr) LD1B_PTR(z10, p, ptr, 2)
|
||||
#else
|
||||
#define LOAD_Z2(p, ptr)
|
||||
#define LOAD_Z10(p, ptr)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 3)
|
||||
#define LOAD_Z3(p, ptr) LD1B_PTR(z3, p, ptr, 3)
|
||||
#define LOAD_Z11(p, ptr) LD1B_PTR(z11, p, ptr, 3)
|
||||
#else
|
||||
#define LOAD_Z3(p, ptr)
|
||||
#define LOAD_Z11(p, ptr)
|
||||
#endif
|
||||
|
||||
#define LOAD_Z4(p, ptr) LD1B_PTR(z4, p, ptr, 0)
|
||||
#define LOAD_Z12(p, ptr) LD1B_PTR(z12, p, ptr, 0)
|
||||
|
||||
#if (M_SIZE > 1)
|
||||
#define LOAD_Z5(p, ptr) LD1B_PTR(z5, p, ptr, 1)
|
||||
#define LOAD_Z13(p, ptr) LD1B_PTR(z13, p, ptr, 1)
|
||||
#else
|
||||
#define LOAD_Z5(p, ptr)
|
||||
#define LOAD_Z13(p, ptr)
|
||||
#endif
|
||||
|
||||
#if (M_SIZE > 2)
|
||||
#define LOAD_Z6(p, ptr) LD1B_PTR(z6, p, ptr, 2)
|
||||
#define LOAD_Z14(p, ptr) LD1B_PTR(z14, p, ptr, 2)
|
||||
#else
|
||||
#define LOAD_Z6(p, ptr)
|
||||
#define LOAD_Z14(p, ptr)
|
||||
#endif
|
||||
|
||||
#if (M_SIZE > 3)
|
||||
#define LOAD_Z7(p, ptr) LD1B_PTR(z7, p, ptr, 3)
|
||||
#define LOAD_Z15(p, ptr) LD1B_PTR(z15, p, ptr, 3)
|
||||
#else
|
||||
#define LOAD_Z7(p, ptr)
|
||||
#define LOAD_Z15(p, ptr)
|
||||
#endif
|
||||
|
||||
// macros for dot multiplication
|
||||
#define ACCUMULATE_Z16(lhs, rhs) COMPUTE_DOT(z16, lhs, rhs)
|
||||
|
||||
#if (N_SIZE > 1)
|
||||
#define ACCUMULATE_Z17(lhs, rhs) COMPUTE_DOT(z17, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z17(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 2)
|
||||
#define ACCUMULATE_Z18(lhs, rhs) COMPUTE_DOT(z18, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z18(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 3)
|
||||
#define ACCUMULATE_Z19(lhs, rhs) COMPUTE_DOT(z19, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z19(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (M_SIZE > 1)
|
||||
#define ACCUMULATE_Z20(lhs, rhs) COMPUTE_DOT(z20, lhs, rhs)
|
||||
#if (N_SIZE > 1)
|
||||
#define ACCUMULATE_Z21(lhs, rhs) COMPUTE_DOT(z21, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z21(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 2)
|
||||
#define ACCUMULATE_Z22(lhs, rhs) COMPUTE_DOT(z22, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z22(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 3)
|
||||
#define ACCUMULATE_Z23(lhs, rhs) COMPUTE_DOT(z23, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z23(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#else
|
||||
#define ACCUMULATE_Z20(lhs, rhs)
|
||||
#define ACCUMULATE_Z21(lhs, rhs)
|
||||
#define ACCUMULATE_Z22(lhs, rhs)
|
||||
#define ACCUMULATE_Z23(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (M_SIZE > 2)
|
||||
#define ACCUMULATE_Z24(lhs, rhs) COMPUTE_DOT(z24, lhs, rhs)
|
||||
|
||||
#if (N_SIZE > 1)
|
||||
#define ACCUMULATE_Z25(lhs, rhs) COMPUTE_DOT(z25, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z25(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 2)
|
||||
#define ACCUMULATE_Z26(lhs, rhs) COMPUTE_DOT(z26, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z26(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 3)
|
||||
#define ACCUMULATE_Z27(lhs, rhs) COMPUTE_DOT(z27, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z27(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#else
|
||||
#define ACCUMULATE_Z24(lhs, rhs)
|
||||
#define ACCUMULATE_Z25(lhs, rhs)
|
||||
#define ACCUMULATE_Z26(lhs, rhs)
|
||||
#define ACCUMULATE_Z27(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (M_SIZE > 3)
|
||||
#define ACCUMULATE_Z28(lhs, rhs) COMPUTE_DOT(z28, lhs, rhs)
|
||||
|
||||
#if (N_SIZE > 1)
|
||||
#define ACCUMULATE_Z29(lhs, rhs) COMPUTE_DOT(z29, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z29(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 2)
|
||||
#define ACCUMULATE_Z30(lhs, rhs) COMPUTE_DOT(z30, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z30(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 3)
|
||||
#define ACCUMULATE_Z31(lhs, rhs) COMPUTE_DOT(z31, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z31(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#else
|
||||
#define ACCUMULATE_Z28(lhs, rhs)
|
||||
#define ACCUMULATE_Z29(lhs, rhs)
|
||||
#define ACCUMULATE_Z30(lhs, rhs)
|
||||
#define ACCUMULATE_Z31(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#define MOVE_LHS_PTR(ptr) "add %[" #ptr "], %[" #ptr "], %[move_lhs]\n"
|
||||
#define MOVE_RHS_PTR(ptr) "add %[" #ptr "], %[" #ptr "], %[move_rhs]\n"
|
||||
|
||||
#define PROCESS_ACCUM(reg_idx, z_reg_idx, tmp_reg, dst, p) \
|
||||
"ldr w" #reg_idx ", [%[" #dst "]]\n" \
|
||||
"saddv d" #reg_idx ", " #p ", z" #z_reg_idx ".s\n" \
|
||||
"fmov " #tmp_reg ", d" #reg_idx "\n" \
|
||||
"add x" #reg_idx ", x" #reg_idx ", " #tmp_reg "\n" \
|
||||
"str w" #reg_idx ", [%[" #dst "]], #4\n"
|
||||
|
||||
// function logic
|
||||
void ADD_SUFFIX(gemm_kernel)(const void *lhs_ptr, const void *rhs_ptr,
|
||||
int32_t *accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len) {
|
||||
int64_t run_k_depth = k_depth;
|
||||
int64_t run_sv_len = sv_len;
|
||||
int64_t run_2sv_len = 2 * sv_len;
|
||||
int64_t move_lhs = M_SIZE * sv_len;
|
||||
int64_t move_rhs = N_SIZE * sv_len;
|
||||
int32_t* dst_ptr = accum_ptr;
|
||||
ldc -= M_SIZE;
|
||||
ldc *= 4;
|
||||
asm volatile(
|
||||
// predicate for operating on lhs and rhs is always true
|
||||
"ptrue p0.b, all\n"
|
||||
// Clear accumulators
|
||||
LOAD_Z0(p0, rhs_ptr)
|
||||
"dup z16.s, #0\n"
|
||||
LOAD_Z1(p0, rhs_ptr)
|
||||
"dup z17.s, #0\n"
|
||||
LOAD_Z4(p0, lhs_ptr)
|
||||
"dup z18.s, #0\n"
|
||||
LOAD_Z5(p0, lhs_ptr)
|
||||
"dup z19.s, #0\n"
|
||||
LOAD_Z6(p0, lhs_ptr)
|
||||
"dup z20.s, #0\n"
|
||||
LOAD_Z7(p0, lhs_ptr)
|
||||
"dup z21.s, #0\n"
|
||||
LOAD_Z2(p0, rhs_ptr)
|
||||
"dup z22.s, #0\n"
|
||||
LOAD_Z3(p0, rhs_ptr)
|
||||
"dup z23.s, #0\n"
|
||||
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
|
||||
"dup z24.s, #0\n"
|
||||
"mov x16, %[dst_ptr]\n"
|
||||
"dup z25.s, #0\n"
|
||||
"dup z26.s, #0\n"
|
||||
"dup z27.s, #0\n"
|
||||
MOVE_LHS_PTR(lhs_ptr)
|
||||
"dup z28.s, #0\n"
|
||||
MOVE_RHS_PTR(rhs_ptr)
|
||||
"dup z29.s, #0\n"
|
||||
"dup z30.s, #0\n"
|
||||
"dup z31.s, #0\n"
|
||||
|
||||
"ble 1f\n"
|
||||
|
||||
"cmp %[run_k_depth], %[run_2sv_len]\n"
|
||||
"blt 2f\n"
|
||||
|
||||
"3:\n"
|
||||
LOAD_Z12(p0, lhs_ptr)
|
||||
ACCUMULATE_Z16(z4,z0)
|
||||
ACCUMULATE_Z17(z4,z1)
|
||||
LOAD_Z13(p0, lhs_ptr)
|
||||
ACCUMULATE_Z18(z4,z2)
|
||||
ACCUMULATE_Z19(z4,z3)
|
||||
LOAD_Z8(p0, rhs_ptr)
|
||||
ACCUMULATE_Z20(z5,z0)
|
||||
ACCUMULATE_Z21(z5,z1)
|
||||
LOAD_Z9(p0, rhs_ptr)
|
||||
ACCUMULATE_Z22(z5,z2)
|
||||
ACCUMULATE_Z23(z5,z3)
|
||||
LOAD_Z10(p0,rhs_ptr)
|
||||
ACCUMULATE_Z24(z6,z0)
|
||||
ACCUMULATE_Z25(z6,z1)
|
||||
LOAD_Z11(p0,rhs_ptr)
|
||||
ACCUMULATE_Z26(z6,z2)
|
||||
MOVE_RHS_PTR(rhs_ptr)
|
||||
"prfw pldl1keep, p0, [%[rhs_ptr], #4, MUL VL]\n"
|
||||
ACCUMULATE_Z27(z6,z3)
|
||||
LOAD_Z14(p0, lhs_ptr)
|
||||
ACCUMULATE_Z28(z7,z0)
|
||||
ACCUMULATE_Z29(z7,z1)
|
||||
LOAD_Z15(p0,lhs_ptr)
|
||||
ACCUMULATE_Z30(z7,z2)
|
||||
MOVE_LHS_PTR(lhs_ptr)
|
||||
"prfw pldl1keep, p0, [%[lhs_ptr], #4, MUL VL]\n"
|
||||
ACCUMULATE_Z31(z7,z3)
|
||||
|
||||
LOAD_Z4(p0, lhs_ptr)
|
||||
ACCUMULATE_Z16(z12,z8)
|
||||
ACCUMULATE_Z17(z12,z9)
|
||||
LOAD_Z5(p0, lhs_ptr)
|
||||
ACCUMULATE_Z18(z12,z10)
|
||||
ACCUMULATE_Z19(z12,z11)
|
||||
LOAD_Z6(p0, lhs_ptr)
|
||||
ACCUMULATE_Z20(z13,z8)
|
||||
ACCUMULATE_Z21(z13,z9)
|
||||
LOAD_Z0(p0, rhs_ptr)
|
||||
"sub %[run_k_depth], %[run_k_depth], %[run_2sv_len]\n"
|
||||
ACCUMULATE_Z22(z13,z10)
|
||||
ACCUMULATE_Z23(z13,z11)
|
||||
LOAD_Z1(p0, rhs_ptr)
|
||||
ACCUMULATE_Z24(z14,z8)
|
||||
ACCUMULATE_Z25(z14,z9)
|
||||
LOAD_Z2(p0,rhs_ptr)
|
||||
ACCUMULATE_Z26(z14,z10)
|
||||
ACCUMULATE_Z27(z14,z11)
|
||||
LOAD_Z3(p0, rhs_ptr)
|
||||
ACCUMULATE_Z28(z15,z8)
|
||||
MOVE_RHS_PTR(rhs_ptr)
|
||||
"prfw pldl1keep, p0, [%[rhs_ptr], #4, MUL VL]\n"
|
||||
ACCUMULATE_Z29(z15,z9)
|
||||
LOAD_Z7(p0, lhs_ptr)
|
||||
"cmp %[run_k_depth], %[run_2sv_len]\n"
|
||||
ACCUMULATE_Z30(z15, z10)
|
||||
MOVE_LHS_PTR(lhs_ptr)
|
||||
"prfw pldl1keep, p0, [%[lhs_ptr], #4, MUL VL]\n"
|
||||
ACCUMULATE_Z31(z15,z11)
|
||||
"bge 3b\n"
|
||||
|
||||
"cmp %[run_k_depth], #0\n"
|
||||
"ble 1f\n"
|
||||
|
||||
"2:\n"
|
||||
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
|
||||
ACCUMULATE_Z16(z4,z0)
|
||||
ACCUMULATE_Z17(z4,z1)
|
||||
ACCUMULATE_Z18(z4,z2)
|
||||
ACCUMULATE_Z19(z4,z3)
|
||||
LOAD_Z4(p0,lhs_ptr)
|
||||
ACCUMULATE_Z20(z5,z0)
|
||||
ACCUMULATE_Z21(z5,z1)
|
||||
ACCUMULATE_Z22(z5,z2)
|
||||
ACCUMULATE_Z23(z5,z3)
|
||||
LOAD_Z5(p0,lhs_ptr)
|
||||
ACCUMULATE_Z24(z6,z0)
|
||||
ACCUMULATE_Z25(z6,z1)
|
||||
ACCUMULATE_Z26(z6,z2)
|
||||
ACCUMULATE_Z27(z6,z3)
|
||||
LOAD_Z6(p0,lhs_ptr)
|
||||
ACCUMULATE_Z28(z7,z0)
|
||||
LOAD_Z0(p0,rhs_ptr)
|
||||
ACCUMULATE_Z29(z7,z1)
|
||||
LOAD_Z1(p0,rhs_ptr)
|
||||
ACCUMULATE_Z30(z7,z2)
|
||||
LOAD_Z2(p0,rhs_ptr)
|
||||
ACCUMULATE_Z31(z7,z3)
|
||||
LOAD_Z3(p0,rhs_ptr)
|
||||
MOVE_RHS_PTR(rhs_ptr)
|
||||
LOAD_Z7(p0,lhs_ptr)
|
||||
MOVE_LHS_PTR(lhs_ptr)
|
||||
"bgt 2b\n"
|
||||
|
||||
"1:\n"
|
||||
ACCUMULATE_Z16(z4,z0)
|
||||
ACCUMULATE_Z17(z4,z1)
|
||||
ACCUMULATE_Z18(z4,z2)
|
||||
ACCUMULATE_Z19(z4,z3)
|
||||
ACCUMULATE_Z20(z5,z0)
|
||||
ACCUMULATE_Z21(z5,z1)
|
||||
ACCUMULATE_Z22(z5,z2)
|
||||
ACCUMULATE_Z23(z5,z3)
|
||||
ACCUMULATE_Z24(z6,z0)
|
||||
ACCUMULATE_Z25(z6,z1)
|
||||
ACCUMULATE_Z26(z6,z2)
|
||||
ACCUMULATE_Z27(z6,z3)
|
||||
ACCUMULATE_Z28(z7,z0)
|
||||
ACCUMULATE_Z29(z7,z1)
|
||||
ACCUMULATE_Z30(z7,z2)
|
||||
ACCUMULATE_Z31(z7,z3)
|
||||
|
||||
#if (N_SIZE > 0)
|
||||
#if (M_SIZE > 0)
|
||||
PROCESS_ACCUM(0, 16, x16, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 1)
|
||||
PROCESS_ACCUM(4, 20, x17, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 2)
|
||||
PROCESS_ACCUM(8, 24, x18, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 3)
|
||||
PROCESS_ACCUM(12, 28, x17, dst_ptr, p0)
|
||||
#endif
|
||||
#endif
|
||||
"add %[dst_ptr], %[dst_ptr], %[ldc]\n"
|
||||
|
||||
#if (N_SIZE > 1)
|
||||
#if (M_SIZE > 0)
|
||||
PROCESS_ACCUM(1, 17, x16, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 1)
|
||||
PROCESS_ACCUM(5, 21, x17,dst_ptr,p0)
|
||||
#endif
|
||||
#if (M_SIZE > 2)
|
||||
PROCESS_ACCUM(9, 25, x18, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 3)
|
||||
PROCESS_ACCUM(13, 29, x17, dst_ptr, p0)
|
||||
#endif
|
||||
#endif
|
||||
"add %[dst_ptr], %[dst_ptr], %[ldc]\n"
|
||||
|
||||
|
||||
#if (N_SIZE > 2)
|
||||
#if (M_SIZE > 0)
|
||||
PROCESS_ACCUM(2, 18, x16, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 1)
|
||||
PROCESS_ACCUM(6, 22, x17,dst_ptr,p0)
|
||||
#endif
|
||||
#if (M_SIZE > 2)
|
||||
PROCESS_ACCUM(10,26,x18,dst_ptr,p0)
|
||||
#endif
|
||||
#if (M_SIZE > 3)
|
||||
PROCESS_ACCUM(14,30,x17,dst_ptr,p0)
|
||||
#endif
|
||||
#endif
|
||||
"add %[dst_ptr], %[dst_ptr], %[ldc]\n"
|
||||
|
||||
|
||||
#if (N_SIZE > 3)
|
||||
#if (M_SIZE > 0)
|
||||
PROCESS_ACCUM(3, 19, x16, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 1)
|
||||
PROCESS_ACCUM(7, 23, x17,dst_ptr,p0)
|
||||
#endif
|
||||
#if (M_SIZE > 2)
|
||||
PROCESS_ACCUM(11,27,x18,dst_ptr,p0)
|
||||
#endif
|
||||
#if (M_SIZE > 3)
|
||||
PROCESS_ACCUM(15,31,x17,dst_ptr,p0)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
:
|
||||
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
|
||||
[run_k_depth] "+r"(run_k_depth),
|
||||
[dst_ptr] "+wr"(dst_ptr)
|
||||
:
|
||||
[run_sv_len] "r"(run_sv_len), [run_2sv_len] "r"(run_2sv_len),
|
||||
[move_lhs] "r"(move_lhs), [move_rhs] "r"(move_rhs), [ldc] "r"(ldc),
|
||||
[accum_ptr] "r"(accum_ptr)
|
||||
:
|
||||
"cc", "memory",
|
||||
"w0","w1","w2","w3","w4","w5","w6","w7",
|
||||
"w8","w9","w10","w11","w12","w13","w14","w15",
|
||||
"x16","x17","x18","x19",
|
||||
"z0","z1","z2","z3","z4","z5","z6","z7",
|
||||
"z8","z9","z10","z11","z12","z13","z14","z15",
|
||||
"z16","z17","z18","z19","z20","z21","z22","z23",
|
||||
"z24","z25","z26","z27","z28","z29","z30","z31"
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,468 +0,0 @@
|
||||
#ifndef __GEMM_INTEGER_KERNELS_H__
|
||||
#define __GEMM_INTEGER_KERNELS_H__
|
||||
|
||||
#include <stdint.h>
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
typedef int8_t BLASINT8;
|
||||
typedef uint8_t BLASUINT8;
|
||||
|
||||
typedef struct {
|
||||
void (*small_kernel)(const size_t, const size_t, const size_t, const float, const void*, const size_t, const BLASINT8,
|
||||
const void*, const size_t, const BLASINT8, const float, int32_t*, const size_t, const int32_t*);
|
||||
void (**gemm_kernels)(const void*, const void*, int32_t*, size_t, int64_t, int64_t);
|
||||
void (*pack_a_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8);
|
||||
void (*pack_b_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8);
|
||||
void (*beta_func)(int32_t*, const int32_t*, float, size_t, size_t, size_t);
|
||||
void (*post_ops_func)(float, const int32_t*, int32_t*, size_t, size_t, size_t);
|
||||
size_t (*a_indexing)(size_t, size_t, size_t);
|
||||
size_t (*b_indexing)(size_t, size_t, size_t);
|
||||
size_t (*move_oc)(size_t, size_t);
|
||||
|
||||
} int_gemm_funcs;
|
||||
|
||||
#ifndef COMP_SV_LEN
|
||||
#error "COMP_SV_LEN is not defined"
|
||||
#endif
|
||||
|
||||
#define KERNEL_M_STEP 4
|
||||
#define KERNEL_N_STEP 4
|
||||
#define KERNEL_K_STEP COMP_SV_LEN
|
||||
|
||||
#define M_BLOCK 256
|
||||
#if ((M_BLOCK % KERNEL_M_STEP) != 0)
|
||||
#error "M_BLOCK % KERNEL_M_STEP != 0"
|
||||
#endif
|
||||
#define N_BLOCK 256
|
||||
#if ((N_BLOCK % KERNEL_N_STEP) != 0)
|
||||
#error "N_BLOCK % KERNEL_N_STEP != 0"
|
||||
#endif
|
||||
#define K_BLOCK 512
|
||||
#if ((K_BLOCK % KERNEL_K_STEP) != 0)
|
||||
#error "K_BLOCK % KERNEL_K_STEP != 0"
|
||||
#endif
|
||||
|
||||
#define ALIGNMENT 4096
|
||||
|
||||
#define EXTERNAL_API __attribute__((visibility("default")))
|
||||
#define UNUSED(arg) ((void)(arg))
|
||||
|
||||
// general pipeline
|
||||
void gemm_impl_8bit(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha, const void* a, size_t lda,
|
||||
const BLASINT8 oa, const void* b, size_t ldb, const BLASINT8 ob, float beta, int32_t* c, size_t ldc,
|
||||
const int32_t* oc, size_t small_switch);
|
||||
// s8 kernel
|
||||
void pack_b_s8_n(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
|
||||
const BLASINT8 ob);
|
||||
|
||||
void pack_a_s8_n(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
|
||||
const BLASINT8 oa);
|
||||
|
||||
void pack_b_s8_t(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
|
||||
const BLASINT8 ob);
|
||||
|
||||
void pack_a_s8_t(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
|
||||
const BLASINT8 oa);
|
||||
|
||||
// u8 kernels
|
||||
void pack_b_u8_n(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
|
||||
const BLASINT8 ob);
|
||||
|
||||
void pack_a_u8_n(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
|
||||
const BLASINT8 oa);
|
||||
|
||||
void pack_b_u8_t(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
|
||||
const BLASINT8 ob);
|
||||
|
||||
void pack_a_u8_t(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
|
||||
const BLASINT8 oa);
|
||||
|
||||
// beta kernels
|
||||
void beta_cf_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
|
||||
void beta_cc_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
|
||||
void beta_cr_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
|
||||
void beta_rf_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
|
||||
void beta_rc_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
|
||||
void beta_rr_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
|
||||
|
||||
void beta_cf_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
|
||||
size_t ldc);
|
||||
void beta_cc_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
|
||||
size_t ldc);
|
||||
void beta_cr_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
|
||||
size_t ldc);
|
||||
void beta_rf_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
|
||||
size_t ldc);
|
||||
void beta_rc_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
|
||||
size_t ldc);
|
||||
void beta_rr_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
|
||||
size_t ldc);
|
||||
|
||||
// post-ops kernels
|
||||
void post_ops(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc);
|
||||
void post_ops_opt(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc);
|
||||
|
||||
// drivers
|
||||
void gemm_driver(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha, const void* a, size_t lda,
|
||||
const BLASINT8 oa, const void* b, size_t ldb, const BLASINT8 ob, float beta, int32_t* c, size_t ldc,
|
||||
const int32_t* oc);
|
||||
void gemm_driver_opt(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha, const void* a, size_t lda,
|
||||
const BLASINT8 oa, const void* b, size_t ldb, const BLASINT8 ob, float beta, int32_t* c,
|
||||
size_t ldc, const int32_t* oc);
|
||||
|
||||
// matrix multiplication kernels
|
||||
|
||||
// s8s8s32 kernels
|
||||
void gemm_kernel_s8s8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_s8s8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_s8s8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_s8s8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
// u8u8s32 kernels
|
||||
|
||||
void gemm_kernel_u8u8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_u8u8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_u8u8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_u8u8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
// s8u8s32 kernels
|
||||
void gemm_kernel_s8u8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_s8u8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_s8u8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_s8u8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
// u8s8s32 kernels
|
||||
void gemm_kernel_u8s8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_u8s8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_u8s8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_u8s8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
// small kernels
|
||||
// s8s8s32 kernels
|
||||
void small_kernel_s8s8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
// s8u8s32 kernels
|
||||
|
||||
void small_kernel_s8u8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
// u8s8s32 kernels
|
||||
void small_kernel_u8s8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
// u8u8s32 kernels
|
||||
void small_kernel_u8u8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
@@ -1,158 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#if defined(OC_FIX)
|
||||
#define OC_TYPE f
|
||||
#define OC_IDX(mi, nn) 0
|
||||
#elif defined(OC_COL)
|
||||
#define OC_TYPE c
|
||||
#define OC_IDX(mi, ni) mi
|
||||
#else // OC_ROW
|
||||
#define OC_TYPE r
|
||||
#define OC_IDX(mi, ni) ni
|
||||
#endif // OC_T
|
||||
|
||||
#if defined(TRANSA)
|
||||
#if defined TRANSB
|
||||
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, t, t, OC_TYPE)
|
||||
#elif defined(NOTRANSB)
|
||||
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, t, n, OC_TYPE)
|
||||
#else
|
||||
#error "Neither TRANSB or NOTRANSB is defined"
|
||||
#endif
|
||||
#elif defined(NOTRANSA)
|
||||
#if defined TRANSB
|
||||
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, n, t, OC_TYPE)
|
||||
#elif defined(NOTRANSB)
|
||||
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, n, n, OC_TYPE)
|
||||
#else
|
||||
#error "Neither TRANSB or NOTRANSB is defined"
|
||||
#endif
|
||||
#else
|
||||
#error "Neither TRANSA or NOTRANSA is defined"
|
||||
#endif
|
||||
|
||||
#if (defined (LHS_INT) && defined(RHS_INT)) || (defined (LHS_UINT) && defined(RHS_UINT))
|
||||
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) #lhs_type "dot " #out ".s, " #in1 ".b, " #in2 ".b\n"
|
||||
#else
|
||||
#if defined(LHS_INT)
|
||||
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in2 ".b," #in1 ".b\n"
|
||||
#else // LHS_UINT
|
||||
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in1 ".b," #in2 ".b\n"
|
||||
#endif // LHS_INT
|
||||
#endif // LHS_INT
|
||||
|
||||
#define COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE) COMPUTE_DOT_TYPED(out, in1, in2, LHS_TYPE, RHS_TYPE)
|
||||
#define COMPUTE_DOT(out, in1, in2) COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE)
|
||||
|
||||
static inline double compute_dot(size_t k, const void *a, const BLASINT8* oa,
|
||||
const void *b, const BLASINT8* ob, int64_t sv_len) {
|
||||
int32_t accum = 0;
|
||||
int64_t run_k_depth = k;
|
||||
int64_t run_sv_len = sv_len;
|
||||
const void* lhs_ptr = a;
|
||||
const void* rhs_ptr = b;
|
||||
asm volatile(
|
||||
"dup z4.s, #0\n"
|
||||
"ptrue p0.b, all\n"
|
||||
"ld1b {z0.b}, p0/z, [%[oa]]\n"
|
||||
"ld1b {z1.b}, p0/z, [%[ob]]\n"
|
||||
"1:\n"
|
||||
"whilelt p1.b, xzr, %[run_k_depth]\n"
|
||||
"ld1b {z2.b}, p1/z, [%[lhs_ptr]]\n"
|
||||
"ld1b {z3.b}, p1/z, [%[rhs_ptr]]\n"
|
||||
"add z2.b, p1/m, z2.b, z0.b\n"
|
||||
"add z3.b, p1/m, z3.b, z1.b\n"
|
||||
"add %[lhs_ptr], %[lhs_ptr], %[run_sv_len]\n"
|
||||
"add %[rhs_ptr], %[rhs_ptr], %[run_sv_len]\n"
|
||||
COMPUTE_DOT(z4, z2, z3)
|
||||
"subs %[run_k_depth], %[run_k_depth], #1\n"
|
||||
"bgt 1b\n"
|
||||
"ptrue p2.s, all\n"
|
||||
"saddv d1, p2, z4.s\n"
|
||||
"fmov x2, d1\n"
|
||||
"add %[accum], %[accum], x2\n"
|
||||
: // outputs
|
||||
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
|
||||
[run_k_depth] "+wr"(run_k_depth),
|
||||
[accum] "+wr"(accum)
|
||||
: // inputs
|
||||
[run_sv_len] "r"(run_sv_len), [oa] "r"(oa), [ob] "r"(ob)
|
||||
: // clobbers
|
||||
"cc", "memory",
|
||||
"d1", "x2",
|
||||
"z0", "z1", "z2", "z3", "z4", "p0", "p1", "p2"
|
||||
);
|
||||
return (double) accum;
|
||||
}
|
||||
|
||||
#if !defined(TRANSA) || defined(TRANSB)
|
||||
// performs transposition (n0 * n1) -> (n1 * n0) assuming col major
|
||||
static inline void simplest_transpose(const void *in, void *out, size_t n0, size_t ld0, size_t n1) {
|
||||
// since we care only about size, we can use signed type always
|
||||
BLASINT8* typed_in = (BLASINT8*) in;
|
||||
BLASINT8* typed_out = (BLASINT8*) out;
|
||||
for (size_t i = 0; i < n1; ++i) {
|
||||
for (size_t j = 0; j < n0; ++j) {
|
||||
typed_out[i + j * n1] = typed_in[j + i * ld0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // !defined(TRANSA) || defined(TRANSB)
|
||||
|
||||
// A in row-major, B in col-major
|
||||
void ADD_SUFFIX(small_kernel)(const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const void *a, const size_t lda, const BLASINT8 oa,
|
||||
const void *b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t *c, const size_t ldc, const int32_t *oc) {
|
||||
double double_alpha = (double) alpha;
|
||||
// we use typed pointers only for indexing, so we don't care about signess
|
||||
#ifdef TRANSA
|
||||
BLASINT8* a_typed = (BLASINT8*) a;
|
||||
const size_t used_lda = lda;
|
||||
#else
|
||||
BLASINT8* a_typed = (BLASINT8*) aligned_alloc(128, sizeof(BLASINT8) * m * k);
|
||||
simplest_transpose(a, a_typed, m, lda, k);
|
||||
const size_t used_lda = k;
|
||||
#endif // TRANSA
|
||||
#ifndef TRANSB
|
||||
BLASINT8* b_typed = (BLASINT8*) b;
|
||||
const size_t used_ldb = ldb;
|
||||
#else // TRANSB
|
||||
BLASINT8* b_typed = (BLASINT8*) aligned_alloc(128, sizeof(BLASINT8) * k * n);
|
||||
simplest_transpose(b, b_typed, n, ldb, k);
|
||||
const size_t used_ldb = k;
|
||||
#endif // TRANSB
|
||||
BLASINT8 oa_buf[KERNEL_K_STEP];
|
||||
BLASINT8 ob_buf[KERNEL_K_STEP];
|
||||
for (size_t i = 0; i < KERNEL_K_STEP; ++i) {
|
||||
oa_buf[i] = oa;
|
||||
ob_buf[i] = ob;
|
||||
}
|
||||
// printf("\n========\n");
|
||||
for (size_t mi = 0; mi < m; ++mi) {
|
||||
for (size_t ni = 0; ni < n; ++ni) {
|
||||
// printf("mi = %lu, ni = %lu, oc_idx = %lu\n", mi, ni, OC_IDX(mi, ni));
|
||||
double tmp = compute_dot(k, a_typed + mi * used_lda, oa_buf, b_typed + ni * used_ldb, ob_buf, KERNEL_K_STEP);
|
||||
c[mi + ni * ldc] = round(tmp * double_alpha + ((double)(beta * ((float)c[mi + ni * ldc]) + oc[OC_IDX(mi, ni)])));
|
||||
}
|
||||
}
|
||||
#ifdef TRANSA
|
||||
free(a_typed);
|
||||
#endif // TRANSA
|
||||
#ifdef TRANSB
|
||||
free(b_typed);
|
||||
#endif // TRANSB
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif /* __cplusplus */
|
||||
@@ -1,134 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#if defined(TRANSA)
|
||||
// row major
|
||||
#define INDEXING_A(row_idx, col_idx, lda) ((col_idx) * (lda) + row_idx)
|
||||
#define ADD_A_SUFF(name) ADD_PACK_A_T_SUFF(name)
|
||||
#elif defined(NOTRANSA)
|
||||
// col major
|
||||
#define INDEXING_A(row_idx, col_idx, lda) ((row_idx) * (lda) + col_idx)
|
||||
#define ADD_A_SUFF(name) ADD_PACK_A_N_SUFF(name)
|
||||
#else
|
||||
#error "Neither TRANSA or NOTRANSA is defined"
|
||||
#endif
|
||||
|
||||
#define OLD_N_SIZE 8
|
||||
#define NEW_N_SIZE 4
|
||||
|
||||
void ADD_A_SUFF(pack_a)(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda, const BLASINT8 oa) {
|
||||
LHS_INT_TYPE* bufferA_typed = (LHS_INT_TYPE*) bufferA;
|
||||
LHS_INT_TYPE* curr_a_ptr_typed = (LHS_INT_TYPE*) curr_a_ptr;
|
||||
|
||||
// printf("m_block_size:%lu ,k_block_size: %lu\n", m_block_size, k_block_size);
|
||||
|
||||
for(size_t old_split_n = 0; old_split_n < (m_block_size / OLD_N_SIZE); old_split_n++) {
|
||||
for(size_t split_k = 0; split_k < (k_block_size / KERNEL_K_STEP); split_k++) {
|
||||
for(size_t old_idx_n = 0; old_idx_n < OLD_N_SIZE; old_idx_n++) {
|
||||
for(size_t idx_k = 0; idx_k < KERNEL_K_STEP; idx_k++) {
|
||||
size_t n_idx = old_split_n * OLD_N_SIZE + old_idx_n;
|
||||
size_t new_split_n = n_idx / NEW_N_SIZE;
|
||||
size_t new_idx_n = n_idx % NEW_N_SIZE;
|
||||
|
||||
size_t old_buff_idx =
|
||||
old_split_n * OLD_N_SIZE * lda +
|
||||
split_k * OLD_N_SIZE * KERNEL_K_STEP +
|
||||
old_idx_n * KERNEL_K_STEP +
|
||||
idx_k;
|
||||
size_t new_buff_idx =
|
||||
new_split_n * NEW_N_SIZE * k_block_size +
|
||||
split_k * NEW_N_SIZE * KERNEL_K_STEP +
|
||||
new_idx_n * KERNEL_K_STEP +
|
||||
idx_k;
|
||||
bufferA_typed[new_buff_idx] = curr_a_ptr_typed[old_buff_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for(size_t n_idx = 0; n_idx < m_block_size; n_idx++) {
|
||||
// for(size_t k_idx = 0; k_idx < k_block_size; k_idx++) {
|
||||
// size_t old_split_n = n_idx / OLD_N_SIZE;
|
||||
// size_t old_idx_n = n_idx % OLD_N_SIZE;
|
||||
// size_t new_split_n = n_idx / NEW_N_SIZE;
|
||||
// size_t new_idx_n = n_idx % NEW_N_SIZE;
|
||||
// size_t split_k = k_idx / KERNEL_K_STEP;
|
||||
// size_t idx_k = k_idx % KERNEL_K_STEP;
|
||||
|
||||
// size_t old_buff_idx =
|
||||
// old_split_n * OLD_N_SIZE * lda +
|
||||
// split_k * OLD_N_SIZE * KERNEL_K_STEP +
|
||||
// old_idx_n * KERNEL_K_STEP +
|
||||
// idx_k;
|
||||
// size_t new_buff_idx =
|
||||
// new_split_n * NEW_N_SIZE * k_block_size +
|
||||
// split_k * NEW_N_SIZE * KERNEL_K_STEP +
|
||||
// new_idx_n * KERNEL_K_STEP +
|
||||
// idx_k;
|
||||
// bufferA_typed[new_buff_idx] = curr_a_ptr_typed[old_buff_idx] + oa;
|
||||
// }
|
||||
// }
|
||||
|
||||
// size_t k_block_size_up = (k_block_size + KERNEL_K_STEP - 1) / KERNEL_K_STEP * KERNEL_K_STEP;
|
||||
// size_t k_portions = k_block_size / KERNEL_K_STEP;
|
||||
// size_t k_resid = k_block_size - KERNEL_K_STEP * k_portions;
|
||||
|
||||
// size_t m_portions = m_block_size / KERNEL_M_STEP;
|
||||
// size_t m_resid = m_block_size - KERNEL_M_STEP * m_portions;
|
||||
|
||||
// for (size_t im4 = 0; im4 < m_portions; ++im4) {
|
||||
// for (size_t ik16 = 0; ik16 < k_portions; ++ik16) {
|
||||
// for (size_t im = 0; im < KERNEL_M_STEP; ++im) {
|
||||
// for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
|
||||
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * KERNEL_M_STEP * ik16 + k_block_size_up * KERNEL_M_STEP * im4] =
|
||||
// curr_a_ptr_typed[INDEXING_A((ik16 * KERNEL_K_STEP + ik), (im4 * KERNEL_M_STEP + im), lda)] + oa;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if (k_resid) {
|
||||
// for (size_t im = 0; im < KERNEL_M_STEP; ++im) {
|
||||
// for (size_t ik = 0; ik < k_resid; ++ik) {
|
||||
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * KERNEL_M_STEP * k_portions + k_block_size_up * KERNEL_M_STEP * im4] =
|
||||
// curr_a_ptr_typed[INDEXING_A((k_portions * KERNEL_K_STEP + ik), (im4 * KERNEL_M_STEP + im), lda)] + oa;
|
||||
// }
|
||||
// }
|
||||
// for (size_t im = 0; im < KERNEL_M_STEP; ++im) {
|
||||
// for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
|
||||
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * KERNEL_M_STEP * k_portions + k_block_size_up * KERNEL_M_STEP * im4] = 0;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if (m_resid) {
|
||||
// for (size_t ik16 = 0; ik16 < k_portions; ++ik16) {
|
||||
// for (size_t im = 0; im < m_resid; ++im) {
|
||||
// for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
|
||||
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * m_resid * ik16 + k_block_size_up * KERNEL_M_STEP * m_portions] =
|
||||
// curr_a_ptr_typed[INDEXING_A((ik16 * KERNEL_K_STEP + ik), (m_portions * KERNEL_M_STEP + im), lda)] + oa;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if (k_resid) {
|
||||
// for (size_t im = 0; im < m_resid; ++im) {
|
||||
// for (size_t ik = 0; ik < k_resid; ++ik) {
|
||||
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * m_resid * k_portions + k_block_size_up * KERNEL_M_STEP * m_portions] =
|
||||
// curr_a_ptr_typed[INDEXING_A((k_portions * KERNEL_K_STEP + ik), (m_portions * KERNEL_M_STEP + im), lda)] + oa;
|
||||
// }
|
||||
// }
|
||||
// for (size_t im = 0; im < m_resid; ++im) {
|
||||
// for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
|
||||
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * m_resid * k_portions + k_block_size_up * KERNEL_M_STEP * m_portions] = 0;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif /* __cplusplus */
|
||||
@@ -1,96 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#if defined(TRANSB)
|
||||
// row major
|
||||
#define INDEXING_B(row_idx, col_idx, ldb) ((col_idx) * (ldb) + row_idx)
|
||||
#define ADD_B_SUFF(name) ADD_PACK_B_T_SUFF(name)
|
||||
#elif defined(NOTRANSB)
|
||||
// col major
|
||||
#define INDEXING_B(row_idx, col_idx, ldb) ((row_idx) * (ldb) + col_idx)
|
||||
#define ADD_B_SUFF(name) ADD_PACK_B_N_SUFF(name)
|
||||
#else
|
||||
#error "Neither TRANSB or NOTRANSB is defined."
|
||||
#endif
|
||||
|
||||
void ADD_B_SUFF(pack_b)(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb, const BLASINT8 ob) {
|
||||
RHS_INT_TYPE* bufferB_typed = (RHS_INT_TYPE*) bufferB;
|
||||
RHS_INT_TYPE* curr_b_ptr_typed = (RHS_INT_TYPE*) curr_b_ptr;
|
||||
size_t k_block_size_up = (k_block_size + KERNEL_K_STEP - 1) / KERNEL_K_STEP * KERNEL_K_STEP;
|
||||
size_t k_portions = k_block_size / KERNEL_K_STEP;
|
||||
size_t k_resid = k_block_size - KERNEL_K_STEP * k_portions;
|
||||
|
||||
size_t n_portions = n_block_size / KERNEL_N_STEP;
|
||||
size_t n_resid = n_block_size - KERNEL_N_STEP * n_portions;
|
||||
|
||||
for (size_t in4 = 0; in4 < n_portions; ++in4) {
|
||||
for (size_t in = 0; in < KERNEL_N_STEP; ++in) {
|
||||
for (size_t ik16 = 0; ik16 < k_block_size / KERNEL_K_STEP; ++ik16) {
|
||||
for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
|
||||
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * KERNEL_N_STEP * ik16 + k_block_size_up * KERNEL_N_STEP * in4] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * in4 + in), (KERNEL_K_STEP * ik16 + ik), ldb)] + ob;
|
||||
}
|
||||
}
|
||||
|
||||
if (k_resid) {
|
||||
for (size_t ik = 0; ik < k_resid; ++ik) {
|
||||
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * KERNEL_N_STEP * k_portions + k_block_size_up * KERNEL_N_STEP * in4] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * in4 + in), (KERNEL_K_STEP * k_portions + ik), ldb)] + ob;
|
||||
}
|
||||
for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
|
||||
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * KERNEL_N_STEP * k_portions + k_block_size_up * KERNEL_N_STEP * in4] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (n_resid) {
|
||||
for (size_t in = 0; in < n_resid; ++in) {
|
||||
for (size_t ik16 = 0; ik16 < k_block_size / KERNEL_K_STEP; ++ik16) {
|
||||
for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
|
||||
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * n_resid * ik16 + k_block_size_up * KERNEL_N_STEP * n_portions] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * n_portions + in), (KERNEL_K_STEP * ik16 + ik), ldb)] + ob;
|
||||
}
|
||||
}
|
||||
if (k_resid) {
|
||||
for (size_t ik = 0; ik < k_resid; ++ik) {
|
||||
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * n_resid * k_portions + k_block_size_up * KERNEL_N_STEP * n_portions] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * n_portions + in), (KERNEL_K_STEP * k_portions + ik), ldb)] + ob;
|
||||
}
|
||||
for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
|
||||
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * n_resid * k_portions + k_block_size_up * KERNEL_N_STEP * n_portions] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// printf("n_block_size:%lu ,k_block_size: %lu\n", n_block_size, k_block_size);
|
||||
|
||||
// for(size_t n_idx = 0; n_idx < n_block_size; n_idx++) {
|
||||
// for(size_t k_idx = 0; k_idx < k_block_size; k_idx++) {
|
||||
// size_t old_split_n = n_idx / OLD_N_SIZE;
|
||||
// size_t old_idx_n = n_idx % OLD_N_SIZE;
|
||||
// size_t new_split_n = n_idx / NEW_N_SIZE;
|
||||
// size_t new_idx_n = n_idx % NEW_N_SIZE;
|
||||
// size_t split_k = k_idx / KERNEL_K_STEP;
|
||||
// size_t idx_k = k_idx % KERNEL_K_STEP;
|
||||
|
||||
// size_t old_buff_idx =
|
||||
// old_split_n * OLD_N_SIZE * ldb +
|
||||
// split_k * OLD_N_SIZE * KERNEL_K_STEP +
|
||||
// old_idx_n * KERNEL_K_STEP +
|
||||
// idx_k;
|
||||
// size_t new_buff_idx =
|
||||
// new_split_n * NEW_N_SIZE * k_block_size +
|
||||
// split_k * NEW_N_SIZE * KERNEL_K_STEP +
|
||||
// new_idx_n * KERNEL_K_STEP +
|
||||
// idx_k;
|
||||
// bufferB_typed[new_buff_idx] = curr_b_ptr_typed[old_buff_idx] + ob;
|
||||
// }
|
||||
// }
|
||||
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif /* __cplusplus */
|
||||
@@ -1,25 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
#include "beta_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
// Fixed OC
|
||||
void BETA_SUFF(post_ops)(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc) {
|
||||
float* current_c_float_ptr = (float*) current_c_ptr;
|
||||
double double_alpha = (double) alpha;
|
||||
for (size_t n_idx = 0; n_idx < n_block; ++n_idx) {
|
||||
for (size_t m_idx = 0; m_idx < m; ++m_idx) {
|
||||
current_c_ptr[m_idx + n_idx * ldc] = round(((double)(current_c_float_ptr[m_idx + n_idx * ldc]))
|
||||
+ double_alpha * ((double) bufferC[m_idx + n_idx * LDC(m, ldc)]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif /* __cplusplus */
|
||||
@@ -1,35 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
// here we will use BLASINT8, because we care only about type size, not signness (no math operations required)
|
||||
void gemm_impl_8bit(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha,
|
||||
const void* a, size_t lda, const BLASINT8 oa,
|
||||
const void* b, size_t ldb, const BLASINT8 ob,
|
||||
float beta, int32_t* c, size_t ldc, const int32_t* oc, size_t small_switch) {
|
||||
|
||||
void (*small_kernel)(const size_t, const size_t, const size_t, const float,
|
||||
const void *, const size_t, const BLASINT8,
|
||||
const void *, const size_t, const BLASINT8,
|
||||
const float, int32_t *, const size_t, const int32_t *) = arg->small_kernel;
|
||||
|
||||
if (m * n * k < small_switch) { // experimentally measured constant
|
||||
small_kernel(m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
} else {
|
||||
// Corner cases optimizations
|
||||
if ((alpha == 1.0f) && (beta == 0.0f || beta == 1.0f)) {
|
||||
gemm_driver_opt(arg, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
} else {
|
||||
gemm_driver(arg, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif /* __cplusplus */
|
||||
@@ -1,30 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.14.1)
|
||||
|
||||
project(THREAD_TEST)
|
||||
|
||||
|
||||
set(CMAKE_CXX_FLAGS ${CFLAGS_OPT})
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
|
||||
|
||||
|
||||
message(STATUS "Current source dir: ${CMAKE_CURRENT_SOURCE_DIR}/..")
|
||||
set(BLAS_PATH "")
|
||||
message(STATUS "BLAS_PATH=${BLAS_PATH}")
|
||||
|
||||
if ("${BLAS_PATH}" STREQUAL "")
|
||||
|
||||
set(BLAS_LIB "${CMAKE_CURRENT_SOURCE_DIR}/../libint8gemm.so")
|
||||
message(STATUS "BLAS_LIB=${BLAS_LIB}")
|
||||
|
||||
# Add threading library to linker
|
||||
#find_package(Threads)
|
||||
|
||||
add_executable(integer_tester integer_gemm.cpp)
|
||||
#target_include_directories(integer_tester PUBLIC "${BLAS_PATH}")
|
||||
#target_include_directories(integer_tester PUBLIC "${BLAS_BUILD_PATH}")
|
||||
target_link_libraries(integer_tester PRIVATE ${BLAS_LIB})
|
||||
set_property(TARGET integer_tester PROPERTY CXX_STANDARD 17)
|
||||
add_test(integer_tester ${CMAKE_CURRENT_BINARY_DIR}/integer_tester)
|
||||
else ()
|
||||
message(FATAL_ERROR "Can not find int8_gemm path, pls set right path!")
|
||||
endif ()
|
||||
@@ -1,438 +0,0 @@
|
||||
#include <malloc.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
/* matrix saved in rows or cols */
|
||||
typedef enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 } CBLAS_ORDER;
|
||||
|
||||
/* matrix transpose or conjugate transpose */
|
||||
typedef enum CBLAS_TRANSPOSE {
|
||||
CblasNoTrans = 111,
|
||||
CblasTrans = 112,
|
||||
CblasConjTrans = 113, // conjugate transpose
|
||||
CblasConjNoTrans = 114
|
||||
} CBLAS_TRANSPOSE;
|
||||
|
||||
typedef CBLAS_ORDER CBLAS_LAYOUT;
|
||||
|
||||
typedef enum CBLAS_OFFSET { CblasRowOffset = 171, CblasColOffset = 172, CblasFixOffset = 173 } CBLAS_OFFSET;
|
||||
|
||||
typedef int8_t BLASINT8;
|
||||
typedef uint8_t BLASUINT8;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
void cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
|
||||
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void cblas_gemm_u8u8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
|
||||
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void cblas_gemm_s8u8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
|
||||
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void cblas_gemm_u8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
|
||||
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif /* __cplusplus */
|
||||
|
||||
namespace test {
|
||||
|
||||
namespace tools {
|
||||
|
||||
template <typename T, std::size_t alignment = 128>
|
||||
struct aligned_allocator {
|
||||
using value_type = T;
|
||||
using pointer = T*;
|
||||
using const_pointer = const T*;
|
||||
using reference = T&;
|
||||
using const_reference = const T&;
|
||||
using size_type = std::size_t;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
|
||||
template <typename U>
|
||||
struct rebind {
|
||||
typedef aligned_allocator<U, alignment> other;
|
||||
};
|
||||
|
||||
[[nodiscard]] T* allocate(std::size_t n) {
|
||||
if (n > std::numeric_limits<std::size_t>::max() / sizeof(T)) throw std::bad_array_new_length();
|
||||
if (auto p = static_cast<T*>(memalign(alignment, n * sizeof(T)))) {
|
||||
return p;
|
||||
}
|
||||
|
||||
throw std::bad_alloc();
|
||||
}
|
||||
|
||||
void deallocate(T* p, std::size_t n) noexcept {
|
||||
(void)(n);
|
||||
free(p);
|
||||
}
|
||||
|
||||
~aligned_allocator() {}
|
||||
};
|
||||
template <typename T, std::size_t alignment_1, typename U, std::size_t alignment_2>
|
||||
bool operator==(const aligned_allocator<T, alignment_1>&, const aligned_allocator<U, alignment_2>&) {
|
||||
return (alignment_1 == alignment_2) && std::is_same_v<T, U>;
|
||||
}
|
||||
|
||||
template <typename T, std::size_t alignment_1, typename U, std::size_t alignment_2>
|
||||
bool operator!=(const aligned_allocator<T, alignment_1>& lhs, const aligned_allocator<U, alignment_2>& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
template <typename Func, typename... Args>
|
||||
double timing(Func&& func, Args&&... args) {
|
||||
double time = 0.0;
|
||||
double time_begin = 0.0;
|
||||
std::size_t n_run = 0;
|
||||
|
||||
auto start_begin = std::chrono::steady_clock::now();
|
||||
std::forward<Func>(func)(std::forward<Args>(args)...);
|
||||
auto end_begin = std::chrono::steady_clock::now();
|
||||
|
||||
time_begin = std::chrono::duration_cast<std::chrono::nanoseconds>(end_begin - start_begin).count() / 1e9;
|
||||
n_run = std::max<std::size_t>(std::size_t(1.0 / time_begin), 3);
|
||||
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
for (std::size_t i = 0; i < n_run; ++i) {
|
||||
std::forward<Func>(func)(std::forward<Args>(args)...);
|
||||
}
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
|
||||
time += std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count() / 1e9;
|
||||
return time / n_run;
|
||||
}
|
||||
} // namespace tools
|
||||
namespace helpers {
|
||||
std::size_t get_oc_size(CBLAS_OFFSET offset, std::size_t m, std::size_t n) {
|
||||
std::size_t ret_val = 0;
|
||||
switch (offset) {
|
||||
case CblasFixOffset:
|
||||
ret_val = 1;
|
||||
break;
|
||||
case CblasColOffset:
|
||||
ret_val = m;
|
||||
break;
|
||||
case CblasRowOffset:
|
||||
ret_val = n;
|
||||
break;
|
||||
default:
|
||||
std::cout << "Incorrect value of offset to the function " << __PRETTY_FUNCTION__ << std::endl;
|
||||
}
|
||||
return ret_val;
|
||||
}
|
||||
template <typename T>
|
||||
auto get_ab_matrix(CBLAS_LAYOUT lt, CBLAS_TRANSPOSE trans_, T&& non_trans_mtx, T&& trans_mtx) {
|
||||
if (lt == CblasColMajor) {
|
||||
if (trans_ == CblasNoTrans) {
|
||||
return non_trans_mtx.data();
|
||||
} else {
|
||||
return trans_mtx.data();
|
||||
}
|
||||
} else {
|
||||
if (trans_ == CblasNoTrans) {
|
||||
return trans_mtx.data();
|
||||
} else {
|
||||
return non_trans_mtx.data();
|
||||
}
|
||||
}
|
||||
}
|
||||
auto get_ldab(CBLAS_LAYOUT lt, CBLAS_TRANSPOSE trans_mtx, std::size_t ld_n, std::size_t ld_t) {
|
||||
if (lt == CblasColMajor) {
|
||||
if (trans_mtx == CblasNoTrans) {
|
||||
return ld_n;
|
||||
} else {
|
||||
return ld_t;
|
||||
}
|
||||
} else {
|
||||
if (trans_mtx == CblasNoTrans) {
|
||||
return ld_t;
|
||||
} else {
|
||||
return ld_n;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// returns copy of the matrix
|
||||
template <typename T>
|
||||
auto get_c_matrix(CBLAS_LAYOUT lt, T&& non_trans_mtx, T&& trans_mtx) {
|
||||
if (lt == CblasColMajor) {
|
||||
return non_trans_mtx;
|
||||
} else {
|
||||
return trans_mtx;
|
||||
}
|
||||
}
|
||||
|
||||
auto get_ldc(CBLAS_LAYOUT lt, std::size_t ldc_n, std::size_t ldc_t) {
|
||||
if (lt == CblasColMajor) {
|
||||
return ldc_n;
|
||||
} else {
|
||||
return ldc_t;
|
||||
}
|
||||
}
|
||||
template <typename A_Type, typename B_Type>
|
||||
void cblas_gemm_wrapper(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
|
||||
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const A_Type* a, const size_t lda, const int8_t oa, const B_Type* b, const size_t ldb,
|
||||
const int8_t ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc) {
|
||||
if constexpr (std::is_same_v<A_Type, std::int8_t>) {
|
||||
if constexpr (std::is_same_v<B_Type, std::int8_t>) {
|
||||
cblas_gemm_s8s8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
} else {
|
||||
cblas_gemm_s8u8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same_v<B_Type, std::int8_t>) {
|
||||
cblas_gemm_u8s8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
} else {
|
||||
cblas_gemm_u8u8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t return_oc_idx(const CBLAS_OFFSET offsetc, std::size_t mi, std::size_t ni) {
|
||||
return (offsetc == CblasFixOffset) ? 0 : ((offsetc == CblasColOffset) ? mi : ni);
|
||||
}
|
||||
} // namespace helpers
|
||||
enum class status_t { passed, failed };
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const status_t& st) {
|
||||
if (status_t::passed == st) {
|
||||
os << "PASSED";
|
||||
} else if (status_t::failed == st) {
|
||||
os << "FAILED";
|
||||
}
|
||||
return os;
|
||||
}
|
||||
// column major
|
||||
template <typename A_Type, typename B_Type>
|
||||
void ref_gemm(const CBLAS_OFFSET offsetc, const std::size_t m, const std::size_t n, const std::size_t k,
|
||||
const float alpha, const A_Type* a, const std::size_t lda, const std::int8_t oa, const B_Type* b,
|
||||
const std::size_t ldb, const std::int8_t ob, const float beta, std::int32_t* c, const std::size_t ldc,
|
||||
const std::int32_t* oc) {
|
||||
for (std::size_t mi = 0; mi < m; ++mi) {
|
||||
for (std::size_t ni = 0; ni < n; ++ni) {
|
||||
std::int32_t tmp = 0;
|
||||
for (std::size_t ki = 0; ki < k; ++ki) {
|
||||
tmp += (a[mi + ki * lda] + oa) * (b[ki + ni * ldb] + ob);
|
||||
}
|
||||
c[mi + ni * ldc] = std::round(alpha * static_cast<double>(tmp) +
|
||||
static_cast<double>(beta * static_cast<float>(c[mi + ni * ldc])) +
|
||||
static_cast<float>(oc[helpers::return_oc_idx(offsetc, mi, ni)]));
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename DataType>
|
||||
void fill_random(DataType* buffer, std::size_t len) {
|
||||
static std::mt19937 generator(0);
|
||||
std::uniform_int_distribution<DataType> dist(0, 64);
|
||||
for (std::size_t i = 0; i < len; i++) {
|
||||
buffer[i] = static_cast<DataType>(dist(generator));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
void fill_const(DataType* buffer, std::size_t len) {
|
||||
for (std::size_t i = 0; i < len; i++) {
|
||||
buffer[i] = DataType{-8};
|
||||
}
|
||||
}
|
||||
// performs transposition (n0 * n1) -> (n1 * n0), assuming col major
|
||||
template <typename T>
|
||||
void simplest_transpose(T* in, T* out, std::size_t n0, std::size_t n1, std::size_t ld0, std::size_t ld1) {
|
||||
for (std::size_t i = 0; i < n0; ++i) {
|
||||
for (std::size_t j = 0; j < n1; ++j) {
|
||||
out[i + j * ld1] = in[j + i * ld0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
status_t compare(DataType* ref, DataType* test, std::size_t m, std::size_t n, std::size_t ld) {
|
||||
for (std::size_t mi = 0; mi < m; ++mi) {
|
||||
for (std::size_t ni = 0; ni < n; ++ni) {
|
||||
if (ref[mi + ni * ld] != test[mi + ni * ld]) {
|
||||
return status_t::failed;
|
||||
}
|
||||
}
|
||||
}
|
||||
return status_t::passed;
|
||||
}
|
||||
template <typename DataType>
|
||||
void print_matrix(DataType* buffer, std::size_t m, std::size_t n) {
|
||||
for (std::size_t mi = 0; mi < m; ++mi) {
|
||||
for (std::size_t ni = 0; ni < n; ++ni) {
|
||||
std::cout << static_cast<std::int32_t>(buffer[mi + ni * m]) << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
template <typename A_Type, typename B_Type>
|
||||
status_t gemm(std::size_t m, std::size_t n, std::size_t k, float alpha, float beta) {
|
||||
std::int8_t oa = 4;
|
||||
std::int8_t ob = 9;
|
||||
|
||||
std::size_t lda_n = m;
|
||||
std::size_t ldb_n = k;
|
||||
std::size_t ldc_n = m;
|
||||
|
||||
std::size_t lda_t = k;
|
||||
std::size_t ldb_t = n;
|
||||
std::size_t ldc_t = n;
|
||||
|
||||
if (std::getenv("LD_STRIDE")) {
|
||||
lda_n += 2;
|
||||
ldb_n += 7;
|
||||
ldc_n += 3;
|
||||
|
||||
lda_t += 8;
|
||||
ldb_t += 3;
|
||||
ldc_t += 23;
|
||||
}
|
||||
|
||||
bool only_performance = false;
|
||||
if (std::getenv("ONLY_PERF")) {
|
||||
only_performance = true;
|
||||
}
|
||||
std::vector<A_Type, tools::aligned_allocator<A_Type, 128>> a_n(lda_n * k);
|
||||
std::vector<A_Type, tools::aligned_allocator<A_Type, 128>> a_t(m * lda_t);
|
||||
std::vector<B_Type, tools::aligned_allocator<B_Type, 128>> b_n(ldb_n * n);
|
||||
std::vector<B_Type, tools::aligned_allocator<B_Type, 128>> b_t(k * ldb_t);
|
||||
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_ref(ldc_n * n);
|
||||
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_n(ldc_n * n);
|
||||
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_t(m * ldc_t);
|
||||
|
||||
// fill the whole array even if ld* > corresponding dim
|
||||
fill_random(a_n.data(), a_n.size());
|
||||
fill_random(b_n.data(), b_n.size());
|
||||
|
||||
simplest_transpose(a_n.data(), a_t.data(), m, k, lda_n, lda_t);
|
||||
simplest_transpose(b_n.data(), b_t.data(), k, n, ldb_n, ldb_t);
|
||||
|
||||
fill_const(c_ref.data(), ldc_n * n);
|
||||
c_n = c_ref;
|
||||
|
||||
simplest_transpose(c_n.data(), c_t.data(), m, ldc_n, n, ldc_t);
|
||||
|
||||
auto return_st = status_t::passed;
|
||||
double total = 0;
|
||||
size_t cnt = 0;
|
||||
for (auto c_offset : {CblasFixOffset, CblasColOffset, CblasRowOffset}) {
|
||||
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> oc(helpers::get_oc_size(c_offset, m, n));
|
||||
for (std::size_t i = 0; i < oc.size(); ++i) {
|
||||
oc[i] = i + i;
|
||||
}
|
||||
|
||||
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_ref_copy = c_ref;
|
||||
if (!only_performance) {
|
||||
ref_gemm(c_offset, m, n, k, alpha, a_n.data(), lda_n, oa, b_n.data(), ldb_n, ob, beta, c_ref_copy.data(), ldc_n,
|
||||
oc.data());
|
||||
}
|
||||
for (auto layout : {CblasColMajor, CblasRowMajor}) {
|
||||
for (auto transa : {CblasNoTrans, CblasTrans}) {
|
||||
for (auto transb : {CblasNoTrans, CblasTrans}) {
|
||||
auto&& c_tested = helpers::get_c_matrix(layout, c_n, c_t);
|
||||
if (!only_performance) {
|
||||
helpers::cblas_gemm_wrapper(
|
||||
layout, transa, transb, c_offset, m, n, k, alpha, helpers::get_ab_matrix(layout, transa, a_n, a_t),
|
||||
helpers::get_ldab(layout, transa, lda_n, lda_t), oa, helpers::get_ab_matrix(layout, transb, b_n, b_t),
|
||||
helpers::get_ldab(layout, transb, ldb_n, ldb_t), ob, beta, c_tested.data(),
|
||||
helpers::get_ldc(layout, ldc_n, ldc_t), oc.data());
|
||||
|
||||
// transpose c_tested to col-major if required
|
||||
auto loc_st = status_t::passed;
|
||||
if (layout == CblasRowMajor) {
|
||||
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_tested_n(ldc_n * n);
|
||||
simplest_transpose(c_tested.data(), c_tested_n.data(), n, ldc_t, m, ldc_n);
|
||||
loc_st = compare(c_ref_copy.data(), c_tested_n.data(), m, n, ldc_n);
|
||||
} else {
|
||||
loc_st = compare(c_ref_copy.data(), c_tested.data(), m, n, ldc_n);
|
||||
}
|
||||
if (loc_st != status_t::passed) {
|
||||
std::cout << "-";
|
||||
return_st = status_t::failed;
|
||||
} else {
|
||||
std::cout << "+";
|
||||
}
|
||||
} else {
|
||||
double cur = (2.0 * m * n * k) /
|
||||
tools::timing(helpers::cblas_gemm_wrapper<A_Type, B_Type>, layout, transa, transb, c_offset, m,
|
||||
n, k, alpha, helpers::get_ab_matrix(layout, transa, a_n, a_t),
|
||||
helpers::get_ldab(layout, transa, lda_n, lda_t), oa,
|
||||
helpers::get_ab_matrix(layout, transb, b_n, b_t),
|
||||
helpers::get_ldab(layout, transb, ldb_n, ldb_t), ob, beta, c_tested.data(),
|
||||
helpers::get_ldc(layout, ldc_n, ldc_t), oc.data()) /
|
||||
1e12;
|
||||
total += cur;
|
||||
++cnt;
|
||||
|
||||
std::cout << cur << ", ";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (only_performance) {
|
||||
std::cout << "Average " << total / cnt << " TFlops";
|
||||
}
|
||||
std::cout << " ";
|
||||
return return_st;
|
||||
}
|
||||
} // namespace test
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::size_t m = 128;
|
||||
std::size_t n = 128;
|
||||
std::size_t k = 128;
|
||||
float alpha = 1.0f;
|
||||
float beta = 1.0f;
|
||||
|
||||
if (argc > 1) {
|
||||
m = std::stoi(argv[1]);
|
||||
if (argc > 2) {
|
||||
n = std::stoi(argv[2]);
|
||||
if (argc > 3) {
|
||||
k = std::stoi(argv[3]);
|
||||
if (argc > 4) {
|
||||
alpha = std::stof(argv[4]);
|
||||
if (argc > 5) {
|
||||
beta = std::stof(argv[5]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cout << "Testing matrix m = " << m << ", n = " << n << ", k = " << k << ", alpha = " << alpha
|
||||
<< ", beta = " << beta << std::endl;
|
||||
|
||||
std::cout << "\tTesting i8i8i32: " << test::gemm<std::int8_t, std::int8_t>(m, n, k, alpha, beta) << std::endl;
|
||||
std::cout << "\tTesting i8u8i32: " << test::gemm<std::int8_t, std::uint8_t>(m, n, k, alpha, beta) << std::endl;
|
||||
std::cout << "\tTesting u8i8i32: " << test::gemm<std::uint8_t, std::int8_t>(m, n, k, alpha, beta) << std::endl;
|
||||
std::cout << "\tTesting u8u8i32: " << test::gemm<std::uint8_t, std::uint8_t>(m, n, k, alpha, beta) << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_C_STANDARD_REQUIRED ON)
|
||||
|
||||
# can be compiled for SVE256 (32) or SVE512 (64)
|
||||
set(SV_LENGTH 32) # in bytes
|
||||
|
||||
include_directories("${PROJECT_SOURCE_DIR}")
|
||||
include_directories("${PROJECT_SOURCE_DIR}/include")
|
||||
include_directories("${PROJECT_SOURCE_DIR}/standalone")
|
||||
add_compile_options(-fPIC -fvisibility=hidden -fstack-protector-strong -march=armv8.3-a+sve+i8mm -O3)
|
||||
# add_compile_options(-Wall -Wextra -Werror)
|
||||
|
||||
# sources are split into several groups: matmul kernels (sources are compiled multiple times),
|
||||
# packing kernels (sources are compiled multiple times),
|
||||
# sequential/parallel pipelines,
|
||||
# interface (sources are compiled multiple times)
|
||||
set(INT_GEMM_KERNELS "")
|
||||
set(INT_PACK_KERNELS "")
|
||||
set(INT_GEMM_INTERFACE "")
|
||||
set(INT_GEMM_PARALLEL "")
|
||||
set(INT_GEMM_SEQ "")
|
||||
set(BETA_KERNELS "")
|
||||
set(POST_OPS_KERNELS "")
|
||||
set(INT_SMALL_KERNELS "")
|
||||
set(GEMM_DRIVERS "")
|
||||
|
||||
# Supported precisions are i/u for A and B matrices (4 combinations)
|
||||
set(LHS_TYPES LHS_INT LHS_UINT)
|
||||
set(RHS_TYPES RHS_INT RHS_UINT)
|
||||
|
||||
# compile matrix-multiplication kernels multiple times
|
||||
set(INTEGER_MM_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_kernels.c")
|
||||
set(M_SIZES 1 2 3 4)
|
||||
set(N_SIZES 1 2 3 4)
|
||||
foreach(M_SIZE ${M_SIZES})
|
||||
foreach(N_SIZE ${N_SIZES})
|
||||
foreach(LHS_TYPE ${LHS_TYPES})
|
||||
foreach(RHS_TYPE ${RHS_TYPES})
|
||||
add_library(int4_integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH} OBJECT ${INTEGER_MM_KERNELS_SRC})
|
||||
target_compile_options(int4_integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH} PUBLIC -DM_SIZE=${M_SIZE}
|
||||
-DN_SIZE=${N_SIZE} -D${LHS_TYPE} -D${RHS_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(int4_integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH} PUBLIC
|
||||
"${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND INT_GEMM_KERNELS $<TARGET_OBJECTS:int4_integer_gemm_kernels_${LHS_TYPE}_${RHS_TYPE}_${M_SIZE}_${N_SIZE}_${SV_LENGTH}>)
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# compile interface multiple times
|
||||
set(INTEGER_GEMM_IFACE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_interface.c")
|
||||
foreach(LHS_TYPE ${LHS_TYPES})
|
||||
foreach(RHS_TYPE ${RHS_TYPES})
|
||||
add_library(int4_integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH} OBJECT ${INTEGER_GEMM_IFACE_SRC})
|
||||
target_compile_options(int4_integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH} PUBLIC -D${LHS_TYPE} -D${RHS_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(int4_integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH} PUBLIC
|
||||
"${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND INT_GEMM_INTERFACE $<TARGET_OBJECTS:int4_integer_gemm_interface_${LHS_TYPE}_${RHS_TYPE}_${SV_LENGTH}>)
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# compile threading layer
|
||||
# set(INTEGER_GEMM_PAR_PIPE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/parallel_int_gemm_pipeline.c")
|
||||
# add_library(integer_gemm_par_pipe_${SV_LENGTH} OBJECT ${INTEGER_GEMM_PAR_PIPE_SRC})
|
||||
# target_compile_options(int4_integer_gemm_par_pipe_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
# target_include_directories(int4_integer_gemm_par_pipe_${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
# list(APPEND INT_GEMM_PARALLEL $<TARGET_OBJECTS:int4_integer_gemm_par_pipe_${SV_LENGTH}>)
|
||||
|
||||
|
||||
# compile sequential layer
|
||||
set(INTEGER_GEMMSEQ_PIPE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/sequential_int_gemm_pipeline.c")
|
||||
add_library(int4_integer_gemm_seq_pipe_${SV_LENGTH} OBJECT ${INTEGER_GEMMSEQ_PIPE_SRC})
|
||||
target_compile_options(int4_integer_gemm_seq_pipe_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(int4_integer_gemm_seq_pipe_${SV_LENGTH} PUBLIC
|
||||
"${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND INT_GEMM_SEQ $<TARGET_OBJECTS:int4_integer_gemm_seq_pipe_${SV_LENGTH}>)
|
||||
|
||||
# compile packingA kernels
|
||||
set(INT_PACK_A_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_pack_a_kernels.c")
|
||||
set(TRANSA_VALS TRANSA NOTRANSA)
|
||||
foreach(LHS_TYPE ${LHS_TYPES})
|
||||
foreach(TRANSA_VAL ${TRANSA_VALS})
|
||||
add_library(int4_integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH} OBJECT ${INT_PACK_A_KERNELS_SRC})
|
||||
target_compile_options(int4_integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH} -D${LHS_TYPE} -D${TRANSA_VAL})
|
||||
target_include_directories(int4_integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND INT_PACK_KERNELS $<TARGET_OBJECTS:int4_integer_gemm_pack_a_${LHS_TYPE}_${TRANSA_VAL}_${SV_LENGTH}>)
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# compile packingB kernels
|
||||
set(INT_PACK_B_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_pack_b_kernels.c")
|
||||
set(TRANSB_VALS TRANSB NOTRANSB)
|
||||
foreach(RHS_TYPE ${RHS_TYPES})
|
||||
foreach(TRANSB_VAL ${TRANSB_VALS})
|
||||
add_library(int4_integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH} OBJECT ${INT_PACK_B_KERNELS_SRC})
|
||||
target_compile_options(int4_integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH} PUBLIC -DCOMP_SV_LEN=${SV_LENGTH} -D${RHS_TYPE} -D${TRANSB_VAL})
|
||||
target_include_directories(int4_integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND INT_PACK_KERNELS $<TARGET_OBJECTS:int4_integer_gemm_pack_b_${RHS_TYPE}_${TRANSB_VAL}_${SV_LENGTH}>)
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# compile beta kernels
|
||||
set(BETA_OPTS BETA_OPT BETA_NO_OPT)
|
||||
|
||||
foreach(B_TYPE ${BETA_OPTS})
|
||||
add_library(int4_beta_kernels_${B_TYPE} OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/integer_beta_kernels.c")
|
||||
target_compile_options(int4_beta_kernels_${B_TYPE} PUBLIC -D${B_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(int4_beta_kernels_${B_TYPE} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND BETA_KERNELS $<TARGET_OBJECTS:int4_beta_kernels_${B_TYPE}>)
|
||||
endforeach()
|
||||
|
||||
# compile int gemm drivers
|
||||
foreach(B_TYPE ${BETA_OPTS})
|
||||
add_library(int4_gemm_driver_${B_TYPE} OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_driver.c")
|
||||
target_compile_options(int4_gemm_driver_${B_TYPE} PUBLIC -D${B_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(int4_gemm_driver_${B_TYPE} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND GEMM_DRIVERS $<TARGET_OBJECTS:int4_gemm_driver_${B_TYPE}>)
|
||||
endforeach()
|
||||
|
||||
# compile post-ops kernels
|
||||
foreach(B_TYPE ${BETA_OPTS})
|
||||
add_library(int4_post_ops_kernels_${B_TYPE} OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/integer_post_ops_kernels.c")
|
||||
target_compile_options(int4_post_ops_kernels_${B_TYPE} PUBLIC -D${B_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(int4_post_ops_kernels_${B_TYPE} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND POST_OPS_KERNELS $<TARGET_OBJECTS:int4_post_ops_kernels_${B_TYPE}>)
|
||||
endforeach()
|
||||
|
||||
# compile matrix-multiplication small kernels multiple times
|
||||
set(SMALL_KERNELS_KERNELS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/integer_gemm_small_kernels.c")
|
||||
set(TRANSA_VALS TRANSA NOTRANSA)
|
||||
set(TRANSB_VALS TRANSB NOTRANSB)
|
||||
set(OC_TYPES OC_FIX OC_COL OC_ROW)
|
||||
|
||||
foreach(LHS_TYPE ${LHS_TYPES})
|
||||
foreach(RHS_TYPE ${RHS_TYPES})
|
||||
foreach(TRANSA_VAL ${TRANSA_VALS})
|
||||
foreach(TRANSB_VAL ${TRANSB_VALS})
|
||||
foreach(OC_TYPE ${OC_TYPES})
|
||||
add_library(int4_small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH} OBJECT ${SMALL_KERNELS_KERNELS_SRC})
|
||||
target_compile_options(int4_small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH} PUBLIC -D${LHS_TYPE} -D${RHS_TYPE} -D${TRANSA_VAL} -D${TRANSB_VAL} -D${OC_TYPE} -DCOMP_SV_LEN=${SV_LENGTH})
|
||||
target_include_directories(int4_small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH} PUBLIC "${CMAKE_SOURCE_DIR};${BLAS_INCLUDE_DIR};${DRIVER_INCLUDE_DIR};${CMAKE_BINARY_DIR}")
|
||||
list(APPEND INT_SMALL_KERNELS $<TARGET_OBJECTS:int4_small_kernels_${LHS_TYPE}_${RHS_TYPE}_${TRANSA_VAL}_${TRANSB_VAL}_${OC_TYPE}${SV_LENGTH}>)
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
list(APPEND OBJ_FILES_STANDALONE_DIR ${INT_GEMM_KERNELS} ${INT_GEMM_INTERFACE} ${INT_GEMM_SEQ} ${INT_PACK_KERNELS} ${BETA_KERNELS} ${POST_OPS_KERNELS} ${INT_SMALL_KERNELS} ${GEMM_DRIVERS})
|
||||
# set(OBJ_FILES_STANDALONE_DIR ${OBJ_FILES_STANDALONE_DIR} PARENT_SCOPE)
|
||||
# all compiled object files are united into one object library
|
||||
add_library(prefillint4gemm SHARED ${OBJ_FILES_STANDALONE_DIR})
|
||||
|
||||
set_target_properties(prefillint4gemm PROPERTIES
|
||||
ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/prefillint4gemm
|
||||
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/prefillint4gemm
|
||||
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/prefillint4gemm
|
||||
)
|
||||
@@ -1,25 +0,0 @@
|
||||
#ifndef BETA_MACROS_H
|
||||
#define BETA_MACROS_H
|
||||
|
||||
#if defined(OC_FIX)
|
||||
#define OC_TYPE f
|
||||
#define OC_IDX(mi, nn) 0
|
||||
#elif defined(OC_COL)
|
||||
#define OC_TYPE c
|
||||
#define OC_IDX(mi, ni) mi
|
||||
#else
|
||||
#define OC_TYPE r
|
||||
#define OC_IDX(mi, ni) ni
|
||||
#endif
|
||||
|
||||
#if defined(BETA_OPT)
|
||||
#define BETA_SUFF(name) name##_opt
|
||||
#define LDC(m, ldc) ldc
|
||||
#define DTYPE int32_t
|
||||
#else
|
||||
#define BETA_SUFF(name) name
|
||||
#define LDC(m, ldc) m
|
||||
#define DTYPE float
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -1,63 +0,0 @@
|
||||
#ifndef __HELPING_MACROS_INT4_H__
|
||||
#define __HELPING_MACROS_INT4_H__
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#if defined(LHS_INT) && defined(LHS_UINT)
|
||||
#error "Both LHS_INT and LHS_UINT are defined"
|
||||
#endif
|
||||
|
||||
#if defined(RHS_INT) && defined(RHS_UINT)
|
||||
#error "Both RHS_INT and RHS_UINT are defined"
|
||||
#endif
|
||||
|
||||
#ifdef LHS_INT
|
||||
#define LHS_TYPE s
|
||||
#define LHS_INT_TYPE int8_t
|
||||
#endif
|
||||
#ifdef LHS_UINT
|
||||
#define LHS_TYPE u
|
||||
#define LHS_INT_TYPE uint8_t
|
||||
#endif
|
||||
#ifdef RHS_INT
|
||||
#define RHS_TYPE s
|
||||
#define RHS_INT_TYPE int8_t
|
||||
#endif
|
||||
#ifdef RHS_UINT
|
||||
#define RHS_TYPE u
|
||||
#define RHS_INT_TYPE uint8_t
|
||||
#endif
|
||||
|
||||
// mangling macros
|
||||
#define ADD_M_N_SIZES(name, m_size, n_size) name##_##m_size##x##n_size
|
||||
#define ADD_M_N_SIZES_MACRO(name, m_size, n_size) ADD_M_N_SIZES(name, m_size, n_size)
|
||||
#define ADD_TYPES(name, lhs_type, rhs_type) name##_##lhs_type##8##rhs_type##8s32
|
||||
#define ADD_TYPES_MACRO(name, lhs_type, rhs_type) ADD_TYPES(name, lhs_type, rhs_type)
|
||||
#define ADD_TYPES_SUFF(name) ADD_TYPES_MACRO(name, LHS_TYPE, RHS_TYPE)
|
||||
#define ADD_ONE_TYPE_TRANSP(name, type, nt) name##_##type##8_##nt
|
||||
#define ADD_ONE_TYPE_TRANSP_MACRO(name, type, nt) ADD_ONE_TYPE_TRANSP(name, type, nt)
|
||||
#define ADD_PACK_A_N_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, LHS_TYPE, n)
|
||||
#define ADD_PACK_B_N_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, RHS_TYPE, n)
|
||||
#define ADD_PACK_A_T_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, LHS_TYPE, t)
|
||||
#define ADD_PACK_B_T_SUFF(name) ADD_ONE_TYPE_TRANSP_MACRO(name, RHS_TYPE, t)
|
||||
#define ADD_TWO_TYPES_TRANSP(name, lhs_type, rhs_type, a_t, b_t, oc_t) \
|
||||
name##_##lhs_type##8##rhs_type##8s32##_##a_t##b_t##_##oc_t
|
||||
#define ADD_TWO_TYPES_TRANSP_MACRO(name, lhs_type, rhs_type, a_t, b_t, oc_t) \
|
||||
ADD_TWO_TYPES_TRANSP(name, lhs_type, rhs_type, a_t, b_t, oc_t)
|
||||
#define ADD_TRANSP_MACRO(name, a_t, b_t, oc_t) ADD_TWO_TYPES_TRANSP_MACRO(name, LHS_TYPE, RHS_TYPE, a_t, b_t, oc_t)
|
||||
|
||||
#ifdef ENABLE_THREADING
|
||||
#define ADD_THREAD_SUFF(name) name##_thread
|
||||
#else
|
||||
#define ADD_THREAD_SUFF(name) name
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -1,82 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
#include "beta_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
// Column major C, Fixed C_offset
|
||||
void BETA_SUFF(beta_cf_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
|
||||
DTYPE oc_val = (DTYPE)*oc;
|
||||
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
|
||||
DTYPE beta_val = (DTYPE)beta;
|
||||
for (size_t ni = 0; ni < n_block_size; ++ni) {
|
||||
for (size_t mi = 0; mi < m_block_size; ++mi) {
|
||||
c_typed_ptr[ldc * ni + mi] = beta_val * ((DTYPE)c_ptr[ldc * ni + mi]) + oc_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Column major C, Column major C_offset
|
||||
void BETA_SUFF(beta_cc_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
|
||||
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
|
||||
DTYPE beta_val = (DTYPE)beta;
|
||||
for (size_t ni = 0; ni < n_block_size; ++ni) {
|
||||
for (size_t mi = 0; mi < m_block_size; ++mi) {
|
||||
c_typed_ptr[ldc * ni + mi] = beta_val * ((DTYPE)c_ptr[ldc * ni + mi]) + ((DTYPE)oc[mi]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Column major C, Row major C_offset
|
||||
void BETA_SUFF(beta_cr_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
|
||||
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
|
||||
DTYPE beta_val = (DTYPE)beta;
|
||||
for (size_t ni = 0; ni < n_block_size; ++ni) {
|
||||
for (size_t mi = 0; mi < m_block_size; ++mi) {
|
||||
c_typed_ptr[ldc * ni + mi] = beta_val * ((DTYPE)c_ptr[ldc * ni + mi]) + ((DTYPE)oc[ni]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Row major C, Fixed C_offset
|
||||
// for row-major we actually swap m and n values so we reswap it here again
|
||||
void BETA_SUFF(beta_rf_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
|
||||
DTYPE oc_val = (DTYPE)*oc;
|
||||
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
|
||||
DTYPE beta_val = (DTYPE)beta;
|
||||
for (size_t mi = 0; mi < n_block_size; ++mi) {
|
||||
for (size_t ni = 0; ni < m_block_size; ++ni) {
|
||||
c_typed_ptr[ldc * mi + ni] = beta_val * ((DTYPE)c_ptr[ldc * mi + ni]) + oc_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Row major C, Column major C_offset
|
||||
// for row-major we actually swap m and n values so we reswap it here again
|
||||
void BETA_SUFF(beta_rc_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
|
||||
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
|
||||
DTYPE beta_val = (DTYPE)beta;
|
||||
for (size_t mi = 0; mi < n_block_size; ++mi) {
|
||||
for (size_t ni = 0; ni < m_block_size; ++ni) {
|
||||
c_typed_ptr[ldc * mi + ni] = beta_val * ((DTYPE)c_ptr[ldc * mi + ni]) + ((DTYPE)oc[mi]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Row major C, Row major C_offset
|
||||
// for row-major we actually swap m and n values so we reswap it here again
|
||||
void BETA_SUFF(beta_rr_s8)(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc) {
|
||||
DTYPE* c_typed_ptr = (DTYPE*) c_ptr;
|
||||
DTYPE beta_val = (DTYPE)beta;
|
||||
for (size_t mi = 0; mi < n_block_size; ++mi) {
|
||||
for (size_t ni = 0; ni < m_block_size; ++ni) {
|
||||
c_typed_ptr[ldc * mi + ni] = beta_val * ((DTYPE)c_ptr[ldc * mi + ni]) + ((DTYPE)oc[ni]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif /* __cplusplus */
|
||||
@@ -1,121 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "beta_macros.h"
|
||||
|
||||
#define OLD_N_SIZE 8
|
||||
#define PACKED_LD_STEP(n_step, k_step, ldb) (n_step * (ldb / 2) + (k_step / 2) * OLD_N_SIZE)
|
||||
|
||||
void BETA_SUFF(gemm_driver)(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha,
|
||||
const void* a, size_t lda, const BLASINT8 oa,
|
||||
const void* b, size_t ldb, const BLASINT8 ob,
|
||||
float beta, int32_t* c, size_t ldc, const int32_t* oc) {
|
||||
|
||||
void (**gemm_kernels)(const void*, const void*, int32_t*, size_t, int64_t, int64_t) = arg->gemm_kernels;
|
||||
void (*pack_a_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8) = arg->pack_a_fun;
|
||||
void (*pack_b_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8) = arg->pack_b_fun;
|
||||
void (*beta_func)(int32_t*, const int32_t*, float, size_t, size_t, size_t) = arg->beta_func;
|
||||
size_t (*a_indexing)(size_t m, size_t n, size_t ld) = arg->a_indexing;
|
||||
size_t (*b_indexing)(size_t m, size_t n, size_t ld) = arg->b_indexing;
|
||||
|
||||
#ifndef BETA_OPT
|
||||
void (*post_ops_func)(float, const int32_t*, int32_t*, size_t, size_t, size_t) = arg->post_ops_func;
|
||||
#endif // BETA_OPT
|
||||
|
||||
const BLASINT8* a_typed = (const BLASINT8*) a;
|
||||
const BLASINT8* b_typed = (const BLASINT8*) b;
|
||||
|
||||
BLASINT8* bufferA = (BLASINT8*) aligned_alloc(ALIGNMENT, sizeof(BLASINT8) * K_BLOCK * M_BLOCK);
|
||||
BLASINT8* bufferB = (BLASINT8*) aligned_alloc(ALIGNMENT, sizeof(BLASINT8) * K_BLOCK * N_BLOCK);
|
||||
|
||||
// Tmp buffer is not needed when (alpha = 1 and beta = 0/1)
|
||||
#ifdef BETA_OPT
|
||||
int32_t* bufferC = c;
|
||||
#else
|
||||
int32_t* bufferC = (int32_t*) aligned_alloc(ALIGNMENT, sizeof(int32_t) * m * N_BLOCK);
|
||||
#endif
|
||||
|
||||
if (!bufferA || !bufferB || !bufferC) {
|
||||
free(bufferA);
|
||||
free(bufferB);
|
||||
free(bufferC);
|
||||
printf("Integer GEMM unsuccessful allocation");
|
||||
return;
|
||||
}
|
||||
// printf("pack b beta: %f\n",beta);
|
||||
beta_func(c, oc, beta, m, n, ldc);
|
||||
|
||||
for (size_t n_block = 0; n_block < n; n_block += N_BLOCK) {
|
||||
size_t n_block_size = n - n_block;
|
||||
if (n_block_size > N_BLOCK) {
|
||||
n_block_size = N_BLOCK;
|
||||
}
|
||||
|
||||
#ifndef BETA_OPT
|
||||
// fill bufferC w/ zeros
|
||||
for (size_t tmp_idx = 0; tmp_idx < (m * N_BLOCK); ++tmp_idx) {
|
||||
bufferC[tmp_idx] = 0;
|
||||
}
|
||||
#endif // BETA_OPT
|
||||
|
||||
if (alpha != 0.0f){
|
||||
for (size_t k_block = 0; k_block < k; k_block += K_BLOCK){
|
||||
size_t k_block_size = k - k_block;
|
||||
if (k_block_size > K_BLOCK) {
|
||||
k_block_size = K_BLOCK;
|
||||
}
|
||||
size_t k_block_size_up = (k_block_size + KERNEL_K_STEP - 1) / KERNEL_K_STEP * KERNEL_K_STEP;
|
||||
|
||||
const BLASINT8* curr_b_ptr = b_typed + b_indexing(k_block, n_block, ldb);
|
||||
|
||||
pack_b_fun(bufferB, curr_b_ptr, n_block_size, k_block_size, ldb, ob);
|
||||
for (size_t m_block = 0; m_block < m; m_block += M_BLOCK) {
|
||||
size_t m_block_size = m - m_block;
|
||||
if (m_block_size > M_BLOCK) {
|
||||
m_block_size = M_BLOCK;
|
||||
}
|
||||
const BLASINT8* curr_a_ptr = a_typed + PACKED_LD_STEP(m_block, k_block, lda);
|
||||
pack_a_fun(bufferA, curr_a_ptr, m_block_size, k_block_size, lda, oa);
|
||||
// loop over bufferB, taking parts which fit into L1
|
||||
for (size_t n_sub_block = 0; n_sub_block < n_block_size; n_sub_block += KERNEL_N_STEP) {
|
||||
size_t n_sub_block_size = n_block_size - n_sub_block;
|
||||
if (n_sub_block_size > KERNEL_N_STEP) {
|
||||
n_sub_block_size = KERNEL_N_STEP;
|
||||
}
|
||||
BLASINT8* current_bufferB_ptr = bufferB + n_sub_block * k_block_size_up;
|
||||
// loop over bufferA, taking parts which fit into L1
|
||||
for (size_t m_sub_block = 0; m_sub_block < m_block_size; m_sub_block += KERNEL_M_STEP) {
|
||||
size_t m_sub_block_size = m_block_size - m_sub_block;
|
||||
if (m_sub_block_size > KERNEL_M_STEP) {
|
||||
m_sub_block_size = KERNEL_M_STEP;
|
||||
}
|
||||
BLASINT8* current_bufferA_ptr = bufferA + m_sub_block * k_block_size_up;
|
||||
#ifdef BETA_OPT
|
||||
int32_t* current_bufferC_ptr = bufferC + n_block * ldc + n_sub_block * LDC(m, ldc) + m_sub_block + m_block;
|
||||
#else
|
||||
int32_t* current_bufferC_ptr = bufferC + n_sub_block * LDC(m, ldc) + m_sub_block + m_block;
|
||||
#endif
|
||||
// call kernel which performs loop over k_block_size
|
||||
gemm_kernels[(n_sub_block_size - 1) + (m_sub_block_size - 1) * KERNEL_N_STEP](current_bufferA_ptr, current_bufferB_ptr, current_bufferC_ptr,
|
||||
LDC(m, ldc), k_block_size_up, COMP_SV_LEN);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef BETA_OPT
|
||||
// copy C data from bufferC multiplying by alpha and adding initial C data (scaled by beta)
|
||||
int32_t* current_c_ptr = c + n_block * ldc; // col major
|
||||
post_ops_func(alpha, bufferC, current_c_ptr, LDC(m, ldc), n_block_size, ldc);
|
||||
#endif
|
||||
}
|
||||
|
||||
free(bufferA);
|
||||
free(bufferB);
|
||||
|
||||
#ifndef BETA_OPT
|
||||
free(bufferC);
|
||||
#endif
|
||||
}
|
||||
@@ -1,154 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
//#include "cblas.h"
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
|
||||
|
||||
/* matrix saved in rows or cols */
|
||||
typedef enum CBLAS_ORDER {
|
||||
CblasRowMajor = 101,
|
||||
CblasColMajor = 102
|
||||
} CBLAS_ORDER;
|
||||
|
||||
/* matrix transpose or conjugate transpose */
|
||||
typedef enum CBLAS_TRANSPOSE {
|
||||
CblasNoTrans = 111,
|
||||
CblasTrans = 112,
|
||||
CblasConjTrans = 113, // conjugate transpose
|
||||
CblasConjNoTrans = 114
|
||||
} CBLAS_TRANSPOSE;
|
||||
|
||||
typedef CBLAS_ORDER CBLAS_LAYOUT;
|
||||
|
||||
typedef enum CBLAS_OFFSET {
|
||||
CblasRowOffset = 171,
|
||||
CblasColOffset = 172,
|
||||
CblasFixOffset = 173
|
||||
} CBLAS_OFFSET;
|
||||
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#define ADD_KERNEL_SUFF(name, m_size, n_size) ADD_M_N_SIZES_MACRO(ADD_TYPES_MACRO(name, LHS_TYPE, RHS_TYPE), m_size, n_size)
|
||||
|
||||
static void (*gemm_kernels[])(const void*, const void*, int32_t*, size_t, int64_t, int64_t) = {
|
||||
ADD_KERNEL_SUFF(gemm_kernel, 1, 1), ADD_KERNEL_SUFF(gemm_kernel, 1, 2),ADD_KERNEL_SUFF(gemm_kernel, 1, 3), ADD_KERNEL_SUFF(gemm_kernel, 1, 4),
|
||||
ADD_KERNEL_SUFF(gemm_kernel, 2, 1), ADD_KERNEL_SUFF(gemm_kernel, 2, 2),ADD_KERNEL_SUFF(gemm_kernel, 2, 3), ADD_KERNEL_SUFF(gemm_kernel, 2, 4),
|
||||
ADD_KERNEL_SUFF(gemm_kernel, 3, 1), ADD_KERNEL_SUFF(gemm_kernel, 3, 2),ADD_KERNEL_SUFF(gemm_kernel, 3, 3), ADD_KERNEL_SUFF(gemm_kernel, 3, 4),
|
||||
ADD_KERNEL_SUFF(gemm_kernel, 4, 1), ADD_KERNEL_SUFF(gemm_kernel, 4, 2),ADD_KERNEL_SUFF(gemm_kernel, 4, 3), ADD_KERNEL_SUFF(gemm_kernel, 4, 4)
|
||||
};
|
||||
|
||||
static void (*pack_b_funs[])(void*, const void*, size_t, size_t, size_t, const BLASINT8) = {
|
||||
ADD_PACK_B_N_SUFF(pack_b),
|
||||
ADD_PACK_B_T_SUFF(pack_b)
|
||||
};
|
||||
|
||||
static void (*pack_a_funs[])(void*, const void*, size_t, size_t, size_t, const BLASINT8) = {
|
||||
ADD_PACK_A_N_SUFF(pack_a),
|
||||
ADD_PACK_A_T_SUFF(pack_a)
|
||||
};
|
||||
|
||||
static void (*small_kernels[])(const size_t, const size_t, const size_t, const float,
|
||||
const void *, const size_t, const BLASINT8,
|
||||
const void *, const size_t, const BLASINT8,
|
||||
const float, int32_t *, const size_t, const int32_t *) = {
|
||||
ADD_TRANSP_MACRO(small_kernel, n, n, f), ADD_TRANSP_MACRO(small_kernel, n, t, f),
|
||||
ADD_TRANSP_MACRO(small_kernel, t, n, f), ADD_TRANSP_MACRO(small_kernel, t, t, f),
|
||||
ADD_TRANSP_MACRO(small_kernel, n, n, c), ADD_TRANSP_MACRO(small_kernel, n, t, c),
|
||||
ADD_TRANSP_MACRO(small_kernel, t, n, c), ADD_TRANSP_MACRO(small_kernel, t, t, c),
|
||||
ADD_TRANSP_MACRO(small_kernel, n, n, r), ADD_TRANSP_MACRO(small_kernel, n, t, r),
|
||||
ADD_TRANSP_MACRO(small_kernel, t, n, r), ADD_TRANSP_MACRO(small_kernel, t, t, r),
|
||||
};
|
||||
|
||||
static void (*beta_funcs[])(int32_t*, const int32_t*, float, size_t, size_t, size_t) = {
|
||||
beta_cf_s8, beta_cc_s8, beta_cr_s8, beta_rf_s8, beta_rc_s8, beta_rr_s8,
|
||||
beta_cf_s8_opt, beta_cc_s8_opt, beta_cr_s8_opt, beta_rf_s8_opt, beta_rc_s8_opt, beta_rr_s8_opt
|
||||
};
|
||||
|
||||
static void (*post_op_kernels[])(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc) = {
|
||||
post_ops, post_ops_opt
|
||||
};
|
||||
|
||||
static size_t row_major_idx(size_t m, size_t n, size_t ld) {
|
||||
return ld * m + n;
|
||||
}
|
||||
|
||||
static size_t col_major_idx(size_t m, size_t n, size_t ld) {
|
||||
return m + ld * n;
|
||||
}
|
||||
|
||||
static size_t (*compute_idx[])(size_t m, size_t n, size_t ld) = {
|
||||
col_major_idx,
|
||||
row_major_idx
|
||||
};
|
||||
|
||||
static size_t mov_oc_fix(size_t mi, size_t ni) {
|
||||
UNUSED(mi);
|
||||
UNUSED(ni);
|
||||
return 0;
|
||||
}
|
||||
static size_t mov_oc_col(size_t mi, size_t ni){
|
||||
UNUSED(ni);
|
||||
return mi;
|
||||
}
|
||||
|
||||
static size_t mov_oc_row(size_t mi, size_t ni) {
|
||||
UNUSED(mi);
|
||||
return ni;
|
||||
}
|
||||
|
||||
static size_t (*move_oc[])(size_t, size_t) = {
|
||||
mov_oc_fix, mov_oc_col, mov_oc_row
|
||||
};
|
||||
|
||||
EXTERNAL_API void ADD_TYPES_SUFF(prefill_int4_cblas_gemm)(
|
||||
const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,
|
||||
const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const void *a, const size_t lda, const BLASINT8 oa,
|
||||
const void *b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t *c, const size_t ldc, const int32_t *oc) {
|
||||
|
||||
int opt_offset = ((alpha == 1.0f) && (beta == 0.0f || beta == 1.0f)) ? 1 : 0;
|
||||
if(Layout == CblasColMajor) {
|
||||
int beta_offset = (offsetc == CblasFixOffset) ? 0 : (offsetc == CblasColOffset ? 1:2);
|
||||
int_gemm_funcs arg = {
|
||||
small_kernels[(transb == CblasTrans) + 2 * (transa == CblasTrans) + beta_offset * 4],
|
||||
gemm_kernels,
|
||||
pack_a_funs[transa == CblasTrans],
|
||||
pack_b_funs[transb == CblasTrans],
|
||||
beta_funcs[beta_offset + opt_offset * 6],
|
||||
post_op_kernels[alpha == 1],
|
||||
compute_idx[transa == CblasTrans],
|
||||
compute_idx[transb == CblasTrans],
|
||||
move_oc[beta_offset],
|
||||
};
|
||||
(gemm_impl_8bit(&arg, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc, 262144));
|
||||
} else if (Layout == CblasRowMajor) {
|
||||
int beta_offset = (offsetc == CblasFixOffset) ? 3 : (offsetc == CblasColOffset ? 4 : 5);
|
||||
int beta_offset_small = (offsetc == CblasFixOffset) ? 0 : (offsetc == CblasColOffset ? 2 : 1);
|
||||
int_gemm_funcs arg = {
|
||||
small_kernels[(transa == CblasTrans) + 2 * (transb == CblasTrans) + beta_offset_small * 4],
|
||||
gemm_kernels,
|
||||
pack_a_funs[transb == CblasTrans],
|
||||
pack_b_funs[transa == CblasTrans],
|
||||
beta_funcs[beta_offset + opt_offset * 6],
|
||||
post_op_kernels[alpha == 1],
|
||||
compute_idx[transb == CblasTrans],
|
||||
compute_idx[transa == CblasTrans],
|
||||
move_oc[beta_offset_small]
|
||||
};
|
||||
(gemm_impl_8bit(&arg, n, m, k, alpha, b, ldb, ob, a, lda, oa, beta, c, ldc, oc, 262144));
|
||||
}
|
||||
else {
|
||||
printf("Incorrect layout");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,453 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define ADD_SUFFIX(name) ADD_M_N_SIZES_MACRO(ADD_TYPES_MACRO(name, LHS_TYPE, RHS_TYPE), M_SIZE, N_SIZE)
|
||||
|
||||
#define LD1B_PTR(reg_name, p, ptr, idx) "ld1b {" #reg_name ".b}, " #p "/z, [%[" #ptr "], #" #idx ", MUL VL]\n"
|
||||
#define COMPUTE_ADDP(out, in1, in2) "addp " #out ".s, " #in1 ".s, " #in2 ".s\n"
|
||||
#if (defined(LHS_INT) && defined(RHS_INT)) || (defined(LHS_UINT) && defined(RHS_UINT))
|
||||
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) #lhs_type "dot " #out ".s, " #in1 ".b, " #in2 ".b\n"
|
||||
#else
|
||||
#ifdef LHS_INT
|
||||
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in2 ".b, " #in1 ".b\n"
|
||||
#else
|
||||
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in1 ".b, " #in2 ".b\n"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE) COMPUTE_DOT_TYPED(out, in1, in2, LHS_TYPE, RHS_TYPE)
|
||||
#define COMPUTE_DOT(out, in1, in2) COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE)
|
||||
|
||||
#if (N_SIZE > 4)
|
||||
#error "N_SIZE can't be greater than 4"
|
||||
#endif
|
||||
|
||||
#if (M_SIZE > 4)
|
||||
#error "M_SIZE can't be greater than 4"
|
||||
#endif
|
||||
|
||||
#define LOAD_Z0(p, ptr) LD1B_PTR(z0, p, ptr, 0)
|
||||
#define LOAD_Z8(p, ptr) LD1B_PTR(z8, p, ptr, 0)
|
||||
|
||||
#if (N_SIZE > 1)
|
||||
#define LOAD_Z1(p, ptr) LD1B_PTR(z1, p, ptr, 1)
|
||||
#define LOAD_Z9(p, ptr) LD1B_PTR(z9, p, ptr, 1)
|
||||
#else
|
||||
#define LOAD_Z1(p, ptr)
|
||||
#define LOAD_Z9(p, ptr)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 2)
|
||||
#define LOAD_Z2(p, ptr) LD1B_PTR(z2, p, ptr, 2)
|
||||
#define LOAD_Z10(p, ptr) LD1B_PTR(z10, p, ptr, 2)
|
||||
#else
|
||||
#define LOAD_Z2(p, ptr)
|
||||
#define LOAD_Z10(p, ptr)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 3)
|
||||
#define LOAD_Z3(p, ptr) LD1B_PTR(z3, p, ptr, 3)
|
||||
#define LOAD_Z11(p, ptr) LD1B_PTR(z11, p, ptr, 3)
|
||||
#else
|
||||
#define LOAD_Z3(p, ptr)
|
||||
#define LOAD_Z11(p, ptr)
|
||||
#endif
|
||||
|
||||
#define LOAD_Z4(p, ptr) LD1B_PTR(z4, p, ptr, 0)
|
||||
#define LOAD_Z12(p, ptr) LD1B_PTR(z12, p, ptr, 0)
|
||||
|
||||
#if (M_SIZE > 1)
|
||||
#define LOAD_Z5(p, ptr) LD1B_PTR(z5, p, ptr, 1)
|
||||
#define LOAD_Z13(p, ptr) LD1B_PTR(z13, p, ptr, 1)
|
||||
#else
|
||||
#define LOAD_Z5(p, ptr)
|
||||
#define LOAD_Z13(p, ptr)
|
||||
#endif
|
||||
|
||||
#if (M_SIZE > 2)
|
||||
#define LOAD_Z6(p, ptr) LD1B_PTR(z6, p, ptr, 2)
|
||||
#define LOAD_Z14(p, ptr) LD1B_PTR(z14, p, ptr, 2)
|
||||
#else
|
||||
#define LOAD_Z6(p, ptr)
|
||||
#define LOAD_Z14(p, ptr)
|
||||
#endif
|
||||
|
||||
#if (M_SIZE > 3)
|
||||
#define LOAD_Z7(p, ptr) LD1B_PTR(z7, p, ptr, 3)
|
||||
#define LOAD_Z15(p, ptr) LD1B_PTR(z15, p, ptr, 3)
|
||||
#else
|
||||
#define LOAD_Z7(p, ptr)
|
||||
#define LOAD_Z15(p, ptr)
|
||||
#endif
|
||||
|
||||
// macros for dot multiplication
|
||||
#define ACCUMULATE_Z16(lhs, rhs) COMPUTE_DOT(z16, lhs, rhs)
|
||||
|
||||
#if (N_SIZE > 1)
|
||||
#define ACCUMULATE_Z17(lhs, rhs) COMPUTE_DOT(z17, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z17(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 2)
|
||||
#define ACCUMULATE_Z18(lhs, rhs) COMPUTE_DOT(z18, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z18(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 3)
|
||||
#define ACCUMULATE_Z19(lhs, rhs) COMPUTE_DOT(z19, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z19(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (M_SIZE > 1)
|
||||
#define ACCUMULATE_Z20(lhs, rhs) COMPUTE_DOT(z20, lhs, rhs)
|
||||
#if (N_SIZE > 1)
|
||||
#define ACCUMULATE_Z21(lhs, rhs) COMPUTE_DOT(z21, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z21(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 2)
|
||||
#define ACCUMULATE_Z22(lhs, rhs) COMPUTE_DOT(z22, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z22(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 3)
|
||||
#define ACCUMULATE_Z23(lhs, rhs) COMPUTE_DOT(z23, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z23(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#else
|
||||
#define ACCUMULATE_Z20(lhs, rhs)
|
||||
#define ACCUMULATE_Z21(lhs, rhs)
|
||||
#define ACCUMULATE_Z22(lhs, rhs)
|
||||
#define ACCUMULATE_Z23(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (M_SIZE > 2)
|
||||
#define ACCUMULATE_Z24(lhs, rhs) COMPUTE_DOT(z24, lhs, rhs)
|
||||
|
||||
#if (N_SIZE > 1)
|
||||
#define ACCUMULATE_Z25(lhs, rhs) COMPUTE_DOT(z25, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z25(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 2)
|
||||
#define ACCUMULATE_Z26(lhs, rhs) COMPUTE_DOT(z26, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z26(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 3)
|
||||
#define ACCUMULATE_Z27(lhs, rhs) COMPUTE_DOT(z27, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z27(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#else
|
||||
#define ACCUMULATE_Z24(lhs, rhs)
|
||||
#define ACCUMULATE_Z25(lhs, rhs)
|
||||
#define ACCUMULATE_Z26(lhs, rhs)
|
||||
#define ACCUMULATE_Z27(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (M_SIZE > 3)
|
||||
#define ACCUMULATE_Z28(lhs, rhs) COMPUTE_DOT(z28, lhs, rhs)
|
||||
|
||||
#if (N_SIZE > 1)
|
||||
#define ACCUMULATE_Z29(lhs, rhs) COMPUTE_DOT(z29, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z29(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 2)
|
||||
#define ACCUMULATE_Z30(lhs, rhs) COMPUTE_DOT(z30, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z30(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#if (N_SIZE > 3)
|
||||
#define ACCUMULATE_Z31(lhs, rhs) COMPUTE_DOT(z31, lhs, rhs)
|
||||
#else
|
||||
#define ACCUMULATE_Z31(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#else
|
||||
#define ACCUMULATE_Z28(lhs, rhs)
|
||||
#define ACCUMULATE_Z29(lhs, rhs)
|
||||
#define ACCUMULATE_Z30(lhs, rhs)
|
||||
#define ACCUMULATE_Z31(lhs, rhs)
|
||||
#endif
|
||||
|
||||
#define MOVE_LHS_PTR(ptr) "add %[" #ptr "], %[" #ptr "], %[move_lhs]\n"
|
||||
#define MOVE_RHS_PTR(ptr) "add %[" #ptr "], %[" #ptr "], %[move_rhs]\n"
|
||||
|
||||
#define PROCESS_ACCUM(reg_idx, z_reg_idx, tmp_reg, dst, p) \
|
||||
"ldr w" #reg_idx ", [%[" #dst "]]\n" \
|
||||
"saddv d" #reg_idx ", " #p ", z" #z_reg_idx ".s\n" \
|
||||
"fmov " #tmp_reg ", d" #reg_idx "\n" \
|
||||
"add x" #reg_idx ", x" #reg_idx ", " #tmp_reg "\n" \
|
||||
"str w" #reg_idx ", [%[" #dst "]], #4\n"
|
||||
|
||||
// function logic
|
||||
void ADD_SUFFIX(gemm_kernel)(const void *lhs_ptr, const void *rhs_ptr,
|
||||
int32_t *accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len) {
|
||||
int64_t run_k_depth = k_depth;
|
||||
int64_t run_sv_len = sv_len;
|
||||
int64_t run_2sv_len = 2 * sv_len;
|
||||
int64_t move_lhs = M_SIZE * sv_len;
|
||||
int64_t move_rhs = N_SIZE * sv_len;
|
||||
int32_t* dst_ptr = accum_ptr;
|
||||
ldc -= M_SIZE;
|
||||
ldc *= 4;
|
||||
asm volatile(
|
||||
// predicate for operating on lhs and rhs is always true
|
||||
"ptrue p0.b, all\n"
|
||||
// Clear accumulators
|
||||
LOAD_Z0(p0, rhs_ptr)
|
||||
"dup z16.s, #0\n"
|
||||
LOAD_Z1(p0, rhs_ptr)
|
||||
"dup z17.s, #0\n"
|
||||
LOAD_Z4(p0, lhs_ptr)
|
||||
"dup z18.s, #0\n"
|
||||
LOAD_Z5(p0, lhs_ptr)
|
||||
"dup z19.s, #0\n"
|
||||
LOAD_Z6(p0, lhs_ptr)
|
||||
"dup z20.s, #0\n"
|
||||
LOAD_Z7(p0, lhs_ptr)
|
||||
"dup z21.s, #0\n"
|
||||
LOAD_Z2(p0, rhs_ptr)
|
||||
"dup z22.s, #0\n"
|
||||
LOAD_Z3(p0, rhs_ptr)
|
||||
"dup z23.s, #0\n"
|
||||
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
|
||||
"dup z24.s, #0\n"
|
||||
"mov x16, %[dst_ptr]\n"
|
||||
"dup z25.s, #0\n"
|
||||
"dup z26.s, #0\n"
|
||||
"dup z27.s, #0\n"
|
||||
MOVE_LHS_PTR(lhs_ptr)
|
||||
"dup z28.s, #0\n"
|
||||
MOVE_RHS_PTR(rhs_ptr)
|
||||
"dup z29.s, #0\n"
|
||||
"dup z30.s, #0\n"
|
||||
"dup z31.s, #0\n"
|
||||
|
||||
"ble 1f\n"
|
||||
|
||||
"cmp %[run_k_depth], %[run_2sv_len]\n"
|
||||
"blt 2f\n"
|
||||
|
||||
"3:\n"
|
||||
LOAD_Z12(p0, lhs_ptr)
|
||||
ACCUMULATE_Z16(z4,z0)
|
||||
ACCUMULATE_Z17(z4,z1)
|
||||
LOAD_Z13(p0, lhs_ptr)
|
||||
ACCUMULATE_Z18(z4,z2)
|
||||
ACCUMULATE_Z19(z4,z3)
|
||||
LOAD_Z8(p0, rhs_ptr)
|
||||
ACCUMULATE_Z20(z5,z0)
|
||||
ACCUMULATE_Z21(z5,z1)
|
||||
LOAD_Z9(p0, rhs_ptr)
|
||||
ACCUMULATE_Z22(z5,z2)
|
||||
ACCUMULATE_Z23(z5,z3)
|
||||
LOAD_Z10(p0,rhs_ptr)
|
||||
ACCUMULATE_Z24(z6,z0)
|
||||
ACCUMULATE_Z25(z6,z1)
|
||||
LOAD_Z11(p0,rhs_ptr)
|
||||
ACCUMULATE_Z26(z6,z2)
|
||||
MOVE_RHS_PTR(rhs_ptr)
|
||||
"prfw pldl1keep, p0, [%[rhs_ptr], #4, MUL VL]\n"
|
||||
ACCUMULATE_Z27(z6,z3)
|
||||
LOAD_Z14(p0, lhs_ptr)
|
||||
ACCUMULATE_Z28(z7,z0)
|
||||
ACCUMULATE_Z29(z7,z1)
|
||||
LOAD_Z15(p0,lhs_ptr)
|
||||
ACCUMULATE_Z30(z7,z2)
|
||||
MOVE_LHS_PTR(lhs_ptr)
|
||||
"prfw pldl1keep, p0, [%[lhs_ptr], #4, MUL VL]\n"
|
||||
ACCUMULATE_Z31(z7,z3)
|
||||
|
||||
LOAD_Z4(p0, lhs_ptr)
|
||||
ACCUMULATE_Z16(z12,z8)
|
||||
ACCUMULATE_Z17(z12,z9)
|
||||
LOAD_Z5(p0, lhs_ptr)
|
||||
ACCUMULATE_Z18(z12,z10)
|
||||
ACCUMULATE_Z19(z12,z11)
|
||||
LOAD_Z6(p0, lhs_ptr)
|
||||
ACCUMULATE_Z20(z13,z8)
|
||||
ACCUMULATE_Z21(z13,z9)
|
||||
LOAD_Z0(p0, rhs_ptr)
|
||||
"sub %[run_k_depth], %[run_k_depth], %[run_2sv_len]\n"
|
||||
ACCUMULATE_Z22(z13,z10)
|
||||
ACCUMULATE_Z23(z13,z11)
|
||||
LOAD_Z1(p0, rhs_ptr)
|
||||
ACCUMULATE_Z24(z14,z8)
|
||||
ACCUMULATE_Z25(z14,z9)
|
||||
LOAD_Z2(p0,rhs_ptr)
|
||||
ACCUMULATE_Z26(z14,z10)
|
||||
ACCUMULATE_Z27(z14,z11)
|
||||
LOAD_Z3(p0, rhs_ptr)
|
||||
ACCUMULATE_Z28(z15,z8)
|
||||
MOVE_RHS_PTR(rhs_ptr)
|
||||
"prfw pldl1keep, p0, [%[rhs_ptr], #4, MUL VL]\n"
|
||||
ACCUMULATE_Z29(z15,z9)
|
||||
LOAD_Z7(p0, lhs_ptr)
|
||||
"cmp %[run_k_depth], %[run_2sv_len]\n"
|
||||
ACCUMULATE_Z30(z15, z10)
|
||||
MOVE_LHS_PTR(lhs_ptr)
|
||||
"prfw pldl1keep, p0, [%[lhs_ptr], #4, MUL VL]\n"
|
||||
ACCUMULATE_Z31(z15,z11)
|
||||
"bge 3b\n"
|
||||
|
||||
"cmp %[run_k_depth], #0\n"
|
||||
"ble 1f\n"
|
||||
|
||||
"2:\n"
|
||||
"subs %[run_k_depth], %[run_k_depth], %[run_sv_len]\n"
|
||||
ACCUMULATE_Z16(z4,z0)
|
||||
ACCUMULATE_Z17(z4,z1)
|
||||
ACCUMULATE_Z18(z4,z2)
|
||||
ACCUMULATE_Z19(z4,z3)
|
||||
LOAD_Z4(p0,lhs_ptr)
|
||||
ACCUMULATE_Z20(z5,z0)
|
||||
ACCUMULATE_Z21(z5,z1)
|
||||
ACCUMULATE_Z22(z5,z2)
|
||||
ACCUMULATE_Z23(z5,z3)
|
||||
LOAD_Z5(p0,lhs_ptr)
|
||||
ACCUMULATE_Z24(z6,z0)
|
||||
ACCUMULATE_Z25(z6,z1)
|
||||
ACCUMULATE_Z26(z6,z2)
|
||||
ACCUMULATE_Z27(z6,z3)
|
||||
LOAD_Z6(p0,lhs_ptr)
|
||||
ACCUMULATE_Z28(z7,z0)
|
||||
LOAD_Z0(p0,rhs_ptr)
|
||||
ACCUMULATE_Z29(z7,z1)
|
||||
LOAD_Z1(p0,rhs_ptr)
|
||||
ACCUMULATE_Z30(z7,z2)
|
||||
LOAD_Z2(p0,rhs_ptr)
|
||||
ACCUMULATE_Z31(z7,z3)
|
||||
LOAD_Z3(p0,rhs_ptr)
|
||||
MOVE_RHS_PTR(rhs_ptr)
|
||||
LOAD_Z7(p0,lhs_ptr)
|
||||
MOVE_LHS_PTR(lhs_ptr)
|
||||
"bgt 2b\n"
|
||||
|
||||
"1:\n"
|
||||
ACCUMULATE_Z16(z4,z0)
|
||||
ACCUMULATE_Z17(z4,z1)
|
||||
ACCUMULATE_Z18(z4,z2)
|
||||
ACCUMULATE_Z19(z4,z3)
|
||||
ACCUMULATE_Z20(z5,z0)
|
||||
ACCUMULATE_Z21(z5,z1)
|
||||
ACCUMULATE_Z22(z5,z2)
|
||||
ACCUMULATE_Z23(z5,z3)
|
||||
ACCUMULATE_Z24(z6,z0)
|
||||
ACCUMULATE_Z25(z6,z1)
|
||||
ACCUMULATE_Z26(z6,z2)
|
||||
ACCUMULATE_Z27(z6,z3)
|
||||
ACCUMULATE_Z28(z7,z0)
|
||||
ACCUMULATE_Z29(z7,z1)
|
||||
ACCUMULATE_Z30(z7,z2)
|
||||
ACCUMULATE_Z31(z7,z3)
|
||||
|
||||
#if (N_SIZE > 0)
|
||||
#if (M_SIZE > 0)
|
||||
PROCESS_ACCUM(0, 16, x16, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 1)
|
||||
PROCESS_ACCUM(4, 20, x17, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 2)
|
||||
PROCESS_ACCUM(8, 24, x18, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 3)
|
||||
PROCESS_ACCUM(12, 28, x17, dst_ptr, p0)
|
||||
#endif
|
||||
#endif
|
||||
"add %[dst_ptr], %[dst_ptr], %[ldc]\n"
|
||||
|
||||
#if (N_SIZE > 1)
|
||||
#if (M_SIZE > 0)
|
||||
PROCESS_ACCUM(1, 17, x16, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 1)
|
||||
PROCESS_ACCUM(5, 21, x17,dst_ptr,p0)
|
||||
#endif
|
||||
#if (M_SIZE > 2)
|
||||
PROCESS_ACCUM(9, 25, x18, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 3)
|
||||
PROCESS_ACCUM(13, 29, x17, dst_ptr, p0)
|
||||
#endif
|
||||
#endif
|
||||
"add %[dst_ptr], %[dst_ptr], %[ldc]\n"
|
||||
|
||||
|
||||
#if (N_SIZE > 2)
|
||||
#if (M_SIZE > 0)
|
||||
PROCESS_ACCUM(2, 18, x16, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 1)
|
||||
PROCESS_ACCUM(6, 22, x17,dst_ptr,p0)
|
||||
#endif
|
||||
#if (M_SIZE > 2)
|
||||
PROCESS_ACCUM(10,26,x18,dst_ptr,p0)
|
||||
#endif
|
||||
#if (M_SIZE > 3)
|
||||
PROCESS_ACCUM(14,30,x17,dst_ptr,p0)
|
||||
#endif
|
||||
#endif
|
||||
"add %[dst_ptr], %[dst_ptr], %[ldc]\n"
|
||||
|
||||
|
||||
#if (N_SIZE > 3)
|
||||
#if (M_SIZE > 0)
|
||||
PROCESS_ACCUM(3, 19, x16, dst_ptr, p0)
|
||||
#endif
|
||||
#if (M_SIZE > 1)
|
||||
PROCESS_ACCUM(7, 23, x17,dst_ptr,p0)
|
||||
#endif
|
||||
#if (M_SIZE > 2)
|
||||
PROCESS_ACCUM(11,27,x18,dst_ptr,p0)
|
||||
#endif
|
||||
#if (M_SIZE > 3)
|
||||
PROCESS_ACCUM(15,31,x17,dst_ptr,p0)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
:
|
||||
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
|
||||
[run_k_depth] "+r"(run_k_depth),
|
||||
[dst_ptr] "+wr"(dst_ptr)
|
||||
:
|
||||
[run_sv_len] "r"(run_sv_len), [run_2sv_len] "r"(run_2sv_len),
|
||||
[move_lhs] "r"(move_lhs), [move_rhs] "r"(move_rhs), [ldc] "r"(ldc),
|
||||
[accum_ptr] "r"(accum_ptr)
|
||||
:
|
||||
"cc", "memory",
|
||||
"w0","w1","w2","w3","w4","w5","w6","w7",
|
||||
"w8","w9","w10","w11","w12","w13","w14","w15",
|
||||
"x16","x17","x18","x19",
|
||||
"z0","z1","z2","z3","z4","z5","z6","z7",
|
||||
"z8","z9","z10","z11","z12","z13","z14","z15",
|
||||
"z16","z17","z18","z19","z20","z21","z22","z23",
|
||||
"z24","z25","z26","z27","z28","z29","z30","z31"
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,468 +0,0 @@
|
||||
#ifndef __GEMM_INTFOUR_KERNELS_H__
|
||||
#define __GEMM_INTFOUR_KERNELS_H__
|
||||
|
||||
#include <stdint.h>
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
typedef int8_t BLASINT8;
|
||||
typedef uint8_t BLASUINT8;
|
||||
|
||||
typedef struct {
|
||||
void (*small_kernel)(const size_t, const size_t, const size_t, const float, const void*, const size_t, const BLASINT8,
|
||||
const void*, const size_t, const BLASINT8, const float, int32_t*, const size_t, const int32_t*);
|
||||
void (**gemm_kernels)(const void*, const void*, int32_t*, size_t, int64_t, int64_t);
|
||||
void (*pack_a_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8);
|
||||
void (*pack_b_fun)(void*, const void*, size_t, size_t, size_t, const BLASINT8);
|
||||
void (*beta_func)(int32_t*, const int32_t*, float, size_t, size_t, size_t);
|
||||
void (*post_ops_func)(float, const int32_t*, int32_t*, size_t, size_t, size_t);
|
||||
size_t (*a_indexing)(size_t, size_t, size_t);
|
||||
size_t (*b_indexing)(size_t, size_t, size_t);
|
||||
size_t (*move_oc)(size_t, size_t);
|
||||
|
||||
} int_gemm_funcs;
|
||||
|
||||
#ifndef COMP_SV_LEN
|
||||
#error "COMP_SV_LEN is not defined"
|
||||
#endif
|
||||
|
||||
#define KERNEL_M_STEP 4
|
||||
#define KERNEL_N_STEP 4
|
||||
#define KERNEL_K_STEP COMP_SV_LEN
|
||||
|
||||
#define M_BLOCK 256
|
||||
#if ((M_BLOCK % KERNEL_M_STEP) != 0)
|
||||
#error "M_BLOCK % KERNEL_M_STEP != 0"
|
||||
#endif
|
||||
#define N_BLOCK 256
|
||||
#if ((N_BLOCK % KERNEL_N_STEP) != 0)
|
||||
#error "N_BLOCK % KERNEL_N_STEP != 0"
|
||||
#endif
|
||||
#define K_BLOCK 512
|
||||
#if ((K_BLOCK % KERNEL_K_STEP) != 0)
|
||||
#error "K_BLOCK % KERNEL_K_STEP != 0"
|
||||
#endif
|
||||
|
||||
#define ALIGNMENT 4096
|
||||
|
||||
#define EXTERNAL_API __attribute__((visibility("default")))
|
||||
#define UNUSED(arg) ((void)(arg))
|
||||
|
||||
// general pipeline
|
||||
void gemm_impl_8bit(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha, const void* a, size_t lda,
|
||||
const BLASINT8 oa, const void* b, size_t ldb, const BLASINT8 ob, float beta, int32_t* c, size_t ldc,
|
||||
const int32_t* oc, size_t small_switch);
|
||||
// s8 kernel
|
||||
void pack_b_s8_n(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
|
||||
const BLASINT8 ob);
|
||||
|
||||
void pack_a_s8_n(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
|
||||
const BLASINT8 oa);
|
||||
|
||||
void pack_b_s8_t(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
|
||||
const BLASINT8 ob);
|
||||
|
||||
void pack_a_s8_t(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
|
||||
const BLASINT8 oa);
|
||||
|
||||
// u8 kernels
|
||||
void pack_b_u8_n(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
|
||||
const BLASINT8 ob);
|
||||
|
||||
void pack_a_u8_n(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
|
||||
const BLASINT8 oa);
|
||||
|
||||
void pack_b_u8_t(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb,
|
||||
const BLASINT8 ob);
|
||||
|
||||
void pack_a_u8_t(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda,
|
||||
const BLASINT8 oa);
|
||||
|
||||
// beta kernels
|
||||
void beta_cf_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
|
||||
void beta_cc_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
|
||||
void beta_cr_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
|
||||
void beta_rf_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
|
||||
void beta_rc_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
|
||||
void beta_rr_s8(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size, size_t ldc);
|
||||
|
||||
void beta_cf_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
|
||||
size_t ldc);
|
||||
void beta_cc_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
|
||||
size_t ldc);
|
||||
void beta_cr_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
|
||||
size_t ldc);
|
||||
void beta_rf_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
|
||||
size_t ldc);
|
||||
void beta_rc_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
|
||||
size_t ldc);
|
||||
void beta_rr_s8_opt(int32_t* c_ptr, const int32_t* oc, float beta, size_t m_block_size, size_t n_block_size,
|
||||
size_t ldc);
|
||||
|
||||
// post-ops kernels
|
||||
void post_ops(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc);
|
||||
void post_ops_opt(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc);
|
||||
|
||||
// drivers
|
||||
void gemm_driver(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha, const void* a, size_t lda,
|
||||
const BLASINT8 oa, const void* b, size_t ldb, const BLASINT8 ob, float beta, int32_t* c, size_t ldc,
|
||||
const int32_t* oc);
|
||||
void gemm_driver_opt(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha, const void* a, size_t lda,
|
||||
const BLASINT8 oa, const void* b, size_t ldb, const BLASINT8 ob, float beta, int32_t* c,
|
||||
size_t ldc, const int32_t* oc);
|
||||
|
||||
// matrix multiplication kernels
|
||||
|
||||
// s8s8s32 kernels
|
||||
void gemm_kernel_s8s8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_s8s8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_s8s8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_s8s8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8s8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
// u8u8s32 kernels
|
||||
|
||||
void gemm_kernel_u8u8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_u8u8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_u8u8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_u8u8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8u8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
// s8u8s32 kernels
|
||||
void gemm_kernel_s8u8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_s8u8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_s8u8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_s8u8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_s8u8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
// u8s8s32 kernels
|
||||
void gemm_kernel_u8s8s32_4x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_4x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_4x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_4x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_u8s8s32_3x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_3x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_3x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_3x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_u8s8s32_2x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_2x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_2x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_2x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_u8s8s32_1x4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_1x3(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_1x2(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
void gemm_kernel_u8s8s32_1x1(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
// small kernels
|
||||
// s8s8s32 kernels
|
||||
void small_kernel_s8s8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8s8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
// s8u8s32 kernels
|
||||
|
||||
void small_kernel_s8u8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_s8u8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
// u8s8s32 kernels
|
||||
void small_kernel_u8s8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8s8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
// u8u8s32 kernels
|
||||
void small_kernel_u8u8s32_nn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_nt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_tn_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_tt_f(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_nn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_nt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_tn_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_tt_c(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_nn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_nt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_tn_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void small_kernel_u8u8s32_tt_r(const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
@@ -1,158 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#if defined(OC_FIX)
|
||||
#define OC_TYPE f
|
||||
#define OC_IDX(mi, nn) 0
|
||||
#elif defined(OC_COL)
|
||||
#define OC_TYPE c
|
||||
#define OC_IDX(mi, ni) mi
|
||||
#else // OC_ROW
|
||||
#define OC_TYPE r
|
||||
#define OC_IDX(mi, ni) ni
|
||||
#endif // OC_T
|
||||
|
||||
#if defined(TRANSA)
|
||||
#if defined TRANSB
|
||||
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, t, t, OC_TYPE)
|
||||
#elif defined(NOTRANSB)
|
||||
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, t, n, OC_TYPE)
|
||||
#else
|
||||
#error "Neither TRANSB or NOTRANSB is defined"
|
||||
#endif
|
||||
#elif defined(NOTRANSA)
|
||||
#if defined TRANSB
|
||||
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, n, t, OC_TYPE)
|
||||
#elif defined(NOTRANSB)
|
||||
#define ADD_SUFFIX(name) ADD_TRANSP_MACRO(name, n, n, OC_TYPE)
|
||||
#else
|
||||
#error "Neither TRANSB or NOTRANSB is defined"
|
||||
#endif
|
||||
#else
|
||||
#error "Neither TRANSA or NOTRANSA is defined"
|
||||
#endif
|
||||
|
||||
#if (defined (LHS_INT) && defined(RHS_INT)) || (defined (LHS_UINT) && defined(RHS_UINT))
|
||||
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) #lhs_type "dot " #out ".s, " #in1 ".b, " #in2 ".b\n"
|
||||
#else
|
||||
#if defined(LHS_INT)
|
||||
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in2 ".b," #in1 ".b\n"
|
||||
#else // LHS_UINT
|
||||
#define COMPUTE_DOT_TYPED(out, in1, in2, lhs_type, rhs_type) "usdot " #out ".s, " #in1 ".b," #in2 ".b\n"
|
||||
#endif // LHS_INT
|
||||
#endif // LHS_INT
|
||||
|
||||
#define COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE) COMPUTE_DOT_TYPED(out, in1, in2, LHS_TYPE, RHS_TYPE)
|
||||
#define COMPUTE_DOT(out, in1, in2) COMPUTE_DOT_TYPED_MACRO(out, in1, in2, LHS_TYPE, RHS_TYPE)
|
||||
|
||||
static inline double compute_dot(size_t k, const void *a, const BLASINT8* oa,
|
||||
const void *b, const BLASINT8* ob, int64_t sv_len) {
|
||||
int32_t accum = 0;
|
||||
int64_t run_k_depth = k;
|
||||
int64_t run_sv_len = sv_len;
|
||||
const void* lhs_ptr = a;
|
||||
const void* rhs_ptr = b;
|
||||
asm volatile(
|
||||
"dup z4.s, #0\n"
|
||||
"ptrue p0.b, all\n"
|
||||
"ld1b {z0.b}, p0/z, [%[oa]]\n"
|
||||
"ld1b {z1.b}, p0/z, [%[ob]]\n"
|
||||
"1:\n"
|
||||
"whilelt p1.b, xzr, %[run_k_depth]\n"
|
||||
"ld1b {z2.b}, p1/z, [%[lhs_ptr]]\n"
|
||||
"ld1b {z3.b}, p1/z, [%[rhs_ptr]]\n"
|
||||
"add z2.b, p1/m, z2.b, z0.b\n"
|
||||
"add z3.b, p1/m, z3.b, z1.b\n"
|
||||
"add %[lhs_ptr], %[lhs_ptr], %[run_sv_len]\n"
|
||||
"add %[rhs_ptr], %[rhs_ptr], %[run_sv_len]\n"
|
||||
COMPUTE_DOT(z4, z2, z3)
|
||||
"subs %[run_k_depth], %[run_k_depth], #1\n"
|
||||
"bgt 1b\n"
|
||||
"ptrue p2.s, all\n"
|
||||
"saddv d1, p2, z4.s\n"
|
||||
"fmov x2, d1\n"
|
||||
"add %[accum], %[accum], x2\n"
|
||||
: // outputs
|
||||
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
|
||||
[run_k_depth] "+wr"(run_k_depth),
|
||||
[accum] "+wr"(accum)
|
||||
: // inputs
|
||||
[run_sv_len] "r"(run_sv_len), [oa] "r"(oa), [ob] "r"(ob)
|
||||
: // clobbers
|
||||
"cc", "memory",
|
||||
"d1", "x2",
|
||||
"z0", "z1", "z2", "z3", "z4", "p0", "p1", "p2"
|
||||
);
|
||||
return (double) accum;
|
||||
}
|
||||
|
||||
#if !defined(TRANSA) || defined(TRANSB)
|
||||
// performs transposition (n0 * n1) -> (n1 * n0) assuming col major
|
||||
static inline void simplest_transpose(const void *in, void *out, size_t n0, size_t ld0, size_t n1) {
|
||||
// since we care only about size, we can use signed type always
|
||||
BLASINT8* typed_in = (BLASINT8*) in;
|
||||
BLASINT8* typed_out = (BLASINT8*) out;
|
||||
for (size_t i = 0; i < n1; ++i) {
|
||||
for (size_t j = 0; j < n0; ++j) {
|
||||
typed_out[i + j * n1] = typed_in[j + i * ld0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // !defined(TRANSA) || defined(TRANSB)
|
||||
|
||||
// A in row-major, B in col-major
|
||||
void ADD_SUFFIX(small_kernel)(const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const void *a, const size_t lda, const BLASINT8 oa,
|
||||
const void *b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t *c, const size_t ldc, const int32_t *oc) {
|
||||
double double_alpha = (double) alpha;
|
||||
// we use typed pointers only for indexing, so we don't care about signess
|
||||
#ifdef TRANSA
|
||||
BLASINT8* a_typed = (BLASINT8*) a;
|
||||
const size_t used_lda = lda;
|
||||
#else
|
||||
BLASINT8* a_typed = (BLASINT8*) aligned_alloc(128, sizeof(BLASINT8) * m * k);
|
||||
simplest_transpose(a, a_typed, m, lda, k);
|
||||
const size_t used_lda = k;
|
||||
#endif // TRANSA
|
||||
#ifndef TRANSB
|
||||
BLASINT8* b_typed = (BLASINT8*) b;
|
||||
const size_t used_ldb = ldb;
|
||||
#else // TRANSB
|
||||
BLASINT8* b_typed = (BLASINT8*) aligned_alloc(128, sizeof(BLASINT8) * k * n);
|
||||
simplest_transpose(b, b_typed, n, ldb, k);
|
||||
const size_t used_ldb = k;
|
||||
#endif // TRANSB
|
||||
BLASINT8 oa_buf[KERNEL_K_STEP];
|
||||
BLASINT8 ob_buf[KERNEL_K_STEP];
|
||||
for (size_t i = 0; i < KERNEL_K_STEP; ++i) {
|
||||
oa_buf[i] = oa;
|
||||
ob_buf[i] = ob;
|
||||
}
|
||||
// printf("\n========\n");
|
||||
for (size_t mi = 0; mi < m; ++mi) {
|
||||
for (size_t ni = 0; ni < n; ++ni) {
|
||||
// printf("mi = %lu, ni = %lu, oc_idx = %lu\n", mi, ni, OC_IDX(mi, ni));
|
||||
double tmp = compute_dot(k, a_typed + mi * used_lda, oa_buf, b_typed + ni * used_ldb, ob_buf, KERNEL_K_STEP);
|
||||
c[mi + ni * ldc] = round(tmp * double_alpha + ((double)(beta * ((float)c[mi + ni * ldc]) + oc[OC_IDX(mi, ni)])));
|
||||
}
|
||||
}
|
||||
#ifdef TRANSA
|
||||
free(a_typed);
|
||||
#endif // TRANSA
|
||||
#ifdef TRANSB
|
||||
free(b_typed);
|
||||
#endif // TRANSB
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif /* __cplusplus */
|
||||
@@ -1,140 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#if defined(TRANSA)
|
||||
// row major
|
||||
#define INDEXING_A(row_idx, col_idx, lda) ((col_idx) * (lda) + row_idx)
|
||||
#define ADD_A_SUFF(name) ADD_PACK_A_T_SUFF(name)
|
||||
#elif defined(NOTRANSA)
|
||||
// col major
|
||||
#define INDEXING_A(row_idx, col_idx, lda) ((row_idx) * (lda) + col_idx)
|
||||
#define ADD_A_SUFF(name) ADD_PACK_A_N_SUFF(name)
|
||||
#else
|
||||
#error "Neither TRANSA or NOTRANSA is defined"
|
||||
#endif
|
||||
|
||||
#define OLD_N_SIZE 8
|
||||
#define NEW_N_SIZE 4
|
||||
|
||||
void ADD_A_SUFF(pack_a)(void* bufferA, const void* curr_a_ptr, size_t m_block_size, size_t k_block_size, size_t lda, const BLASINT8 oa) {
|
||||
LHS_INT_TYPE* bufferA_typed = (LHS_INT_TYPE*) bufferA;
|
||||
LHS_INT_TYPE* curr_a_ptr_typed = (LHS_INT_TYPE*) curr_a_ptr;
|
||||
|
||||
// printf("m_block_size:%lu ,k_block_size: %lu\n", m_block_size, k_block_size);
|
||||
|
||||
for(size_t old_split_n = 0; old_split_n < (m_block_size / OLD_N_SIZE); old_split_n++) {
|
||||
for(size_t split_k = 0; split_k < (k_block_size / (KERNEL_K_STEP * 2)); split_k++) {
|
||||
for(size_t old_idx_n = 0; old_idx_n < OLD_N_SIZE; old_idx_n++) {
|
||||
for(size_t idx_k = 0; idx_k < KERNEL_K_STEP; idx_k++) {
|
||||
size_t n_idx = old_split_n * OLD_N_SIZE + old_idx_n;
|
||||
size_t new_split_n = n_idx / NEW_N_SIZE;
|
||||
size_t new_idx_n = n_idx % NEW_N_SIZE;
|
||||
|
||||
size_t old_buff_idx =
|
||||
old_split_n * OLD_N_SIZE * (lda / 2) +
|
||||
split_k * OLD_N_SIZE * KERNEL_K_STEP +
|
||||
old_idx_n * KERNEL_K_STEP +
|
||||
idx_k;
|
||||
|
||||
uint8_t b01 = curr_a_ptr_typed[old_buff_idx];
|
||||
uint8_t b0 = b01 & 0xF0;
|
||||
uint8_t b1 = b01 << 4;
|
||||
|
||||
size_t new_buff_idx =
|
||||
new_split_n * NEW_N_SIZE * k_block_size +
|
||||
split_k * NEW_N_SIZE * (KERNEL_K_STEP * 2) +
|
||||
new_idx_n * KERNEL_K_STEP +
|
||||
idx_k;
|
||||
bufferA_typed[new_buff_idx] = b0;
|
||||
bufferA_typed[new_buff_idx + KERNEL_K_STEP * NEW_N_SIZE] = b1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for(size_t n_idx = 0; n_idx < m_block_size; n_idx++) {
|
||||
// for(size_t k_idx = 0; k_idx < k_block_size; k_idx++) {
|
||||
// size_t old_split_n = n_idx / OLD_N_SIZE;
|
||||
// size_t old_idx_n = n_idx % OLD_N_SIZE;
|
||||
// size_t new_split_n = n_idx / NEW_N_SIZE;
|
||||
// size_t new_idx_n = n_idx % NEW_N_SIZE;
|
||||
// size_t split_k = k_idx / KERNEL_K_STEP;
|
||||
// size_t idx_k = k_idx % KERNEL_K_STEP;
|
||||
|
||||
// size_t old_buff_idx =
|
||||
// old_split_n * OLD_N_SIZE * lda +
|
||||
// split_k * OLD_N_SIZE * KERNEL_K_STEP +
|
||||
// old_idx_n * KERNEL_K_STEP +
|
||||
// idx_k;
|
||||
// size_t new_buff_idx =
|
||||
// new_split_n * NEW_N_SIZE * k_block_size +
|
||||
// split_k * NEW_N_SIZE * KERNEL_K_STEP +
|
||||
// new_idx_n * KERNEL_K_STEP +
|
||||
// idx_k;
|
||||
// bufferA_typed[new_buff_idx] = curr_a_ptr_typed[old_buff_idx] + oa;
|
||||
// }
|
||||
// }
|
||||
|
||||
// size_t k_block_size_up = (k_block_size + KERNEL_K_STEP - 1) / KERNEL_K_STEP * KERNEL_K_STEP;
|
||||
// size_t k_portions = k_block_size / KERNEL_K_STEP;
|
||||
// size_t k_resid = k_block_size - KERNEL_K_STEP * k_portions;
|
||||
|
||||
// size_t m_portions = m_block_size / KERNEL_M_STEP;
|
||||
// size_t m_resid = m_block_size - KERNEL_M_STEP * m_portions;
|
||||
|
||||
// for (size_t im4 = 0; im4 < m_portions; ++im4) {
|
||||
// for (size_t ik16 = 0; ik16 < k_portions; ++ik16) {
|
||||
// for (size_t im = 0; im < KERNEL_M_STEP; ++im) {
|
||||
// for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
|
||||
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * KERNEL_M_STEP * ik16 + k_block_size_up * KERNEL_M_STEP * im4] =
|
||||
// curr_a_ptr_typed[INDEXING_A((ik16 * KERNEL_K_STEP + ik), (im4 * KERNEL_M_STEP + im), lda)] + oa;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if (k_resid) {
|
||||
// for (size_t im = 0; im < KERNEL_M_STEP; ++im) {
|
||||
// for (size_t ik = 0; ik < k_resid; ++ik) {
|
||||
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * KERNEL_M_STEP * k_portions + k_block_size_up * KERNEL_M_STEP * im4] =
|
||||
// curr_a_ptr_typed[INDEXING_A((k_portions * KERNEL_K_STEP + ik), (im4 * KERNEL_M_STEP + im), lda)] + oa;
|
||||
// }
|
||||
// }
|
||||
// for (size_t im = 0; im < KERNEL_M_STEP; ++im) {
|
||||
// for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
|
||||
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * KERNEL_M_STEP * k_portions + k_block_size_up * KERNEL_M_STEP * im4] = 0;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if (m_resid) {
|
||||
// for (size_t ik16 = 0; ik16 < k_portions; ++ik16) {
|
||||
// for (size_t im = 0; im < m_resid; ++im) {
|
||||
// for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
|
||||
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * m_resid * ik16 + k_block_size_up * KERNEL_M_STEP * m_portions] =
|
||||
// curr_a_ptr_typed[INDEXING_A((ik16 * KERNEL_K_STEP + ik), (m_portions * KERNEL_M_STEP + im), lda)] + oa;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if (k_resid) {
|
||||
// for (size_t im = 0; im < m_resid; ++im) {
|
||||
// for (size_t ik = 0; ik < k_resid; ++ik) {
|
||||
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * m_resid * k_portions + k_block_size_up * KERNEL_M_STEP * m_portions] =
|
||||
// curr_a_ptr_typed[INDEXING_A((k_portions * KERNEL_K_STEP + ik), (m_portions * KERNEL_M_STEP + im), lda)] + oa;
|
||||
// }
|
||||
// }
|
||||
// for (size_t im = 0; im < m_resid; ++im) {
|
||||
// for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
|
||||
// bufferA_typed[ik + im * KERNEL_K_STEP + KERNEL_K_STEP * m_resid * k_portions + k_block_size_up * KERNEL_M_STEP * m_portions] = 0;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif /* __cplusplus */
|
||||
@@ -1,96 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#if defined(TRANSB)
|
||||
// row major
|
||||
#define INDEXING_B(row_idx, col_idx, ldb) ((col_idx) * (ldb) + row_idx)
|
||||
#define ADD_B_SUFF(name) ADD_PACK_B_T_SUFF(name)
|
||||
#elif defined(NOTRANSB)
|
||||
// col major
|
||||
#define INDEXING_B(row_idx, col_idx, ldb) ((row_idx) * (ldb) + col_idx)
|
||||
#define ADD_B_SUFF(name) ADD_PACK_B_N_SUFF(name)
|
||||
#else
|
||||
#error "Neither TRANSB or NOTRANSB is defined."
|
||||
#endif
|
||||
|
||||
void ADD_B_SUFF(pack_b)(void* bufferB, const void* curr_b_ptr, size_t n_block_size, size_t k_block_size, size_t ldb, const BLASINT8 ob) {
|
||||
RHS_INT_TYPE* bufferB_typed = (RHS_INT_TYPE*) bufferB;
|
||||
RHS_INT_TYPE* curr_b_ptr_typed = (RHS_INT_TYPE*) curr_b_ptr;
|
||||
size_t k_block_size_up = (k_block_size + KERNEL_K_STEP - 1) / KERNEL_K_STEP * KERNEL_K_STEP;
|
||||
size_t k_portions = k_block_size / KERNEL_K_STEP;
|
||||
size_t k_resid = k_block_size - KERNEL_K_STEP * k_portions;
|
||||
|
||||
size_t n_portions = n_block_size / KERNEL_N_STEP;
|
||||
size_t n_resid = n_block_size - KERNEL_N_STEP * n_portions;
|
||||
|
||||
for (size_t in4 = 0; in4 < n_portions; ++in4) {
|
||||
for (size_t in = 0; in < KERNEL_N_STEP; ++in) {
|
||||
for (size_t ik16 = 0; ik16 < k_block_size / KERNEL_K_STEP; ++ik16) {
|
||||
for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
|
||||
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * KERNEL_N_STEP * ik16 + k_block_size_up * KERNEL_N_STEP * in4] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * in4 + in), (KERNEL_K_STEP * ik16 + ik), ldb)] + ob;
|
||||
}
|
||||
}
|
||||
|
||||
if (k_resid) {
|
||||
for (size_t ik = 0; ik < k_resid; ++ik) {
|
||||
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * KERNEL_N_STEP * k_portions + k_block_size_up * KERNEL_N_STEP * in4] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * in4 + in), (KERNEL_K_STEP * k_portions + ik), ldb)] + ob;
|
||||
}
|
||||
for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
|
||||
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * KERNEL_N_STEP * k_portions + k_block_size_up * KERNEL_N_STEP * in4] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (n_resid) {
|
||||
for (size_t in = 0; in < n_resid; ++in) {
|
||||
for (size_t ik16 = 0; ik16 < k_block_size / KERNEL_K_STEP; ++ik16) {
|
||||
for (size_t ik = 0; ik < KERNEL_K_STEP; ++ik) {
|
||||
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * n_resid * ik16 + k_block_size_up * KERNEL_N_STEP * n_portions] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * n_portions + in), (KERNEL_K_STEP * ik16 + ik), ldb)] + ob;
|
||||
}
|
||||
}
|
||||
if (k_resid) {
|
||||
for (size_t ik = 0; ik < k_resid; ++ik) {
|
||||
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * n_resid * k_portions + k_block_size_up * KERNEL_N_STEP * n_portions] = curr_b_ptr_typed[INDEXING_B((KERNEL_N_STEP * n_portions + in), (KERNEL_K_STEP * k_portions + ik), ldb)] + ob;
|
||||
}
|
||||
for (size_t ik = k_resid; ik < KERNEL_K_STEP; ++ik) {
|
||||
bufferB_typed[ik + KERNEL_K_STEP * in + KERNEL_K_STEP * n_resid * k_portions + k_block_size_up * KERNEL_N_STEP * n_portions] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// printf("n_block_size:%lu ,k_block_size: %lu\n", n_block_size, k_block_size);
|
||||
|
||||
// for(size_t n_idx = 0; n_idx < n_block_size; n_idx++) {
|
||||
// for(size_t k_idx = 0; k_idx < k_block_size; k_idx++) {
|
||||
// size_t old_split_n = n_idx / OLD_N_SIZE;
|
||||
// size_t old_idx_n = n_idx % OLD_N_SIZE;
|
||||
// size_t new_split_n = n_idx / NEW_N_SIZE;
|
||||
// size_t new_idx_n = n_idx % NEW_N_SIZE;
|
||||
// size_t split_k = k_idx / KERNEL_K_STEP;
|
||||
// size_t idx_k = k_idx % KERNEL_K_STEP;
|
||||
|
||||
// size_t old_buff_idx =
|
||||
// old_split_n * OLD_N_SIZE * ldb +
|
||||
// split_k * OLD_N_SIZE * KERNEL_K_STEP +
|
||||
// old_idx_n * KERNEL_K_STEP +
|
||||
// idx_k;
|
||||
// size_t new_buff_idx =
|
||||
// new_split_n * NEW_N_SIZE * k_block_size +
|
||||
// split_k * NEW_N_SIZE * KERNEL_K_STEP +
|
||||
// new_idx_n * KERNEL_K_STEP +
|
||||
// idx_k;
|
||||
// bufferB_typed[new_buff_idx] = curr_b_ptr_typed[old_buff_idx] + ob;
|
||||
// }
|
||||
// }
|
||||
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif /* __cplusplus */
|
||||
@@ -1,25 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
#include "helping_macros.h"
|
||||
#include "beta_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
// Fixed OC
|
||||
void BETA_SUFF(post_ops)(float alpha, const int32_t* bufferC, int32_t* current_c_ptr, size_t m, size_t n_block, size_t ldc) {
|
||||
float* current_c_float_ptr = (float*) current_c_ptr;
|
||||
double double_alpha = (double) alpha;
|
||||
for (size_t n_idx = 0; n_idx < n_block; ++n_idx) {
|
||||
for (size_t m_idx = 0; m_idx < m; ++m_idx) {
|
||||
current_c_ptr[m_idx + n_idx * ldc] = round(((double)(current_c_float_ptr[m_idx + n_idx * ldc]))
|
||||
+ double_alpha * ((double) bufferC[m_idx + n_idx * LDC(m, ldc)]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif /* __cplusplus */
|
||||
@@ -1,35 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include "integer_gemm_kernels.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
// here we will use BLASINT8, because we care only about type size, not signness (no math operations required)
|
||||
void gemm_impl_8bit(int_gemm_funcs* arg, size_t m, size_t n, size_t k, float alpha,
|
||||
const void* a, size_t lda, const BLASINT8 oa,
|
||||
const void* b, size_t ldb, const BLASINT8 ob,
|
||||
float beta, int32_t* c, size_t ldc, const int32_t* oc, size_t small_switch) {
|
||||
|
||||
void (*small_kernel)(const size_t, const size_t, const size_t, const float,
|
||||
const void *, const size_t, const BLASINT8,
|
||||
const void *, const size_t, const BLASINT8,
|
||||
const float, int32_t *, const size_t, const int32_t *) = arg->small_kernel;
|
||||
|
||||
if (m * n * k < small_switch) { // experimentally measured constant
|
||||
small_kernel(m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
} else {
|
||||
// Corner cases optimizations
|
||||
if ((alpha == 1.0f) && (beta == 0.0f || beta == 1.0f)) {
|
||||
gemm_driver_opt(arg, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
} else {
|
||||
gemm_driver(arg, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif /* __cplusplus */
|
||||
@@ -1,30 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.14.1)
|
||||
|
||||
project(THREAD_TEST)
|
||||
|
||||
|
||||
set(CMAKE_CXX_FLAGS ${CFLAGS_OPT})
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
|
||||
|
||||
|
||||
message(STATUS "Current source dir: ${CMAKE_CURRENT_SOURCE_DIR}/..")
|
||||
set(BLAS_PATH "")
|
||||
message(STATUS "BLAS_PATH=${BLAS_PATH}")
|
||||
|
||||
if ("${BLAS_PATH}" STREQUAL "")
|
||||
|
||||
set(BLAS_LIB "${CMAKE_CURRENT_SOURCE_DIR}/../libint8gemm.so")
|
||||
message(STATUS "BLAS_LIB=${BLAS_LIB}")
|
||||
|
||||
# Add threading library to linker
|
||||
#find_package(Threads)
|
||||
|
||||
add_executable(integer_tester integer_gemm.cpp)
|
||||
#target_include_directories(integer_tester PUBLIC "${BLAS_PATH}")
|
||||
#target_include_directories(integer_tester PUBLIC "${BLAS_BUILD_PATH}")
|
||||
target_link_libraries(integer_tester PRIVATE ${BLAS_LIB})
|
||||
set_property(TARGET integer_tester PROPERTY CXX_STANDARD 17)
|
||||
add_test(integer_tester ${CMAKE_CURRENT_BINARY_DIR}/integer_tester)
|
||||
else ()
|
||||
message(FATAL_ERROR "Can not find int8_gemm path, pls set right path!")
|
||||
endif ()
|
||||
@@ -1,438 +0,0 @@
|
||||
#include <malloc.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
/* matrix saved in rows or cols */
|
||||
typedef enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 } CBLAS_ORDER;
|
||||
|
||||
/* matrix transpose or conjugate transpose */
|
||||
typedef enum CBLAS_TRANSPOSE {
|
||||
CblasNoTrans = 111,
|
||||
CblasTrans = 112,
|
||||
CblasConjTrans = 113, // conjugate transpose
|
||||
CblasConjNoTrans = 114
|
||||
} CBLAS_TRANSPOSE;
|
||||
|
||||
typedef CBLAS_ORDER CBLAS_LAYOUT;
|
||||
|
||||
typedef enum CBLAS_OFFSET { CblasRowOffset = 171, CblasColOffset = 172, CblasFixOffset = 173 } CBLAS_OFFSET;
|
||||
|
||||
typedef int8_t BLASINT8;
|
||||
typedef uint8_t BLASUINT8;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
void cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
|
||||
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void cblas_gemm_u8u8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
|
||||
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void cblas_gemm_s8u8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
|
||||
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void cblas_gemm_u8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
|
||||
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const void* a, const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif /* __cplusplus */
|
||||
|
||||
namespace test {
|
||||
|
||||
namespace tools {
|
||||
|
||||
template <typename T, std::size_t alignment = 128>
|
||||
struct aligned_allocator {
|
||||
using value_type = T;
|
||||
using pointer = T*;
|
||||
using const_pointer = const T*;
|
||||
using reference = T&;
|
||||
using const_reference = const T&;
|
||||
using size_type = std::size_t;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
|
||||
template <typename U>
|
||||
struct rebind {
|
||||
typedef aligned_allocator<U, alignment> other;
|
||||
};
|
||||
|
||||
[[nodiscard]] T* allocate(std::size_t n) {
|
||||
if (n > std::numeric_limits<std::size_t>::max() / sizeof(T)) throw std::bad_array_new_length();
|
||||
if (auto p = static_cast<T*>(memalign(alignment, n * sizeof(T)))) {
|
||||
return p;
|
||||
}
|
||||
|
||||
throw std::bad_alloc();
|
||||
}
|
||||
|
||||
void deallocate(T* p, std::size_t n) noexcept {
|
||||
(void)(n);
|
||||
free(p);
|
||||
}
|
||||
|
||||
~aligned_allocator() {}
|
||||
};
|
||||
template <typename T, std::size_t alignment_1, typename U, std::size_t alignment_2>
|
||||
bool operator==(const aligned_allocator<T, alignment_1>&, const aligned_allocator<U, alignment_2>&) {
|
||||
return (alignment_1 == alignment_2) && std::is_same_v<T, U>;
|
||||
}
|
||||
|
||||
template <typename T, std::size_t alignment_1, typename U, std::size_t alignment_2>
|
||||
bool operator!=(const aligned_allocator<T, alignment_1>& lhs, const aligned_allocator<U, alignment_2>& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
template <typename Func, typename... Args>
|
||||
double timing(Func&& func, Args&&... args) {
|
||||
double time = 0.0;
|
||||
double time_begin = 0.0;
|
||||
std::size_t n_run = 0;
|
||||
|
||||
auto start_begin = std::chrono::steady_clock::now();
|
||||
std::forward<Func>(func)(std::forward<Args>(args)...);
|
||||
auto end_begin = std::chrono::steady_clock::now();
|
||||
|
||||
time_begin = std::chrono::duration_cast<std::chrono::nanoseconds>(end_begin - start_begin).count() / 1e9;
|
||||
n_run = std::max<std::size_t>(std::size_t(1.0 / time_begin), 3);
|
||||
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
for (std::size_t i = 0; i < n_run; ++i) {
|
||||
std::forward<Func>(func)(std::forward<Args>(args)...);
|
||||
}
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
|
||||
time += std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count() / 1e9;
|
||||
return time / n_run;
|
||||
}
|
||||
} // namespace tools
|
||||
namespace helpers {
|
||||
std::size_t get_oc_size(CBLAS_OFFSET offset, std::size_t m, std::size_t n) {
|
||||
std::size_t ret_val = 0;
|
||||
switch (offset) {
|
||||
case CblasFixOffset:
|
||||
ret_val = 1;
|
||||
break;
|
||||
case CblasColOffset:
|
||||
ret_val = m;
|
||||
break;
|
||||
case CblasRowOffset:
|
||||
ret_val = n;
|
||||
break;
|
||||
default:
|
||||
std::cout << "Incorrect value of offset to the function " << __PRETTY_FUNCTION__ << std::endl;
|
||||
}
|
||||
return ret_val;
|
||||
}
|
||||
template <typename T>
|
||||
auto get_ab_matrix(CBLAS_LAYOUT lt, CBLAS_TRANSPOSE trans_, T&& non_trans_mtx, T&& trans_mtx) {
|
||||
if (lt == CblasColMajor) {
|
||||
if (trans_ == CblasNoTrans) {
|
||||
return non_trans_mtx.data();
|
||||
} else {
|
||||
return trans_mtx.data();
|
||||
}
|
||||
} else {
|
||||
if (trans_ == CblasNoTrans) {
|
||||
return trans_mtx.data();
|
||||
} else {
|
||||
return non_trans_mtx.data();
|
||||
}
|
||||
}
|
||||
}
|
||||
auto get_ldab(CBLAS_LAYOUT lt, CBLAS_TRANSPOSE trans_mtx, std::size_t ld_n, std::size_t ld_t) {
|
||||
if (lt == CblasColMajor) {
|
||||
if (trans_mtx == CblasNoTrans) {
|
||||
return ld_n;
|
||||
} else {
|
||||
return ld_t;
|
||||
}
|
||||
} else {
|
||||
if (trans_mtx == CblasNoTrans) {
|
||||
return ld_t;
|
||||
} else {
|
||||
return ld_n;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// returns copy of the matrix
|
||||
template <typename T>
|
||||
auto get_c_matrix(CBLAS_LAYOUT lt, T&& non_trans_mtx, T&& trans_mtx) {
|
||||
if (lt == CblasColMajor) {
|
||||
return non_trans_mtx;
|
||||
} else {
|
||||
return trans_mtx;
|
||||
}
|
||||
}
|
||||
|
||||
auto get_ldc(CBLAS_LAYOUT lt, std::size_t ldc_n, std::size_t ldc_t) {
|
||||
if (lt == CblasColMajor) {
|
||||
return ldc_n;
|
||||
} else {
|
||||
return ldc_t;
|
||||
}
|
||||
}
|
||||
template <typename A_Type, typename B_Type>
|
||||
void cblas_gemm_wrapper(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
|
||||
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k, const float alpha,
|
||||
const A_Type* a, const size_t lda, const int8_t oa, const B_Type* b, const size_t ldb,
|
||||
const int8_t ob, const float beta, int32_t* c, const size_t ldc, const int32_t* oc) {
|
||||
if constexpr (std::is_same_v<A_Type, std::int8_t>) {
|
||||
if constexpr (std::is_same_v<B_Type, std::int8_t>) {
|
||||
cblas_gemm_s8s8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
} else {
|
||||
cblas_gemm_s8u8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same_v<B_Type, std::int8_t>) {
|
||||
cblas_gemm_u8s8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
} else {
|
||||
cblas_gemm_u8u8s32(Layout, transa, transb, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t return_oc_idx(const CBLAS_OFFSET offsetc, std::size_t mi, std::size_t ni) {
|
||||
return (offsetc == CblasFixOffset) ? 0 : ((offsetc == CblasColOffset) ? mi : ni);
|
||||
}
|
||||
} // namespace helpers
|
||||
enum class status_t { passed, failed };
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const status_t& st) {
|
||||
if (status_t::passed == st) {
|
||||
os << "PASSED";
|
||||
} else if (status_t::failed == st) {
|
||||
os << "FAILED";
|
||||
}
|
||||
return os;
|
||||
}
|
||||
// column major
|
||||
template <typename A_Type, typename B_Type>
|
||||
void ref_gemm(const CBLAS_OFFSET offsetc, const std::size_t m, const std::size_t n, const std::size_t k,
|
||||
const float alpha, const A_Type* a, const std::size_t lda, const std::int8_t oa, const B_Type* b,
|
||||
const std::size_t ldb, const std::int8_t ob, const float beta, std::int32_t* c, const std::size_t ldc,
|
||||
const std::int32_t* oc) {
|
||||
for (std::size_t mi = 0; mi < m; ++mi) {
|
||||
for (std::size_t ni = 0; ni < n; ++ni) {
|
||||
std::int32_t tmp = 0;
|
||||
for (std::size_t ki = 0; ki < k; ++ki) {
|
||||
tmp += (a[mi + ki * lda] + oa) * (b[ki + ni * ldb] + ob);
|
||||
}
|
||||
c[mi + ni * ldc] = std::round(alpha * static_cast<double>(tmp) +
|
||||
static_cast<double>(beta * static_cast<float>(c[mi + ni * ldc])) +
|
||||
static_cast<float>(oc[helpers::return_oc_idx(offsetc, mi, ni)]));
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename DataType>
|
||||
void fill_random(DataType* buffer, std::size_t len) {
|
||||
static std::mt19937 generator(0);
|
||||
std::uniform_int_distribution<DataType> dist(0, 64);
|
||||
for (std::size_t i = 0; i < len; i++) {
|
||||
buffer[i] = static_cast<DataType>(dist(generator));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
void fill_const(DataType* buffer, std::size_t len) {
|
||||
for (std::size_t i = 0; i < len; i++) {
|
||||
buffer[i] = DataType{-8};
|
||||
}
|
||||
}
|
||||
// performs transposition (n0 * n1) -> (n1 * n0), assuming col major
|
||||
template <typename T>
|
||||
void simplest_transpose(T* in, T* out, std::size_t n0, std::size_t n1, std::size_t ld0, std::size_t ld1) {
|
||||
for (std::size_t i = 0; i < n0; ++i) {
|
||||
for (std::size_t j = 0; j < n1; ++j) {
|
||||
out[i + j * ld1] = in[j + i * ld0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
status_t compare(DataType* ref, DataType* test, std::size_t m, std::size_t n, std::size_t ld) {
|
||||
for (std::size_t mi = 0; mi < m; ++mi) {
|
||||
for (std::size_t ni = 0; ni < n; ++ni) {
|
||||
if (ref[mi + ni * ld] != test[mi + ni * ld]) {
|
||||
return status_t::failed;
|
||||
}
|
||||
}
|
||||
}
|
||||
return status_t::passed;
|
||||
}
|
||||
template <typename DataType>
|
||||
void print_matrix(DataType* buffer, std::size_t m, std::size_t n) {
|
||||
for (std::size_t mi = 0; mi < m; ++mi) {
|
||||
for (std::size_t ni = 0; ni < n; ++ni) {
|
||||
std::cout << static_cast<std::int32_t>(buffer[mi + ni * m]) << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
template <typename A_Type, typename B_Type>
|
||||
status_t gemm(std::size_t m, std::size_t n, std::size_t k, float alpha, float beta) {
|
||||
std::int8_t oa = 4;
|
||||
std::int8_t ob = 9;
|
||||
|
||||
std::size_t lda_n = m;
|
||||
std::size_t ldb_n = k;
|
||||
std::size_t ldc_n = m;
|
||||
|
||||
std::size_t lda_t = k;
|
||||
std::size_t ldb_t = n;
|
||||
std::size_t ldc_t = n;
|
||||
|
||||
if (std::getenv("LD_STRIDE")) {
|
||||
lda_n += 2;
|
||||
ldb_n += 7;
|
||||
ldc_n += 3;
|
||||
|
||||
lda_t += 8;
|
||||
ldb_t += 3;
|
||||
ldc_t += 23;
|
||||
}
|
||||
|
||||
bool only_performance = false;
|
||||
if (std::getenv("ONLY_PERF")) {
|
||||
only_performance = true;
|
||||
}
|
||||
std::vector<A_Type, tools::aligned_allocator<A_Type, 128>> a_n(lda_n * k);
|
||||
std::vector<A_Type, tools::aligned_allocator<A_Type, 128>> a_t(m * lda_t);
|
||||
std::vector<B_Type, tools::aligned_allocator<B_Type, 128>> b_n(ldb_n * n);
|
||||
std::vector<B_Type, tools::aligned_allocator<B_Type, 128>> b_t(k * ldb_t);
|
||||
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_ref(ldc_n * n);
|
||||
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_n(ldc_n * n);
|
||||
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_t(m * ldc_t);
|
||||
|
||||
// fill the whole array even if ld* > corresponding dim
|
||||
fill_random(a_n.data(), a_n.size());
|
||||
fill_random(b_n.data(), b_n.size());
|
||||
|
||||
simplest_transpose(a_n.data(), a_t.data(), m, k, lda_n, lda_t);
|
||||
simplest_transpose(b_n.data(), b_t.data(), k, n, ldb_n, ldb_t);
|
||||
|
||||
fill_const(c_ref.data(), ldc_n * n);
|
||||
c_n = c_ref;
|
||||
|
||||
simplest_transpose(c_n.data(), c_t.data(), m, ldc_n, n, ldc_t);
|
||||
|
||||
auto return_st = status_t::passed;
|
||||
double total = 0;
|
||||
size_t cnt = 0;
|
||||
for (auto c_offset : {CblasFixOffset, CblasColOffset, CblasRowOffset}) {
|
||||
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> oc(helpers::get_oc_size(c_offset, m, n));
|
||||
for (std::size_t i = 0; i < oc.size(); ++i) {
|
||||
oc[i] = i + i;
|
||||
}
|
||||
|
||||
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_ref_copy = c_ref;
|
||||
if (!only_performance) {
|
||||
ref_gemm(c_offset, m, n, k, alpha, a_n.data(), lda_n, oa, b_n.data(), ldb_n, ob, beta, c_ref_copy.data(), ldc_n,
|
||||
oc.data());
|
||||
}
|
||||
for (auto layout : {CblasColMajor, CblasRowMajor}) {
|
||||
for (auto transa : {CblasNoTrans, CblasTrans}) {
|
||||
for (auto transb : {CblasNoTrans, CblasTrans}) {
|
||||
auto&& c_tested = helpers::get_c_matrix(layout, c_n, c_t);
|
||||
if (!only_performance) {
|
||||
helpers::cblas_gemm_wrapper(
|
||||
layout, transa, transb, c_offset, m, n, k, alpha, helpers::get_ab_matrix(layout, transa, a_n, a_t),
|
||||
helpers::get_ldab(layout, transa, lda_n, lda_t), oa, helpers::get_ab_matrix(layout, transb, b_n, b_t),
|
||||
helpers::get_ldab(layout, transb, ldb_n, ldb_t), ob, beta, c_tested.data(),
|
||||
helpers::get_ldc(layout, ldc_n, ldc_t), oc.data());
|
||||
|
||||
// transpose c_tested to col-major if required
|
||||
auto loc_st = status_t::passed;
|
||||
if (layout == CblasRowMajor) {
|
||||
std::vector<std::int32_t, tools::aligned_allocator<std::int32_t, 128>> c_tested_n(ldc_n * n);
|
||||
simplest_transpose(c_tested.data(), c_tested_n.data(), n, ldc_t, m, ldc_n);
|
||||
loc_st = compare(c_ref_copy.data(), c_tested_n.data(), m, n, ldc_n);
|
||||
} else {
|
||||
loc_st = compare(c_ref_copy.data(), c_tested.data(), m, n, ldc_n);
|
||||
}
|
||||
if (loc_st != status_t::passed) {
|
||||
std::cout << "-";
|
||||
return_st = status_t::failed;
|
||||
} else {
|
||||
std::cout << "+";
|
||||
}
|
||||
} else {
|
||||
double cur = (2.0 * m * n * k) /
|
||||
tools::timing(helpers::cblas_gemm_wrapper<A_Type, B_Type>, layout, transa, transb, c_offset, m,
|
||||
n, k, alpha, helpers::get_ab_matrix(layout, transa, a_n, a_t),
|
||||
helpers::get_ldab(layout, transa, lda_n, lda_t), oa,
|
||||
helpers::get_ab_matrix(layout, transb, b_n, b_t),
|
||||
helpers::get_ldab(layout, transb, ldb_n, ldb_t), ob, beta, c_tested.data(),
|
||||
helpers::get_ldc(layout, ldc_n, ldc_t), oc.data()) /
|
||||
1e12;
|
||||
total += cur;
|
||||
++cnt;
|
||||
|
||||
std::cout << cur << ", ";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (only_performance) {
|
||||
std::cout << "Average " << total / cnt << " TFlops";
|
||||
}
|
||||
std::cout << " ";
|
||||
return return_st;
|
||||
}
|
||||
} // namespace test
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::size_t m = 128;
|
||||
std::size_t n = 128;
|
||||
std::size_t k = 128;
|
||||
float alpha = 1.0f;
|
||||
float beta = 1.0f;
|
||||
|
||||
if (argc > 1) {
|
||||
m = std::stoi(argv[1]);
|
||||
if (argc > 2) {
|
||||
n = std::stoi(argv[2]);
|
||||
if (argc > 3) {
|
||||
k = std::stoi(argv[3]);
|
||||
if (argc > 4) {
|
||||
alpha = std::stof(argv[4]);
|
||||
if (argc > 5) {
|
||||
beta = std::stof(argv[5]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cout << "Testing matrix m = " << m << ", n = " << n << ", k = " << k << ", alpha = " << alpha
|
||||
<< ", beta = " << beta << std::endl;
|
||||
|
||||
std::cout << "\tTesting i8i8i32: " << test::gemm<std::int8_t, std::int8_t>(m, n, k, alpha, beta) << std::endl;
|
||||
std::cout << "\tTesting i8u8i32: " << test::gemm<std::int8_t, std::uint8_t>(m, n, k, alpha, beta) << std::endl;
|
||||
std::cout << "\tTesting u8i8i32: " << test::gemm<std::uint8_t, std::int8_t>(m, n, k, alpha, beta) << std::endl;
|
||||
std::cout << "\tTesting u8u8i32: " << test::gemm<std::uint8_t, std::uint8_t>(m, n, k, alpha, beta) << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
|
||||
#pragma once
|
||||
/*** gemm helper ***/
|
||||
#include "../../api/common.h"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#define COMP_SV_LEN 32
|
||||
#define K_SIZE COMP_SV_LEN
|
||||
#define M_SIZE 1
|
||||
#define N_SIZE 8
|
||||
|
||||
#define INDEXING_B(row_idx, col_idx, ldb) ((col_idx) * (ldb) + row_idx)
|
||||
|
||||
#define PROCESS_ACCUM(reg_idx, z_reg_idx, tmp_reg, dst, p) \
|
||||
"mov w" #reg_idx \
|
||||
", #0\n" \
|
||||
"saddv d" #reg_idx ", " #p ", z" #z_reg_idx \
|
||||
".s\n" \
|
||||
"fmov " #tmp_reg ", d" #reg_idx \
|
||||
"\n" \
|
||||
"add x" #reg_idx ", x" #reg_idx ", " #tmp_reg \
|
||||
"\n" \
|
||||
"str w" #reg_idx ", [%[" #dst "]], #4\n"
|
||||
|
||||
#define INT4_CP_MASK_SHIFT_1x8(src_reg, dst_reg, mask_reg1, mask_reg2, shift) \
|
||||
"movprfx z" #dst_reg ", z" #src_reg \
|
||||
"\n" \
|
||||
"lsl z" #dst_reg ".b, p0/m, z" #dst_reg ".b, #" #shift \
|
||||
"\n" \
|
||||
"and z" #src_reg ".b, p0/m, z" #src_reg ".b, z" #mask_reg1 ".b\n"
|
||||
|
||||
void pack_b_1x8(void* bufferB, const void* cur_b_ptr, size_t n, size_t k, size_t ldb, const BLASINT8 ob);
|
||||
void pack_b_1x8_int4(void* bufferB, const void* cur_b_ptr, size_t n, size_t k, size_t ldb, const BLASINT8 ob);
|
||||
|
||||
void gemm_kernel_1x8(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_1x8_int4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
/*** gemm helper ***/
|
||||
Reference in New Issue
Block a user