mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 03:11:51 +00:00
Fix race in CUDA FA for head sizes 192/128 (#1104)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -285,17 +285,17 @@ struct fattn_mma_f16_config;
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct fattn_mma_f16_config<192, 128> {
|
struct fattn_mma_f16_config<192, 128> {
|
||||||
static constexpr int nbatch_fa = 64;
|
static constexpr int nbatch_fa = 32;
|
||||||
static constexpr int nwarps_max = 4;
|
static constexpr int nwarps_max = 4;
|
||||||
static constexpr bool Q_in_reg = true;
|
static constexpr bool Q_in_reg = true;
|
||||||
static constexpr int nstages_target = 1;
|
static constexpr int nstages_target = 1;
|
||||||
|
|
||||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
return 64;
|
return 96;
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||||
return 64;
|
return 96;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
@@ -317,17 +317,17 @@ struct fattn_mma_f16_config<192, 128> {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct fattn_mma_f16_config<192, 192> {
|
struct fattn_mma_f16_config<192, 192> {
|
||||||
static constexpr int nbatch_fa = 64;
|
static constexpr int nbatch_fa = 32;
|
||||||
static constexpr int nwarps_max = 4;
|
static constexpr int nwarps_max = 4;
|
||||||
static constexpr bool Q_in_reg = true;
|
static constexpr bool Q_in_reg = true;
|
||||||
static constexpr int nstages_target = 1;
|
static constexpr int nstages_target = 1;
|
||||||
|
|
||||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
return 64;
|
return 96;
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||||
return 64;
|
return 96;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
|||||||
Reference in New Issue
Block a user