mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Vulkan: a fresh start (#608)
* It compiles * Seems to be working with coopmat * Vulkan needs f32 precision for flash attention * Vulkan: fix u_batch > 4096/n_active_experts for coopmat1. Without this fix we get an assert. We get the same assert in mainline too. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -10,7 +10,7 @@ extern "C" {
|
||||
#define GGML_VK_NAME "Vulkan"
|
||||
#define GGML_VK_MAX_DEVICES 16
|
||||
|
||||
GGML_API GGML_CALL void ggml_vk_instance_init(void);
|
||||
//GGML_API GGML_CALL void ggml_vk_instance_init(void);
|
||||
|
||||
// backend API
|
||||
GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num);
|
||||
|
||||
@@ -884,6 +884,15 @@ extern "C" {
|
||||
GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
|
||||
GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
|
||||
|
||||
// returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok)
|
||||
GGML_API bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor);
|
||||
|
||||
// true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
|
||||
GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
|
||||
|
||||
// true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
|
||||
GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
||||
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
||||
|
||||
|
||||
@@ -748,6 +748,16 @@ static size_t ggml_hash_find_or_insert(struct ggml_hash_set * hash_set, struct g
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) {
|
||||
assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
|
||||
return ((const int32_t *)(tensor->op_params))[i];
|
||||
}
|
||||
|
||||
static float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) {
|
||||
assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
|
||||
return ((const float *)(tensor->op_params))[i];
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4682,6 +4682,24 @@ GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
||||
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
|
||||
}
|
||||
|
||||
GGML_CALL bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) {
|
||||
return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
|
||||
}
|
||||
|
||||
GGML_CALL bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
|
||||
return
|
||||
tensor->nb[0] > tensor->nb[2] &&
|
||||
tensor->nb[1] > tensor->nb[0] &&
|
||||
tensor->nb[2] == ggml_type_size(tensor->type);
|
||||
}
|
||||
|
||||
GGML_CALL bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
|
||||
return
|
||||
tensor->ne[0] == ggml_blck_size(tensor->type) ||
|
||||
tensor->nb[0] == ggml_type_size(tensor->type);
|
||||
}
|
||||
|
||||
|
||||
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
@@ -5195,16 +5213,6 @@ static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params,
|
||||
memcpy(tensor->op_params, params, params_size);
|
||||
}
|
||||
|
||||
static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) {
|
||||
assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
|
||||
return ((const int32_t *)(tensor->op_params))[i];
|
||||
}
|
||||
|
||||
static float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) {
|
||||
assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
|
||||
return ((const float *)(tensor->op_params))[i];
|
||||
}
|
||||
|
||||
static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {
|
||||
assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
|
||||
((int32_t *)(tensor->op_params))[i] = value;
|
||||
|
||||
@@ -27,5 +27,5 @@ endif()
|
||||
set(TARGET vulkan-shaders-gen)
|
||||
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
|
||||
|
||||
105
ggml/src/vulkan-shaders/conv2d_dw.comp
Normal file
105
ggml/src/vulkan-shaders/conv2d_dw.comp
Normal file
@@ -0,0 +1,105 @@
|
||||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ne;
|
||||
uint batches;
|
||||
uint channels;
|
||||
uint dst_w;
|
||||
uint dst_h;
|
||||
uint src_w;
|
||||
uint src_h;
|
||||
uint knl_w;
|
||||
uint knl_h;
|
||||
int stride_x;
|
||||
int stride_y;
|
||||
int pad_x;
|
||||
int pad_y;
|
||||
int dilation_x;
|
||||
int dilation_y;
|
||||
} p;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
|
||||
layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
|
||||
uint i0 = idx / p.dst_w;
|
||||
uint dst_x = idx - i0 * p.dst_w;
|
||||
uint i1 = i0 / p.dst_h;
|
||||
uint dst_y = i0 - i1 * p.dst_h;
|
||||
uint n = i1 / p.channels;
|
||||
uint c = i1 - n * p.channels;
|
||||
|
||||
uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
|
||||
uint knl_i = c * p.knl_h * p.knl_w;
|
||||
|
||||
FLOAT_TYPE sum = 0.0;
|
||||
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
|
||||
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
|
||||
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
|
||||
continue;
|
||||
}
|
||||
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
|
||||
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
|
||||
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
|
||||
continue;
|
||||
}
|
||||
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
|
||||
FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
|
||||
sum = fma(v, k, sum);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
|
||||
uint i0 = idx / p.channels;
|
||||
uint c = idx - i0 * p.channels;
|
||||
uint i1 = i0 / p.dst_w;
|
||||
uint dst_x = i0 - i1 * p.dst_w;
|
||||
uint n = i1 / p.dst_h;
|
||||
uint dst_y = i1 - n * p.dst_h;
|
||||
|
||||
uint src_i = n * p.channels * p.src_h * p.src_w;
|
||||
uint src_row = p.src_w * p.channels;
|
||||
uint knl_row = p.knl_w * p.channels;
|
||||
|
||||
FLOAT_TYPE sum = 0.0;
|
||||
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
|
||||
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
|
||||
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
|
||||
continue;
|
||||
}
|
||||
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
|
||||
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
|
||||
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
|
||||
continue;
|
||||
}
|
||||
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
|
||||
FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
|
||||
sum = fma(v, k, sum);
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
void main() {
|
||||
uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
FLOAT_TYPE result =
|
||||
#ifdef WHCN
|
||||
conv_2d_dw_whcn(idx);
|
||||
#else
|
||||
conv_2d_dw_cwhn(idx);
|
||||
#endif
|
||||
dst_data[idx] = D_TYPE(result);
|
||||
}
|
||||
|
||||
@@ -6,17 +6,25 @@ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bi
|
||||
#endif // RTE16
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
// 16 invocations needed for init_iq4nl_shmem
|
||||
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
|
||||
#if defined(SET_ROWS) && QUANT_K == 1
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
const uint BLOCK_SIZE = 512;
|
||||
#else
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||
const uint BLOCK_SIZE = 32;
|
||||
#endif
|
||||
|
||||
layout (binding = 0) readonly buffer S {float data_s[];};
|
||||
|
||||
#if defined(SET_ROWS)
|
||||
#include "generic_binary_head.comp"
|
||||
layout (binding = 1) readonly buffer C {uvec2 data_i[];};
|
||||
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
|
||||
#else
|
||||
#include "generic_unary_head.comp"
|
||||
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
void quantize(uint dst_idx, uint src_idx)
|
||||
@@ -221,15 +229,56 @@ void quantize(uint dst_idx, uint src_idx)
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
void quantize(uint dst_idx, uint src_idx)
|
||||
{
|
||||
data_q[dst_idx] = A_TYPE(data_s[src_idx]);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_BF16)
|
||||
void quantize(uint dst_idx, uint src_idx)
|
||||
{
|
||||
data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(SET_ROWS)
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
if (gl_LocalInvocationIndex.x != 0) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
|
||||
const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
uint i12 = fastmod(i03, p.ne12);
|
||||
uint i11 = fastmod(i02, p.ne11);
|
||||
uint i10 = i01;
|
||||
|
||||
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
|
||||
|
||||
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
|
||||
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
|
||||
|
||||
quantize(dst_idx, src0_idx);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
@@ -240,3 +289,5 @@ void main() {
|
||||
|
||||
quantize(dst_idx, src_idx);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -100,6 +100,10 @@ void main() {
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
uint32_t m_offset = 0;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
}
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
@@ -145,13 +149,13 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
if (p.mask != 0) {
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
@@ -24,6 +24,8 @@ layout (push_constant) uniform parameter {
|
||||
uint32_t nev2;
|
||||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
uint32_t nem2;
|
||||
uint32_t nem3;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
@@ -34,14 +36,12 @@ layout (push_constant) uniform parameter {
|
||||
uint32_t nb21;
|
||||
uint32_t nb22;
|
||||
uint32_t nb23;
|
||||
uint32_t nb31;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
uint32_t mask_n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
|
||||
uint32_t k_num;
|
||||
} p;
|
||||
|
||||
#define MASK_ENABLE_BIT (1<<16)
|
||||
#define N_LOG2_MASK 0xFFFF
|
||||
|
||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
|
||||
{
|
||||
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
||||
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
|
||||
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
||||
@@ -125,6 +125,10 @@ void main() {
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
uint32_t m_offset = 0;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
}
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
@@ -178,12 +182,12 @@ void main() {
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (p.mask != 0) {
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
@@ -130,6 +130,11 @@ void main() {
|
||||
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
|
||||
}
|
||||
|
||||
uint32_t m_offset = 0;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
|
||||
}
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
@@ -148,14 +153,14 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
if (p.mask != 0) {
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
||||
}
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
#define BLOCK_SIZE 32
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {float data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||
@@ -12,48 +12,80 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||
layout (push_constant) uniform parameter {
|
||||
uint D;
|
||||
uint N;
|
||||
uint ne3;
|
||||
uint k_num;
|
||||
} p;
|
||||
|
||||
shared float tmpsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
// Each workgroup handles a row
|
||||
const uint n = gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint iq3 = gl_WorkGroupID.z;
|
||||
|
||||
uint D = p.D;
|
||||
uint N = p.N;
|
||||
uint k_num = p.k_num;
|
||||
|
||||
uint l_offset = D * N * k_num + n;
|
||||
uint m_offset = D * N * k_num + N + n;
|
||||
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
|
||||
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
|
||||
uint lm_stride = N * 2;
|
||||
|
||||
// Compute the max m value for the row
|
||||
float m_max = -1.0/0.0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
||||
float m = data_a[m_offset + (k + tid) * lm_stride];
|
||||
m_max = max(m_max, m);
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
tmpsh[tid] = m_max;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
m_max = max(m_max, tmpsh[tid + s]);
|
||||
tmpsh[tid] = m_max;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
m_max = tmpsh[0];
|
||||
|
||||
barrier();
|
||||
|
||||
// Compute L based on m_max
|
||||
float L = 0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
float l = data_a[l_offset + k * lm_stride];
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
||||
float l = data_a[l_offset + (k + tid) * lm_stride];
|
||||
float m = data_a[m_offset + (k + tid) * lm_stride];
|
||||
L += exp(m - m_max) * l;
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
tmpsh[tid] = L;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
L += tmpsh[tid + s];
|
||||
tmpsh[tid] = L;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
L = tmpsh[0];
|
||||
|
||||
L = 1.0 / L;
|
||||
|
||||
// D dimension is split across workgroups in the y dimension
|
||||
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
|
||||
// Scale and sum the O contributions based on m_max and store the result to memory
|
||||
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
||||
if (d < D) {
|
||||
float O = 0.0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
uint o_offset = D * N * k + D * n + d;
|
||||
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
O += exp(m - m_max) * data_a[o_offset];
|
||||
}
|
||||
O *= L;
|
||||
data_d[D * n + d] = O;
|
||||
data_d[iq3 * D * N + D * n + d] = O;
|
||||
}
|
||||
}
|
||||
|
||||
13
ggml/src/vulkan-shaders/geglu.comp
Normal file
13
ggml/src/vulkan-shaders/geglu.comp
Normal file
@@ -0,0 +1,13 @@
|
||||
#version 450
|
||||
|
||||
#include "glu_head.comp"
|
||||
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
float op(float a, float b) {
|
||||
const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);
|
||||
return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
|
||||
}
|
||||
|
||||
#include "glu_main.comp"
|
||||
27
ggml/src/vulkan-shaders/geglu_erf.comp
Normal file
27
ggml/src/vulkan-shaders/geglu_erf.comp
Normal file
@@ -0,0 +1,27 @@
|
||||
#version 450
|
||||
|
||||
#include "glu_head.comp"
|
||||
|
||||
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
||||
// ref: https://www.johndcook.com/blog/python_erf/
|
||||
const float p_erf = 0.3275911f;
|
||||
const float a1_erf = 0.254829592f;
|
||||
const float a2_erf = -0.284496736f;
|
||||
const float a3_erf = 1.421413741f;
|
||||
const float a4_erf = -1.453152027f;
|
||||
const float a5_erf = 1.061405429f;
|
||||
|
||||
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||
|
||||
float op(float a, float b) {
|
||||
const float a_div_sqr2 = a * SQRT_2_INV;
|
||||
const float sign_x = sign(a_div_sqr2);
|
||||
const float x = abs(a_div_sqr2);
|
||||
const float t = 1.0f / (1.0f + p_erf * x);
|
||||
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
||||
const float erf_approx = sign_x * y;
|
||||
|
||||
return 0.5f * a * (1.0f + erf_approx) * b;
|
||||
}
|
||||
|
||||
#include "glu_main.comp"
|
||||
11
ggml/src/vulkan-shaders/geglu_quick.comp
Normal file
11
ggml/src/vulkan-shaders/geglu_quick.comp
Normal file
@@ -0,0 +1,11 @@
|
||||
#version 450
|
||||
|
||||
#include "glu_head.comp"
|
||||
|
||||
const float GELU_QUICK_COEF = -1.702f;
|
||||
|
||||
float op(float a, float b) {
|
||||
return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
|
||||
}
|
||||
|
||||
#include "glu_main.comp"
|
||||
39
ggml/src/vulkan-shaders/gelu_erf.comp
Normal file
39
ggml/src/vulkan-shaders/gelu_erf.comp
Normal file
@@ -0,0 +1,39 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
||||
// ref: https://www.johndcook.com/blog/python_erf/
|
||||
const float p_erf = 0.3275911f;
|
||||
const float a1_erf = 0.254829592f;
|
||||
const float a2_erf = -0.284496736f;
|
||||
const float a3_erf = 1.421413741f;
|
||||
const float a4_erf = -1.453152027f;
|
||||
const float a5_erf = 1.061405429f;
|
||||
|
||||
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float a = float(data_a[i]);
|
||||
const float a_div_sqr2 = a * SQRT_2_INV;
|
||||
const float sign_x = sign(a_div_sqr2);
|
||||
const float x = abs(a_div_sqr2);
|
||||
const float t = 1.0f / (1.0f + p_erf * x);
|
||||
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
||||
const float erf_approx = sign_x * y;
|
||||
|
||||
data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx));
|
||||
}
|
||||
15
ggml/src/vulkan-shaders/glu_head.comp
Normal file
15
ggml/src/vulkan-shaders/glu_head.comp
Normal file
@@ -0,0 +1,15 @@
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint N;
|
||||
uint ne00;
|
||||
uint ne20;
|
||||
uint mode;
|
||||
} p;
|
||||
29
ggml/src/vulkan-shaders/glu_main.comp
Normal file
29
ggml/src/vulkan-shaders/glu_main.comp
Normal file
@@ -0,0 +1,29 @@
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.N) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row = i / p.ne20;
|
||||
const uint col = i - row * p.ne20;
|
||||
|
||||
if (p.mode == 0) {
|
||||
// Default
|
||||
const uint offset = p.ne00 / 2;
|
||||
const uint idx = row * p.ne00 + col;
|
||||
|
||||
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
|
||||
} else if (p.mode == 1) {
|
||||
// Swapped
|
||||
const uint offset = p.ne00 / 2;
|
||||
const uint idx = row * p.ne00 + col;
|
||||
|
||||
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
|
||||
} else {
|
||||
// Split
|
||||
const uint idx = row * p.ne00 + col;
|
||||
|
||||
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
|
||||
}
|
||||
}
|
||||
41
ggml/src/vulkan-shaders/l2_norm.comp
Normal file
41
ggml/src/vulkan-shaders/l2_norm.comp
Normal file
@@ -0,0 +1,41 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#define BLOCK_SIZE 512
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
|
||||
sum[tid] += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
sum[tid] += sum[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[4096];
|
||||
uint _ne1;
|
||||
#ifdef COOPMAT
|
||||
shared uint _ne1_sh;
|
||||
#endif
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
@@ -172,7 +177,47 @@ void main() {
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint _ne1 = 0;
|
||||
#ifdef COOPMAT
|
||||
// Spread the search across all elements in the first subgroup
|
||||
if (gl_SubgroupID == 0) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
_ne1 = _ne1_sh;
|
||||
#else
|
||||
_ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||
@@ -183,6 +228,7 @@ void main() {
|
||||
}
|
||||
|
||||
barrier();
|
||||
#endif
|
||||
|
||||
// Workgroup has no work
|
||||
if (ic * BN >= _ne1) return;
|
||||
@@ -500,10 +546,9 @@ void main() {
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint ib8 = (idx % 128) / 4;
|
||||
const int i8 = 2 * int(idx % 4);
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 32;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
@@ -512,22 +557,16 @@ void main() {
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
const vec2 v = dl * (vec2(gvec) + delta);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib8 = (idx % 128) / 4;
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib8 = idx % 32;
|
||||
const uint ib16 = ib8 / 2;
|
||||
const int i8 = 2 * int(idx % 4);
|
||||
|
||||
const uint16_t[4] scales = data_a[ib].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
@@ -538,21 +577,17 @@ void main() {
|
||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
const vec2 v = dl * (vec2(gvec) + delta);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint ib8 = (idx / 4) % 4;
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 4;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
|
||||
@@ -562,63 +597,81 @@ void main() {
|
||||
data_a[ib].qs[8*ib32 + 6],
|
||||
data_a[ib].qs[8*ib32 + 7]
|
||||
));
|
||||
const float db = d * 0.25 * (0.5 + (signs >> 28));
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
|
||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||
const uvec2 grid = iq2xxs_grid[qs];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint ib8 = (idx / 4) % 4; // 0..3
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 4; // 0..3
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||
const float db = d * 0.25 * (0.5 + scale);
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
|
||||
const uint sign7 = qs >> 9;
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||
const uvec2 grid = iq2xs_grid[qs & 511];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_S)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib8 = (idx % 128) / 4; // 0..31
|
||||
const uint ib32 = ib8 / 4; // 0..7
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib8 = idx % 32; // 0..31
|
||||
const uint ib32 = ib8 / 4; // 0..7
|
||||
|
||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
const uint qhshift = 2 * (ib8 % 4);
|
||||
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
|
||||
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const float db = d * 0.25 * (0.5 + scale);
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ3_XXS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = (idx % 128) / 2; // 0..63
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = idx % 64; // 0..63
|
||||
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
@@ -631,33 +684,36 @@ void main() {
|
||||
));
|
||||
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
|
||||
const uint grid = iq3xxs_grid[qs];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ3_S)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = (idx % 128) / 2; // 0..63
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = idx % 64; // 0..63
|
||||
const uint iqh = iqs / 8;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[iqs];
|
||||
const uint qh = data_a[ib].qh[iqh];
|
||||
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
|
||||
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
|
||||
const uint scale = data_a[ib].scales[iqs / 16];
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
||||
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ4_XS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
@@ -162,17 +162,32 @@ void main() {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
}
|
||||
@@ -414,17 +429,31 @@ void main() {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
||||
}
|
||||
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
#ifdef MUL_MAT_ID
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
#endif
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
#ifdef MUL_MAT_ID
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
#endif
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert from ACC_TYPE to D_TYPE
|
||||
|
||||
9
ggml/src/vulkan-shaders/reglu.comp
Normal file
9
ggml/src/vulkan-shaders/reglu.comp
Normal file
@@ -0,0 +1,9 @@
|
||||
#version 450
|
||||
|
||||
#include "glu_head.comp"
|
||||
|
||||
float op(float a, float b) {
|
||||
return max(a, 0.0f) * b;
|
||||
}
|
||||
|
||||
#include "glu_main.comp"
|
||||
@@ -1,11 +1,13 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_unary_head.comp"
|
||||
#include "generic_binary_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#define BLOCK_SIZE 512
|
||||
|
||||
layout (constant_id = 1) const bool do_multiply = false;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||
@@ -25,6 +27,7 @@ void main() {
|
||||
const uint stride_sample = p.nb03;
|
||||
|
||||
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
||||
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
@@ -46,7 +49,13 @@ void main() {
|
||||
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
||||
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||
if (do_multiply) {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
46
ggml/src/vulkan-shaders/roll.comp
Normal file
46
ggml/src/vulkan-shaders/roll.comp
Normal file
@@ -0,0 +1,46 @@
|
||||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
uint wrap_idx(int i, uint ne) {
|
||||
if (i < 0) {
|
||||
return i + ne;
|
||||
} else if (i >= ne) {
|
||||
return i - ne;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
|
||||
const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
|
||||
const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);
|
||||
const uint i2_offset = i2*p.ne11*p.ne10;
|
||||
const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);
|
||||
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
|
||||
|
||||
const uint p1 = floatBitsToUint(p.param1);
|
||||
const uint p2 = floatBitsToUint(p.param2);
|
||||
const int s0 = int(p1 >> 16) - 0x8000;
|
||||
const int s1 = int(p1 & 0xFFFF) - 0x8000;
|
||||
const int s2 = int(p2 >> 16) - 0x8000;
|
||||
const int s3 = int(p2 & 0xFFFF) - 0x8000;
|
||||
|
||||
const uint i00 = wrap_idx(int(i0) - s0, p.ne10);
|
||||
const uint i01 = wrap_idx(int(i1) - s1, p.ne11);
|
||||
const uint i02 = wrap_idx(int(i2) - s2, p.ne12);
|
||||
const uint i03 = wrap_idx(int(i3) - s3, p.ne13);
|
||||
|
||||
const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
|
||||
const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;
|
||||
|
||||
data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);
|
||||
}
|
||||
@@ -14,21 +14,19 @@ void main() {
|
||||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
const uint i = row_dst*ne0 + i0;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
||||
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
|
||||
const int sec_w = p.sections[1] + p.sections[0];
|
||||
const uint sector = (i0 / 2) % sect_dims;
|
||||
|
||||
@@ -13,21 +13,19 @@ void main() {
|
||||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
const uint i = row_dst*ne0 + i0;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
||||
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
@@ -13,21 +13,19 @@ void main() {
|
||||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
const uint i = row_dst*ne0 + i0;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + 0] = data_a[ix + 0];
|
||||
data_d[idst + 1] = data_a[ix + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
@@ -18,7 +18,7 @@ void main() {
|
||||
continue;
|
||||
}
|
||||
|
||||
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
|
||||
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2));
|
||||
idx += num_threads;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,14 @@ layout (push_constant) uniform parameter
|
||||
{
|
||||
uint KX;
|
||||
uint KY;
|
||||
uint ne00;
|
||||
uint ne01;
|
||||
uint ne02;
|
||||
uint ne12;
|
||||
uint ne13;
|
||||
uint nb11;
|
||||
uint nb12;
|
||||
uint nb13;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
@@ -31,7 +39,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE];
|
||||
void soft_max(uint num_iters) {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0;
|
||||
|
||||
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
|
||||
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
|
||||
const uint32_t i01 = rowx % p.ne01;
|
||||
|
||||
uint rowy_start = 0;
|
||||
if (p.KY > 0) {
|
||||
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
|
||||
}
|
||||
|
||||
if (rowx >= p.nrows_x) {
|
||||
return;
|
||||
@@ -41,7 +57,7 @@ void soft_max(uint num_iters) {
|
||||
|
||||
// ALiBi
|
||||
if (p.max_bias > 0.0f) {
|
||||
const uint h = rowx/p.KY; // head index
|
||||
const uint h = (rowx / p.ne01) % p.ne02; // head index
|
||||
|
||||
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
|
||||
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
|
||||
@@ -67,7 +83,7 @@ void soft_max(uint num_iters) {
|
||||
|
||||
FLOAT_TYPE b = FLOAT_TYPE(0);
|
||||
if (p.KY > 0 && col < p.KX) {
|
||||
b = data_b[rowy * p.KX + col];
|
||||
b = data_b[rowy_start + col];
|
||||
}
|
||||
|
||||
FLOAT_TYPE v = a * p.scale + slope * b;
|
||||
@@ -111,7 +127,7 @@ void soft_max(uint num_iters) {
|
||||
if (idx < DATA_CACHE_SIZE) {
|
||||
val = exp(data_cache[idx] - max_val);
|
||||
} else {
|
||||
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
|
||||
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
|
||||
}
|
||||
sum += val;
|
||||
if (idx < DATA_CACHE_SIZE) {
|
||||
|
||||
9
ggml/src/vulkan-shaders/swiglu.comp
Normal file
9
ggml/src/vulkan-shaders/swiglu.comp
Normal file
@@ -0,0 +1,9 @@
|
||||
#version 450
|
||||
|
||||
#include "glu_head.comp"
|
||||
|
||||
float op(float a, float b) {
|
||||
return a / (1.0f + exp(-a)) * b;
|
||||
}
|
||||
|
||||
#include "glu_main.comp"
|
||||
@@ -3,6 +3,7 @@
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ne; uint a_offset; uint d_offset;
|
||||
uint ne00; uint ne01;
|
||||
uint nb00; uint nb01; uint nb02; uint nb03;
|
||||
uint ne10; uint ne11; uint ne12; uint ne13;
|
||||
float sf0; float sf1; float sf2; float sf3;
|
||||
@@ -15,6 +16,61 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
|
||||
#define NEAREST 0
|
||||
#define BILINEAR 1
|
||||
#define ALIGN_CORNERS (1 << 8)
|
||||
|
||||
layout (constant_id = 0) const uint scale_mode = 0;
|
||||
|
||||
float fetch_nearest(uint i10, uint i11, uint i12, uint i13) {
|
||||
const uint i00 = uint(i10 / p.sf0);
|
||||
const uint i01 = uint(i11 / p.sf1);
|
||||
const uint i02 = uint(i12 / p.sf2);
|
||||
const uint i03 = uint(i13 / p.sf3);
|
||||
|
||||
return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00];
|
||||
}
|
||||
|
||||
float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) {
|
||||
const uint i02 = uint(i12 / p.sf2);
|
||||
const uint i03 = uint(i13 / p.sf3);
|
||||
const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;
|
||||
|
||||
const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00];
|
||||
const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00];
|
||||
const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00];
|
||||
const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00];
|
||||
|
||||
return
|
||||
v00 * (1.0-d.x) * (1.0-d.y) +
|
||||
v01 * d.x * (1.0-d.y) +
|
||||
v10 * (1.0-d.x) * d.y +
|
||||
v11 * d.x * d.y;
|
||||
}
|
||||
|
||||
float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
|
||||
const ivec2 ne0 = ivec2(p.ne00, p.ne01);
|
||||
|
||||
const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5;
|
||||
const vec2 c0f = floor(c);
|
||||
const vec2 d = c - c0f;
|
||||
const ivec2 c0 = max(ivec2(c0f), 0);
|
||||
const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1);
|
||||
|
||||
return fetch_bilinear(c0, c1, d, i12, i13);
|
||||
}
|
||||
|
||||
float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) {
|
||||
const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1);
|
||||
const vec2 c0f = floor(c);
|
||||
const vec2 d = c - c0f;
|
||||
const ivec2 c0 = ivec2(c0f);
|
||||
const ivec2 c1 = c0 + 1;
|
||||
|
||||
return fetch_bilinear(c0, c1, d, i12, i13);
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
@@ -27,10 +83,18 @@ void main() {
|
||||
const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
|
||||
const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
|
||||
|
||||
const uint i00 = uint(i10 / p.sf0);
|
||||
const uint i01 = uint(i11 / p.sf1);
|
||||
const uint i02 = uint(i12 / p.sf2);
|
||||
const uint i03 = uint(i13 / p.sf3);
|
||||
float result;
|
||||
switch (scale_mode) {
|
||||
case NEAREST:
|
||||
result = fetch_nearest(i10, i11, i12, i13);
|
||||
break;
|
||||
case BILINEAR:
|
||||
result = interpolate_bilinear(i10, i11, i12, i13);
|
||||
break;
|
||||
case BILINEAR | ALIGN_CORNERS:
|
||||
result = interpolate_bilinear_align_corners(i10, i11, i12, i13);
|
||||
break;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
|
||||
data_d[p.d_offset + idx] = D_TYPE(result);
|
||||
}
|
||||
|
||||
@@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
std::string load_vec_quant = "2";
|
||||
if ((tname == "q4_0") || (tname == "q4_1"))
|
||||
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
|
||||
load_vec_quant = "8";
|
||||
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
|
||||
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
|
||||
load_vec_quant = "4";
|
||||
|
||||
if (tname == "bf16") {
|
||||
@@ -497,9 +497,9 @@ void process_shaders() {
|
||||
// Norms
|
||||
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("fused_rms_norm_f32", "fused_rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||
@@ -518,6 +518,11 @@ void process_shaders() {
|
||||
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}
|
||||
|
||||
for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||
string_to_spv("set_rows_" + t, "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("set_rows_" + t + "_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
}
|
||||
|
||||
auto get_type_str = [](bool f16) {
|
||||
return f16 ? "float16_t" : "float";
|
||||
};
|
||||
@@ -572,17 +577,10 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("fused_mul_gelu_f16", "fused_mul_gelu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("fused_mul_gelu_f32", "fused_mul_gelu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("fused_mul_silu_f16", "fused_mul_silu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("fused_mul_silu_f32", "fused_mul_silu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("fused_mul_relu_f16", "fused_mul_relu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("fused_mul_relu_f32", "fused_mul_relu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("gelu_erf_f32", "gelu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
@@ -594,6 +592,17 @@ void process_shaders() {
|
||||
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
@@ -637,8 +646,30 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
|
||||
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
// ============================== ik_llama.cpp
|
||||
//
|
||||
string_to_spv("fused_rms_norm_f32", "fused_rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("fused_mul_gelu_f16", "fused_mul_gelu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("fused_mul_gelu_f32", "fused_mul_gelu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("fused_mul_silu_f16", "fused_mul_silu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("fused_mul_silu_f32", "fused_mul_silu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("fused_mul_relu_f16", "fused_mul_relu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("fused_mul_relu_f32", "fused_mul_relu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
//
|
||||
// ============================== end ik_llama.cpp
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
||||
91
ggml/src/vulkan-shaders/wkv7.comp
Normal file
91
ggml/src/vulkan-shaders/wkv7.comp
Normal file
@@ -0,0 +1,91 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#define BLOCK_SIZE 64
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(push_constant) uniform Parameters {
|
||||
uint B;
|
||||
uint T;
|
||||
uint C;
|
||||
uint H;
|
||||
};
|
||||
|
||||
layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
|
||||
layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
|
||||
layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
|
||||
layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
|
||||
layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
|
||||
layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
|
||||
layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
|
||||
layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };
|
||||
|
||||
shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint head_size = BLOCK_SIZE;
|
||||
const uint batch_id = gl_WorkGroupID.x / H;
|
||||
const uint head_id = gl_WorkGroupID.x % H;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
A_TYPE state[BLOCK_SIZE];
|
||||
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i];
|
||||
}
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
barrier();
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
barrier();
|
||||
|
||||
A_TYPE sa = 0.0;
|
||||
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
||||
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||
sa += dot(s_vec, a_vec);
|
||||
}
|
||||
|
||||
const A_TYPE v_val = v[t];
|
||||
A_TYPE y = 0.0;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
|
||||
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
vec4 kv = k_vec * v_val;
|
||||
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
||||
y += dot(r_vec, s_vec);
|
||||
|
||||
state[j] = s_vec.x;
|
||||
state[j+1] = s_vec.y;
|
||||
state[j+2] = s_vec.z;
|
||||
state[j+3] = s_vec.w;
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user