mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-23 07:59:25 +00:00
bitnet: python + llama
This commit is contained in:
@@ -1400,6 +1400,35 @@ class LlamaModel(Model):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("BitnetForCausalLM")
|
||||
class BitnetModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.BITNET
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(1.0)
|
||||
|
||||
def weight_quant(self, weight):
|
||||
dtype = weight.dtype
|
||||
weight = weight.float()
|
||||
s = 1 / weight.abs().mean().clamp(min=1e-5)
|
||||
result = (weight * s).round().clamp(-1, 1) / s
|
||||
return result.type(dtype)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# transform weight into 1/0/-1 (in fp32)
|
||||
if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
|
||||
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
|
||||
"o_proj.weight")):
|
||||
data_torch = self.weight_quant(data_torch)
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("GrokForCausalLM")
|
||||
class GrokModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GROK
|
||||
|
||||
@@ -149,6 +149,7 @@ class MODEL_ARCH(IntEnum):
|
||||
OLMO = auto()
|
||||
ARCTIC = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
BITNET = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
@@ -200,6 +201,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
ATTN_KV_B = auto()
|
||||
ATTN_Q_A_NORM = auto()
|
||||
ATTN_KV_A_NORM = auto()
|
||||
FFN_SUB_NORM = auto()
|
||||
ATTN_SUB_NORM = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
@@ -237,6 +240,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.OLMO: "olmo",
|
||||
MODEL_ARCH.ARCTIC: "arctic",
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
MODEL_ARCH.BITNET: "bitnet",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
@@ -288,6 +292,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
|
||||
MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
@@ -808,6 +814,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.BITNET: [
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.ATTN_SUB_NORM,
|
||||
MODEL_TENSOR.FFN_SUB_NORM,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
||||
@@ -413,6 +413,14 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: (
|
||||
"model.layers.{bid}.self_attn.inner_attn_ln", # bitnet
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_SUB_NORM: (
|
||||
"model.layers.{bid}.mlp.ffn_layernorm", # bitnet
|
||||
),
|
||||
}
|
||||
|
||||
# architecture-specific block mappings
|
||||
|
||||
437
iqk-quantize.cpp
Normal file
437
iqk-quantize.cpp
Normal file
@@ -0,0 +1,437 @@
|
||||
#include "ggml-quants.h"
|
||||
#include "ggml-impl.h"
|
||||
#define GGML_COMMON_IMPL_C
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <cstdint>
|
||||
#include <cmath>
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
|
||||
namespace {
|
||||
|
||||
inline int nearest_int(float fval) {
|
||||
assert(fval <= 4194303.f);
|
||||
float val = fval + 12582912.f;
|
||||
int i; memcpy(&i, &val, sizeof(int));
|
||||
return (i & 0x007fffff) - 0x00400000;
|
||||
}
|
||||
|
||||
struct IQ1BNData {
|
||||
IQ1BNData();
|
||||
std::vector<std::pair<int16_t, bool>> map;
|
||||
std::vector<uint16_t> rmap;
|
||||
};
|
||||
|
||||
const IQ1BNData& get_iq1bn_data() {
|
||||
static std::mutex mutex;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
static IQ1BNData iq1bn;
|
||||
return iq1bn;
|
||||
}
|
||||
|
||||
IQ1BNData::IQ1BNData() {
|
||||
map.resize(1 << 16, {int16_t(-1), false});
|
||||
uint64_t aux64;
|
||||
uint8_t * aux8 = (uint8_t *)&aux64;
|
||||
std::vector<uint64_t> values;
|
||||
values.reserve(6561);
|
||||
rmap.reserve(6561);
|
||||
for (int i = 0; i < (1 << 16); ++i) {
|
||||
bool is_good = true;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
aux8[j] = (i >> 2*j) & 3;
|
||||
if (aux8[j] == 3u) { is_good = false; break; }
|
||||
}
|
||||
if (!is_good) continue;
|
||||
auto orig = aux64;
|
||||
for (int j = 0; j < 8; ++j) aux8[j] = 2 - aux8[j];
|
||||
int k = 0;
|
||||
for (; k < int(values.size()); ++k) {
|
||||
if (values[k] == aux64) break;
|
||||
}
|
||||
if (k < int(values.size())) {
|
||||
map[i] = {k, true};
|
||||
} else {
|
||||
map[i].first = values.size();
|
||||
values.push_back(orig);
|
||||
rmap.push_back(i);
|
||||
}
|
||||
}
|
||||
printf("==================== %s: initialized %d grid points\n", __func__, int(rmap.size()));
|
||||
}
|
||||
|
||||
struct IQ1BNQuantizer {
|
||||
typedef union {
|
||||
float f;
|
||||
uint32_t i;
|
||||
} scale_t;
|
||||
constexpr static int block_size = QK_IQ1BN;
|
||||
int8_t L[QK_IQ1BN];
|
||||
void quantize_one_row(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix);
|
||||
};
|
||||
|
||||
void IQ1BNQuantizer::quantize_one_row(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix) {
|
||||
|
||||
(void)imatrix;
|
||||
|
||||
constexpr int Nk = block_size/8;
|
||||
|
||||
const int nblock = n_per_row/QK_IQ1BN;
|
||||
|
||||
const auto& iq1bn = get_iq1bn_data();
|
||||
|
||||
float max_in_row = 0;
|
||||
for (int j = 0; j < n_per_row; ++j) {
|
||||
float ax = fabsf(src[j]);
|
||||
max_in_row = std::max(max_in_row, ax);
|
||||
}
|
||||
|
||||
max_in_row *= 1.03125f; // i.e., round to nearest in our fp8 representation
|
||||
scale_t s;
|
||||
uint8_t u = 0;
|
||||
if (max_in_row > 1.9074e-06f && max_in_row < 0.12109f) {
|
||||
s.f = max_in_row;
|
||||
u = ((((s.i >> 23) + 132) & 0xf) << 4) | ((s.i >> 19) & 0xf);
|
||||
s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
|
||||
} else {
|
||||
// outside the allowed range. Small values we can habdle via quants set to zero, so we only warn about too large values
|
||||
if (max_in_row >= 0.12109f) {
|
||||
u = 255;
|
||||
fprintf(stderr, "%s: found scale %g, which is outside the range of out fp8 representation\n", __func__, max_in_row);
|
||||
} else{
|
||||
u = 0;
|
||||
}
|
||||
}
|
||||
|
||||
for (int ib = 0; ib < nblock; ++ib) {
|
||||
std::memset(&y[ib], 0, sizeof(block_iq1_bn));
|
||||
auto xb = src + QK_IQ1BN*ib;
|
||||
for (int j = 0; j < QK_IQ1BN; ++j) {
|
||||
L[j] = fabsf(xb[j]) < 1e-6f ? 1 : xb[j] < 0 ? 0 : 2;
|
||||
}
|
||||
auto ql = y[ib].ql;
|
||||
auto qh = y[ib].qh;
|
||||
uint16_t extra = 0;
|
||||
for (int k = 0; k < Nk; ++k) {
|
||||
auto Lk = L + 8*k;
|
||||
uint16_t u = 0;
|
||||
for (int j = 0; j < 8; ++j) u |= (Lk[j] << 2*j);
|
||||
auto& val = iq1bn.map[u];
|
||||
GGML_ASSERT(val.first >= 0);
|
||||
ql[k] = val.first & 255;
|
||||
qh[k/2] |= (val.first >> 8) << 4*(k%2);
|
||||
if (val.second) extra |= (1 << k);
|
||||
}
|
||||
|
||||
y[ib].extra = u | (extra << 8);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void iq1bn_init_impl(void) {
|
||||
get_iq1bn_data();
|
||||
}
|
||||
|
||||
size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
|
||||
IQ1BNQuantizer iq1bn;
|
||||
int nblock = n_per_row/QK_IQ1BN;
|
||||
block_iq1_bn * y = (block_iq1_bn *)dst;
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
iq1bn.quantize_one_row(src + row*n_per_row, y, n_per_row, imatrix);
|
||||
y += nblock;
|
||||
}
|
||||
return sizeof(block_iq1_bn)*nblock*nrows;
|
||||
}
|
||||
|
||||
void quantize_row_iq1_bn_reference(const float * x, block_iq1_bn * y, int64_t k) {
|
||||
quantize_iq1_bn(x, y, 1, k, nullptr);
|
||||
}
|
||||
|
||||
void quantize_row_iq1_bn(const float * x, void * y, int64_t k) {
|
||||
quantize_iq1_bn(x, y, 1, k, nullptr);
|
||||
}
|
||||
|
||||
void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) {
|
||||
assert(k%QK_IQ1BN == 0);
|
||||
int nblock = k / QK_IQ1BN;
|
||||
|
||||
IQ1BNQuantizer::scale_t s;
|
||||
|
||||
for (int i = 0; i < nblock; ++i) {
|
||||
uint16_t u = x[i].extra & 0xff;
|
||||
s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
|
||||
float d = s.f;
|
||||
uint8_t extra = x[i].extra >> 8;
|
||||
auto qh = x[i].qh;
|
||||
auto ql = x[i].ql;
|
||||
for (int k = 0; k < QK_IQ1BN/8; ++k) {
|
||||
uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00);
|
||||
uint16_t val = iq1bn_grid_u16[idx];
|
||||
float dls = extra & (1 << k) ? -d : d;
|
||||
for (int j = 0; j < 8; ++j) y[j] = dls * (((val >> 2*j) & 3) - 1);
|
||||
y += 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if __AVX__ || __AVX2__ || __AVX512F__
|
||||
// horizontally add 8 floats
|
||||
static inline float hsum_float_8(const __m256 x) {
|
||||
__m128 res = _mm256_extractf128_ps(x, 1);
|
||||
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
|
||||
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
||||
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
||||
return _mm_cvtss_f32(res);
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
|
||||
|
||||
GGML_UNUSED(bs);
|
||||
GGML_UNUSED(bx);
|
||||
GGML_UNUSED(by);
|
||||
GGML_UNUSED(nrc);
|
||||
|
||||
static_assert(QK_IQ1BN == 64, "This dot product implementation for iq1_bn requires a block size of 64");
|
||||
|
||||
const block_iq1_bn * x = (const block_iq1_bn *)vx;
|
||||
const block_q8_0 * y = (const block_q8_0 *)vy;
|
||||
int nblock = n / QK_IQ1BN;
|
||||
|
||||
float sumf = 0;
|
||||
IQ1BNQuantizer::scale_t scale;
|
||||
|
||||
#if defined __AVX2__
|
||||
|
||||
const auto m1_8 = _mm256_set1_epi8(1);
|
||||
const auto shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000);
|
||||
const auto shuff2 = _mm256_add_epi8(shuff1, m1_8);
|
||||
const auto shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
|
||||
const auto shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404);
|
||||
const auto mask1 = _mm256_set1_epi64x(0x8040201008040201);
|
||||
#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
|
||||
const auto m1_16 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
|
||||
__m256 acc1 = _mm256_setzero_ps();
|
||||
__m256 acc2 = _mm256_setzero_ps();
|
||||
|
||||
// All scales are the same in BitNet!
|
||||
uint16_t u = x[0].extra & 0xff;
|
||||
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
|
||||
|
||||
for (int i = 0; i < nblock; ++i) {
|
||||
// We would uncomment this if we wanted to use this implementation for a model that has per block scales
|
||||
//uint16_t u = x[i].extra & 0xff;
|
||||
//scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
|
||||
auto signs = _mm256_set1_epi8(x[i].extra >> 8);
|
||||
// signs for groups of 8 ordered as 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ...
|
||||
// To use these to sign the q8 values we need
|
||||
// 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 amd the same for 4...7
|
||||
signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, mask1), mask1), m1_8);
|
||||
auto q8_1 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[2*i+0].qs), _mm256_shuffle_epi8(signs, shuff3));
|
||||
auto q8_2 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[2*i+1].qs), _mm256_shuffle_epi8(signs, shuff4));
|
||||
|
||||
auto ql = x[i].ql;
|
||||
auto qh = x[i].qh;
|
||||
auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)],
|
||||
iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]);
|
||||
auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)],
|
||||
iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]);
|
||||
|
||||
auto v1_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1);
|
||||
auto v1_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1);
|
||||
auto v2_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1);
|
||||
auto v2_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1);
|
||||
|
||||
auto dot1 = _mm256_sub_epi8(_mm256_sign_epi8(q8_1, v1_m), _mm256_sign_epi8(q8_1, v1_p));
|
||||
auto dot2 = _mm256_sub_epi8(_mm256_sign_epi8(q8_2, v2_m), _mm256_sign_epi8(q8_2, v2_p));
|
||||
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1);
|
||||
dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot2);
|
||||
#else
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot1));
|
||||
dot2 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot2));
|
||||
#endif
|
||||
|
||||
// We would uncomment this if we wanted to use this implementation for a model that has per block scales
|
||||
//acc1 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+0].d)), _mm256_cvtepi32_ps(dot1), acc1);
|
||||
//acc2 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+1].d)), _mm256_cvtepi32_ps(dot2), acc2);
|
||||
// All scales are the same for BitNet!
|
||||
// This is slower
|
||||
//uint32_t aux32 = y[2*i+0].d | (y[2*i+1].d << 16);
|
||||
//auto d8 = _mm256_cvtph_ps(_mm_set1_epi32(aux32));
|
||||
//acc1 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x00), _mm256_cvtepi32_ps(dot1), acc1);
|
||||
//acc2 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x55), _mm256_cvtepi32_ps(dot2), acc2);
|
||||
acc1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i+0].d)), _mm256_cvtepi32_ps(dot1), acc1);
|
||||
acc2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i+1].d)), _mm256_cvtepi32_ps(dot2), acc2);
|
||||
|
||||
}
|
||||
|
||||
//sumf = hsum_float_8(_mm256_add_ps(acc1, acc2));
|
||||
sumf = scale.f * hsum_float_8(_mm256_add_ps(acc1, acc2));
|
||||
|
||||
#else
|
||||
|
||||
for (int i = 0; i < nblock; ++i) {
|
||||
uint16_t u = x[i].extra & 0xff;
|
||||
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
|
||||
uint8_t extra = x[i].extra >> 8;
|
||||
auto qh = x[i].qh;
|
||||
auto ql = x[i].ql;
|
||||
auto q8 = y[2*i+0].qs;
|
||||
int16_t sumi1 = 0;
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00);
|
||||
uint16_t val = iq1bn_grid_u16[idx];
|
||||
int16_t sl = 0;
|
||||
for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1);
|
||||
sumi1 += extra & (1 << k) ? -sl : sl;
|
||||
q8 += 8;
|
||||
}
|
||||
q8 = y[2*i+1].qs;
|
||||
int16_t sumi2 = 0;
|
||||
for (int k = 4; k < 8; ++k) {
|
||||
uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00);
|
||||
uint16_t val = iq1bn_grid_u16[idx];
|
||||
int16_t sl = 0;
|
||||
for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1);
|
||||
sumi2 += extra & (1 << k) ? -sl : sl;
|
||||
q8 += 8;
|
||||
}
|
||||
sumf += scale.f * (GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
*s = sumf;
|
||||
|
||||
}
|
||||
|
||||
void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
|
||||
|
||||
GGML_UNUSED(bs);
|
||||
GGML_UNUSED(bx);
|
||||
GGML_UNUSED(by);
|
||||
GGML_UNUSED(nrc);
|
||||
|
||||
static_assert(QK_IQ1BN == 64, "This dot product implementation for iq1_bn requires a block size of 64");
|
||||
|
||||
const block_iq1_bn * x = (const block_iq1_bn *)vx;
|
||||
const block_q8_K64 * y = (const block_q8_K64 *)vy;
|
||||
int nblock = n / QK_IQ1BN;
|
||||
|
||||
float sumf = 0;
|
||||
IQ1BNQuantizer::scale_t scale;
|
||||
|
||||
#if defined __AVX2__
|
||||
|
||||
const auto m1_8 = _mm256_set1_epi8(1);
|
||||
const auto shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000);
|
||||
const auto shuff2 = _mm256_add_epi8(shuff1, m1_8);
|
||||
const auto shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
|
||||
const auto shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404);
|
||||
const auto mask1 = _mm256_set1_epi64x(0x8040201008040201);
|
||||
#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
|
||||
const auto m1_16 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
// All scales are the same in BitNet!
|
||||
uint16_t u = x[0].extra & 0xff;
|
||||
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
|
||||
|
||||
for (int i = 0; i < nblock; ++i) {
|
||||
// We would uncomment this if we wanted to use this implementation for a model that has per block scales
|
||||
//uint16_t u = x[i].extra & 0xff;
|
||||
//scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
|
||||
auto signs = _mm256_set1_epi8(x[i].extra >> 8);
|
||||
// signs for groups of 8 ordered as 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ...
|
||||
// To use these to sign the q8 values we need
|
||||
// 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 amd the same for 4...7
|
||||
signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, mask1), mask1), m1_8);
|
||||
auto q8_1 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[i].qs+0), _mm256_shuffle_epi8(signs, shuff3));
|
||||
auto q8_2 = _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)y[i].qs+1), _mm256_shuffle_epi8(signs, shuff4));
|
||||
|
||||
auto ql = x[i].ql;
|
||||
auto qh = x[i].qh;
|
||||
auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)],
|
||||
iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]);
|
||||
auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)],
|
||||
iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]);
|
||||
|
||||
auto v1_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1);
|
||||
auto v1_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1);
|
||||
auto v2_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1);
|
||||
auto v2_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1);
|
||||
|
||||
auto dot1 = _mm256_sub_epi8(_mm256_sign_epi8(q8_1, v1_m), _mm256_sign_epi8(q8_1, v1_p));
|
||||
auto dot2 = _mm256_sub_epi8(_mm256_sign_epi8(q8_2, v2_m), _mm256_sign_epi8(q8_2, v2_p));
|
||||
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1);
|
||||
dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot2);
|
||||
#else
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot1));
|
||||
dot2 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot2));
|
||||
#endif
|
||||
|
||||
// We would uncomment this if we wanted to use this implementation for a model that has per block scales
|
||||
//acc1 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+0].d)), _mm256_cvtepi32_ps(dot1), acc1);
|
||||
//acc2 = _mm256_fmadd_ps(_mm256_set1_ps(scale.f*GGML_FP16_TO_FP32(y[2*i+1].d)), _mm256_cvtepi32_ps(dot2), acc2);
|
||||
// All scales are the same for BitNet!
|
||||
// This is slower
|
||||
//uint32_t aux32 = y[2*i+0].d | (y[2*i+1].d << 16);
|
||||
//auto d8 = _mm256_cvtph_ps(_mm_set1_epi32(aux32));
|
||||
//acc1 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x00), _mm256_cvtepi32_ps(dot1), acc1);
|
||||
//acc2 = _mm256_fmadd_ps(_mm256_permute_ps(d8, 0x55), _mm256_cvtepi32_ps(dot2), acc2);
|
||||
acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d), _mm256_cvtepi32_ps(_mm256_add_epi32(dot1, dot2)), acc);
|
||||
|
||||
}
|
||||
|
||||
sumf = scale.f * hsum_float_8(acc);
|
||||
|
||||
#else
|
||||
|
||||
for (int i = 0; i < nblock; ++i) {
|
||||
uint16_t u = x[i].extra & 0xff;
|
||||
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
|
||||
uint8_t extra = x[i].extra >> 8;
|
||||
auto qh = x[i].qh;
|
||||
auto ql = x[i].ql;
|
||||
auto q8 = y[2*i+0].qs;
|
||||
int16_t sumi1 = 0;
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00);
|
||||
uint16_t val = iq1bn_grid_u16[idx];
|
||||
int16_t sl = 0;
|
||||
for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1);
|
||||
sumi1 += extra & (1 << k) ? -sl : sl;
|
||||
q8 += 8;
|
||||
}
|
||||
q8 = y[2*i+1].qs;
|
||||
int16_t sumi2 = 0;
|
||||
for (int k = 4; k < 8; ++k) {
|
||||
uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00);
|
||||
uint16_t val = iq1bn_grid_u16[idx];
|
||||
int16_t sl = 0;
|
||||
for (int j = 0; j < 8; ++j) sl += q8[j] * (((val >> 2*j) & 3) - 1);
|
||||
sumi2 += extra & (1 << k) ? -sl : sl;
|
||||
q8 += 8;
|
||||
}
|
||||
sumf += scale.f * (GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
*s = sumf;
|
||||
|
||||
}
|
||||
281
llama.cpp
281
llama.cpp
@@ -225,6 +225,7 @@ enum llm_arch {
|
||||
LLM_ARCH_OLMO,
|
||||
LLM_ARCH_ARCTIC,
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
LLM_ARCH_BITNET,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -263,6 +264,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_OLMO, "olmo" },
|
||||
{ LLM_ARCH_ARCTIC, "arctic" },
|
||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||
{ LLM_ARCH_BITNET, "bitnet" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -500,6 +502,8 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_KV_B,
|
||||
LLM_TENSOR_ATTN_Q_A_NORM,
|
||||
LLM_TENSOR_ATTN_KV_A_NORM,
|
||||
LLM_TENSOR_ATTN_SUB_NORM,
|
||||
LLM_TENSOR_FFN_SUB_NORM,
|
||||
};
|
||||
|
||||
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
|
||||
@@ -1113,6 +1117,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_BITNET,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
@@ -2118,6 +2140,8 @@ struct llama_layer {
|
||||
struct ggml_tensor * attn_out_norm_b;
|
||||
struct ggml_tensor * attn_q_a_norm;
|
||||
struct ggml_tensor * attn_kv_a_norm;
|
||||
struct ggml_tensor * attn_sub_norm;
|
||||
struct ggml_tensor * ffn_sub_norm;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * wq;
|
||||
@@ -4710,6 +4734,15 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 26: model.type = e_model::MODEL_3B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
|
||||
@@ -6655,6 +6688,40 @@ static bool llm_load_tensors(
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
}
|
||||
|
||||
const uint32_t n_ff = hparams.n_ff;
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
|
||||
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@@ -11709,6 +11776,215 @@ struct llm_build_context {
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_bitnet() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
// B1.K
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
// B1.V
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||
|
||||
const int64_t n_ctx = cparams.n_ctx;
|
||||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_head_kv = hparams.n_head_kv;
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
struct ggml_tensor * q_cur = Qcur;
|
||||
struct ggml_tensor * kq_mask = KQ_mask;
|
||||
float kq_scale = 1.0f/sqrtf(float(n_embd_head));
|
||||
struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm;
|
||||
struct ggml_cgraph * graph = gf;
|
||||
struct ggml_tensor * wo = model.layers[il].wo;
|
||||
struct ggml_tensor * cur_attn;
|
||||
struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
||||
cb(q, "q", il);
|
||||
|
||||
struct ggml_tensor * k =
|
||||
ggml_view_3d(ctx0, kv_self.k_l[il],
|
||||
n_embd_head_k, n_kv, n_head_kv,
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
||||
0);
|
||||
cb(k, "k", il);
|
||||
|
||||
if (cparams.flash_attn) {
|
||||
|
||||
// split cached v into n_head heads (not transposed)
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx0, kv_self.v_l[il],
|
||||
n_embd_head_v, n_kv, n_head_kv,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
|
||||
cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
|
||||
} else {
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
GGML_ASSERT(kv_self.size == n_ctx);
|
||||
|
||||
// split cached v into n_head heads
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx0, kv_self.v_l[il],
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(kv_self.v_l[il])*n_ctx,
|
||||
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||
cb(cur_attn, "kqv_merged_cont", il);
|
||||
}
|
||||
|
||||
cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
|
||||
attn_sub_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur_attn, "attn_sub_norm", il);
|
||||
|
||||
ggml_build_forward_expand(graph, cur_attn);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, wo, cur_attn);
|
||||
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward forward
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
|
||||
|
||||
cb(tmp, "ffn_up", il);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur);
|
||||
|
||||
cb(cur, "ffn_gate", il);
|
||||
|
||||
|
||||
cur = ggml_silu(ctx0, cur);
|
||||
cb(cur, "ffn_silu", il);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_gate_par", il);
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].ffn_sub_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_sub_norm", il);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
|
||||
cb(cur, "ffn_down", il);
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
return gf;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
@@ -11932,6 +12208,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm.build_deepseek2();
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
{
|
||||
result = llm.build_bitnet();
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
@@ -16760,6 +17040,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_BERT:
|
||||
case LLM_ARCH_NOMIC_BERT:
|
||||
case LLM_ARCH_STABLELM:
|
||||
case LLM_ARCH_BITNET:
|
||||
case LLM_ARCH_QWEN:
|
||||
case LLM_ARCH_QWEN2:
|
||||
case LLM_ARCH_QWEN2MOE:
|
||||
|
||||
Reference in New Issue
Block a user