mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-10 16:30:12 +00:00
CUDA: faster FA TG for GQA models (#370)
* cuda: WIP MMA FA * Use MMA for TG also when quantized --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -321,6 +321,8 @@ if (GGML_CUDA)
|
||||
list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu")
|
||||
file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "ggml-cuda/template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
|
||||
|
||||
@@ -46,10 +46,14 @@
|
||||
#define CC_VOLTA 700
|
||||
#define CC_TURING 750
|
||||
#define CC_AMPERE 800
|
||||
#define CC_ADA_LOVELACE 890
|
||||
#define CC_OFFSET_AMD 1000000
|
||||
#define CC_OFFSET_MTHREADS 0x0100000
|
||||
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
||||
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
||||
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
|
||||
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < CC_OFFSET_MTHREADS)
|
||||
#define GGML_CUDA_CC_IS_AMD(cc) (cc >= CC_OFFSET_AMD)
|
||||
|
||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||
|
||||
@@ -134,6 +138,49 @@ typedef float2 dfloat2;
|
||||
#define INT8_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
||||
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
|
||||
#define CP_ASYNC_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
|
||||
|
||||
#ifdef __CUDA_ARCH_LIST__
|
||||
constexpr bool ggml_cuda_has_arch_impl(int) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template<class ... Archs>
|
||||
constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
|
||||
return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
|
||||
}
|
||||
|
||||
constexpr bool ggml_cuda_has_arch(const int arch) {
|
||||
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
|
||||
}
|
||||
|
||||
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
|
||||
if (cur == 0) {
|
||||
GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
template<class ... Archs>
|
||||
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
|
||||
if (first <= arch && first > cur) {
|
||||
return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
|
||||
} else {
|
||||
return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
|
||||
return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
|
||||
}
|
||||
#else
|
||||
static int ggml_cuda_highest_compiled_arch(const int arch) {
|
||||
return arch;
|
||||
}
|
||||
#endif // __CUDA_ARCH_LIST__
|
||||
|
||||
static constexpr bool fast_fp16_available(const int cc) {
|
||||
return cc >= CC_PASCAL && cc != 610;
|
||||
}
|
||||
@@ -146,6 +193,15 @@ static constexpr bool int8_mma_available(const int cc) {
|
||||
return cc < CC_OFFSET_AMD && cc >= CC_TURING;
|
||||
}
|
||||
|
||||
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
||||
static bool new_mma_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= CC_TURING;
|
||||
}
|
||||
|
||||
static bool cp_async_available(const int cc) {
|
||||
return cc < CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= CC_AMPERE;
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
static __device__ void no_device_code(
|
||||
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
||||
|
||||
46
ggml/src/ggml-cuda/cp-async.cuh
Normal file
46
ggml/src/ggml-cuda/cp-async.cuh
Normal file
@@ -0,0 +1,46 @@
|
||||
// Simplified API for asynchronous data loading.
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
// Copies data from global to shared memory, cg == cache global.
|
||||
// Both the src and dst pointers must be aligned to 16 bit.
|
||||
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
|
||||
// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
|
||||
// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
|
||||
template <int preload>
|
||||
static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
|
||||
static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
|
||||
#ifdef CP_ASYNC_AVAILABLE
|
||||
#if CUDART_VERSION >= 11040
|
||||
if (preload == 256) {
|
||||
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
|
||||
: : "r"(dst), "l"(src));
|
||||
} else if (preload == 128) {
|
||||
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
|
||||
: : "r"(dst), "l"(src));
|
||||
} else if (preload == 64) {
|
||||
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
|
||||
: : "r"(dst), "l"(src));
|
||||
} else
|
||||
#endif // CUDART_VERSION >= 11040
|
||||
{
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
|
||||
: : "r"(dst), "l"(src));
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // CP_ASYNC_AVAILABLE
|
||||
}
|
||||
|
||||
// Makes each thread wait until its asynchronous data copies are done.
|
||||
// This does NOT provide any additional synchronization.
|
||||
// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
|
||||
static __device__ __forceinline__ void cp_async_wait_all() {
|
||||
#ifdef CP_ASYNC_AVAILABLE
|
||||
asm volatile("cp.async.wait_all;");
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // CP_ASYNC_AVAILABLE
|
||||
}
|
||||
@@ -862,3 +862,336 @@ void launch_fattn(
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_mma_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
const int j = blockIdx.y;
|
||||
const int c = blockIdx.z;
|
||||
const int jc = j*ncols2 + c;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
|
||||
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc0 / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
||||
|
||||
if (jt*ncols1 + j >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
|
||||
// Load the partial result that needs a fixup:
|
||||
float dst_val = 0.0f;
|
||||
float max_val = 0.0f;
|
||||
float rowsum = 0.0f;
|
||||
{
|
||||
dst_val = *dst;
|
||||
|
||||
const float2 tmp = dst_fixup[bidx0*ncols + jc];
|
||||
max_val = tmp.x;
|
||||
rowsum = tmp.y;
|
||||
}
|
||||
|
||||
|
||||
// Iterate over previous blocks and compute the combined results.
|
||||
// All CUDA blocks that get here must have a previous block that needs a fixup.
|
||||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
continue;
|
||||
}
|
||||
|
||||
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
|
||||
|
||||
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
|
||||
|
||||
// Scale the current and new value accumulators depending on the max. values.
|
||||
const float max_val_new = fmaxf(max_val, tmp.x);
|
||||
|
||||
const float diff_val = max_val - max_val_new;
|
||||
const float diff_add = tmp.x - max_val_new;
|
||||
|
||||
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
||||
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
||||
|
||||
dst_val = scale_val*dst_val + scale_add*dst_add;
|
||||
rowsum = scale_val*rowsum + scale_add*tmp.y;
|
||||
|
||||
max_val = max_val_new;
|
||||
|
||||
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
||||
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
||||
break;
|
||||
}
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
}
|
||||
|
||||
// Write back final result:
|
||||
*dst = dst_val / rowsum;
|
||||
}
|
||||
|
||||
template<int D> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_mma_combine_results(
|
||||
const float * __restrict__ VKQ_parts,
|
||||
const float2 * __restrict__ VKQ_meta,
|
||||
float * __restrict__ dst,
|
||||
const int parallel_blocks) {
|
||||
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
|
||||
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
|
||||
dst += D * gridDim.z*blockIdx.x;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
|
||||
extern __shared__ float2 meta[];
|
||||
if (tid < 2*parallel_blocks) {
|
||||
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
float kqmax = meta[0].x;
|
||||
for (int l = 1; l < parallel_blocks; ++l) {
|
||||
kqmax = max(kqmax, meta[l].x);
|
||||
}
|
||||
|
||||
float VKQ_numerator = 0.0f;
|
||||
float VKQ_denominator = 0.0f;
|
||||
for (int l = 0; l < parallel_blocks; ++l) {
|
||||
const float diff = meta[l].x - kqmax;
|
||||
float KQ_max_scale = expf(diff);
|
||||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
|
||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||
}
|
||||
|
||||
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
template <int D, int ncols1, int ncols2, int KQ_stride>
|
||||
void launch_fattn_mma(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
GGML_ASSERT(Q->ne[3] == 1);
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||
|
||||
ggml_cuda_pool_alloc<half> K_f16(pool);
|
||||
ggml_cuda_pool_alloc<half> V_f16(pool);
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||
|
||||
const char * K_data = (const char *) K->data;
|
||||
size_t nb11 = K->nb[1];
|
||||
size_t nb12 = K->nb[2];
|
||||
size_t nb13 = K->nb[3];
|
||||
|
||||
const char * V_data = (const char *) V->data;
|
||||
size_t nb21 = V->nb[1];
|
||||
size_t nb22 = V->nb[2];
|
||||
size_t nb23 = V->nb[3];
|
||||
|
||||
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
||||
K_f16.alloc(ggml_nelements(K));
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
||||
to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream);
|
||||
K_data = (char *) K_f16.ptr;
|
||||
|
||||
const size_t bs = ggml_blck_size(K->type);
|
||||
const size_t ts = ggml_type_size(K->type);
|
||||
|
||||
nb11 = nb11*bs*sizeof(half)/ts;
|
||||
nb12 = nb12*bs*sizeof(half)/ts;
|
||||
nb13 = nb13*bs*sizeof(half)/ts;
|
||||
}
|
||||
|
||||
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16.ptr;
|
||||
|
||||
const size_t bs = ggml_blck_size(V->type);
|
||||
const size_t ts = ggml_type_size(V->type);
|
||||
|
||||
nb21 = nb21*bs*sizeof(half)/ts;
|
||||
nb22 = nb22*bs*sizeof(half)/ts;
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
}
|
||||
|
||||
int parallel_blocks = 1;
|
||||
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
|
||||
const dim3 block_dim(warp_size, nwarps, 1);
|
||||
dim3 blocks_num;
|
||||
if (stream_k) {
|
||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||
const int max_blocks = 2*nsm;
|
||||
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
||||
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
|
||||
|
||||
const int nblocks_stream_k = max_blocks;
|
||||
|
||||
const bool use_stream_k = cc >= CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
|
||||
|
||||
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
|
||||
blocks_num.y = 1;
|
||||
blocks_num.z = 1;
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
|
||||
} else {
|
||||
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
||||
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||
|
||||
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
||||
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
||||
|
||||
// parallel_blocks must not be larger than what the tensor size allows:
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||
|
||||
// If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
|
||||
// Test whether parallel_blocks can be set to a higher value for better efficiency.
|
||||
const int blocks_per_wave = nsm * max_blocks_per_sm;
|
||||
int nwaves_best = 0;
|
||||
int efficiency_percent_best = 0;
|
||||
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
|
||||
const int nblocks_total = ntiles_total * parallel_blocks_test;
|
||||
const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
|
||||
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
|
||||
|
||||
// Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
|
||||
if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (efficiency_percent > efficiency_percent_best) {
|
||||
nwaves_best = nwaves;
|
||||
efficiency_percent_best = efficiency_percent;
|
||||
parallel_blocks = parallel_blocks_test;
|
||||
}
|
||||
}
|
||||
|
||||
blocks_num.x = ntiles_x;
|
||||
blocks_num.y = parallel_blocks;
|
||||
blocks_num.z = Q->ne[2]*Q->ne[3];
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||
}
|
||||
}
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
const uint32_t n_head = Q->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
GGML_ASSERT(block_dim.x % warp_size == 0);
|
||||
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
||||
(const char *) Q->data,
|
||||
K_data,
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if (stream_k) {
|
||||
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
||||
|
||||
flash_attn_mma_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
||||
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
||||
|
||||
flash_attn_mma_combine_results<D>
|
||||
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
1047
ggml/src/ggml-cuda/fattn-mma-f16.cuh
Normal file
1047
ggml/src/ggml-cuda/fattn-mma-f16.cuh
Normal file
File diff suppressed because it is too large
Load Diff
@@ -12,10 +12,94 @@
|
||||
#include "fattn-vec-f16.cuh"
|
||||
#include "fattn-vec-f32.cuh"
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
#include "fattn-mma-f16.cuh"
|
||||
#include "fattn.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
template <int D, int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
if (Q->ne[1] <= 8/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 16/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
|
||||
}
|
||||
|
||||
template <int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const float use_gqa_opt = mask && max_bias == 0.0f;
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio == 4) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio == 2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
@@ -371,8 +455,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
@@ -389,7 +476,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
|
||||
if (!fast_fp16_available(cc)) {
|
||||
if (Q->ne[1] <= 8) {
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||
@@ -398,23 +485,43 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
|
||||
if (!fp16_mma_available(cc)) {
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
if (precision == GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||
}
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
||||
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
||||
// So, not sure why in mainline they thought that for CC_ADA_LOVELACE or when KV cache is not f16 the vector kernels are faster.
|
||||
// On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache.
|
||||
//const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
||||
//const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies;
|
||||
const bool can_use_vector_kernel = Q->ne[0] % (2*WARP_SIZE) == 0;
|
||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
||||
if (precision == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
return;
|
||||
} else if(Q->ne[0] <= 128) {
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
// We need this because I haven't adapted the MMA kernels to work for different
|
||||
// K and V head sizes.
|
||||
if (K->ne[0] != V->ne[0]) {
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||
}
|
||||
|
||||
396
ggml/src/ggml-cuda/mma_new.cuh
Normal file
396
ggml/src/ggml-cuda/mma_new.cuh
Normal file
@@ -0,0 +1,396 @@
|
||||
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
|
||||
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
|
||||
// The documentation for the PTX instructions can be found under:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
|
||||
//
|
||||
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
|
||||
// A is a row-major matrix with shape M x K.
|
||||
// B is a column-major matrix with shape K x N.
|
||||
// C is a column-major matrix with shape M x N.
|
||||
// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
|
||||
// Note that J is measured in physical 32 bit elements instead of logical elements.
|
||||
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
||||
// All matrix tiles have ne physical 32 bit elements per warp.
|
||||
//
|
||||
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
|
||||
#if CUDART_VERSION >= 11080
|
||||
|
||||
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
||||
int ret = 0;
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
||||
: "=r"(ret) : "r"(x));
|
||||
#else
|
||||
GGML_UNUSED(x);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(INT8_MMA_AVAILABLE)
|
||||
return ret;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
||||
// Imagine transposing row-major matrix to column-major matrix.
|
||||
const int src_i_low = 2 * (threadIdx.x % 4);
|
||||
const int src_i_high = src_i_low + 1;
|
||||
const int src_j = threadIdx.x / 4;
|
||||
|
||||
const int src_laneid_low = src_i_low * 4 + src_j / 2;
|
||||
const int src_laneid_high = src_i_high * 4 + src_j / 2;
|
||||
|
||||
const int shift_low = ((src_j + 0) % 2) * 16;
|
||||
const int shift_high = ((src_j + 1) % 2) * 16;
|
||||
|
||||
const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
|
||||
const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
|
||||
|
||||
return ret_low | ret_high;
|
||||
}
|
||||
|
||||
#endif // CUDART_VERSION >= 11080
|
||||
|
||||
static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
|
||||
half2 ret;
|
||||
*((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
|
||||
return ret;
|
||||
}
|
||||
|
||||
namespace ggml_cuda_mma {
|
||||
|
||||
template <int I_, int J_, typename T>
|
||||
struct tile {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
T x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 8 && (J == 4 || J == 8)) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return (l / 2) * 8 + threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return threadIdx.x % 4;
|
||||
} else if constexpr (I == 8 && J == 8) {
|
||||
return 4 * l + threadIdx.x % 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return 2 * (threadIdx.x % 4) + l % 2;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
struct tile<I_, J_, half2> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return l * 8 + threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return (l % 2) * 8 + threadIdx.x / 4;
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return l * 4 + threadIdx.x % 4;
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return threadIdx.x % 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return (l / 2) * 4 + threadIdx.x % 4;
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int I, int J>
|
||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||
tile<I, J/2, half2> ret;
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
|
||||
ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
|
||||
tile<8, 8, half2> ret;
|
||||
ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
|
||||
ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <int I, int J, typename T>
|
||||
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < t.ne; ++l) {
|
||||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
int * xi = (int *) t.x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||
: "=r"(xi[0]), "=r"(xi[1])
|
||||
: "l"(xs));
|
||||
#else
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
int * xi = (int *) t.x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||
: "=r"(xi[0]), "=r"(xi[1])
|
||||
: "l"(xs));
|
||||
#else
|
||||
load_generic(xs0, stride);
|
||||
GGML_UNUSED(t);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
int * xi = (int * ) t.x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
|
||||
: "l"(xs));
|
||||
#else
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix_trans(
|
||||
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
int * xi = (int * ) t.x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
|
||||
: "l"(xs));
|
||||
#else
|
||||
GGML_UNUSED(t);
|
||||
GGML_UNUSED(xs0);
|
||||
GGML_UNUSED(stride);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
#if __CUDA_ARCH__ >= CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
||||
: "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(D.x[0]), "+r"(D.x[1])
|
||||
: "r"(A.x[0]), "r"(B.x[0]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(D.x[2]), "+r"(D.x[3])
|
||||
: "r"(A.x[1]), "r"(B.x[0]));
|
||||
#endif // __CUDA_ARCH__ >= CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
#if __CUDA_ARCH__ >= CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
||||
: "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
|
||||
#else
|
||||
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(D.x[0]), "+r"(D.x[1])
|
||||
: "r"(A.x[0]), "r"(B.x[0]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(D.x[2]), "+r"(D.x[3])
|
||||
: "r"(A.x[1]), "r"(B.x[0]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(D.x[0]), "+r"(D.x[1])
|
||||
: "r"(A.x[2]), "r"(B.x[1]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(D.x[2]), "+r"(D.x[3])
|
||||
: "r"(A.x[3]), "r"(B.x[1]));
|
||||
#endif // __CUDA_ARCH__ >= CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
int * Dxi = (int *) D.x;
|
||||
#if __CUDA_ARCH__ >= CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
||||
#endif // __CUDA_ARCH__ >= CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
int * Dxi = (int *) D.x;
|
||||
#if __CUDA_ARCH__ >= CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
||||
: "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
||||
: "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
||||
#endif // __CUDA_ARCH__ >= CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
int * Dxi = (int *) D.x;
|
||||
#if __CUDA_ARCH__ >= CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
||||
#endif // __CUDA_ARCH__ >= CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
int * Dxi = (int *) D.x;
|
||||
#if __CUDA_ARCH__ >= CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
|
||||
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
||||
#endif // __CUDA_ARCH__ >= CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 1, 8);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 1);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 2);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 4);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 2, 4);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 2, 8);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 32, 1);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 32, 2);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 2);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 4);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 8);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 64, 1);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 1);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 2);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 4);
|
||||
@@ -0,0 +1,10 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 8);
|
||||
Reference in New Issue
Block a user