Enable CUDA graphs for MoE models + GPT-OSS support (#689)

* gmp-oss: common

* gpt-oss: attnetion sinks, swiglu_oai

* gpt-oss: WIP llama

Model loads and runs (CPU only), but PPL is much to high
(~1500 for 1st batch vs ~200 in mainline).
Is it because of SWA, because of vocab, or did I introduce a bug somewhere?

* gpt-oss: CPU seems to be working

It was the SWA thta was missing in the previous commit.

There are issues with EOG tokens, so this still needs to be added.

* CUDA: ADD_ID

Just a copy from mainline

* gpt-oss: Seems to be working on CUDA

* gpt-oss: add sinks to the attn-vec kernels

* CUDA: add head size of 64 to new mma

Haven't turned it on yet, but observe slightly better PP and slightly
worse TG performance with that.

* gpt-oss: add ability to use -fmoe (only CUDA for now)

* Move row sums to the write place

* Add sinks to iqk flash attention

* gpt_oss: Implement -fmoe on the CPU

* Simdify swiglu_oai

Turning it off for now as performance becomes more variable,
so perhaps I'm running into thermal trottling imore often
because of making the CPU work too hard.

* llama: factor out model loader

* Builds successfully

* It runs, but mmap does not work

* Fix llama_mmap so mmap works

* Minor

* Fix CUDA after latest changes

* Attempt to use CUDA graphs with MoE models - not working

* CUDA graphs WIP - still not working

* CUDA graphs - seems to be working

Likely not all MLA variants are working.
I no longer remember why I added the q8_0 cpy that
transposes the tensor, but if really needed, this is now
missing. Also missing is q6_0.

* Make q8_0 cache work for DeepSeek models with CUDA graphs

* cuda: cpy for q6_0

* Fix llama_mmap on non-Linux platforms

* Adding forgotten file

* Iterating on Windows build failures

* cuda: re-add q8_0 -> q8_0 transpose

so mla = 2 can be used with CUDA graphs and q8_0 cache.

* Disable graphs without -fmoe

* Minor

* Turn graphs on by default

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-08-15 09:18:07 +03:00
committed by GitHub
parent c00335684c
commit 633e0617b0
56 changed files with 8720 additions and 5115 deletions

View File

@@ -24,9 +24,9 @@ class common_chat_msg_parser {
std::string prelude; std::string prelude;
std::vector<common_string_range> groups; std::vector<common_string_range> groups;
}; };
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax); common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
// Accessors // Accessors
const std::string & input() const { return input_; } const std::string & input() const { return input_; }
size_t pos() const { return pos_; } size_t pos() const { return pos_; }
@@ -42,7 +42,7 @@ class common_chat_msg_parser {
} }
pos_ = pos; pos_ = pos;
} }
void move_back(size_t n) { void move_back(size_t n) {
if (pos_ < n) { if (pos_ < n) {
throw std::runtime_error("Can't move back that far!"); throw std::runtime_error("Can't move back that far!");
@@ -56,46 +56,46 @@ class common_chat_msg_parser {
// Content manipulation // Content manipulation
void add_content(const std::string & content); void add_content(const std::string & content);
void add_reasoning_content(const std::string & reasoning_content); void add_reasoning_content(const std::string & reasoning_content);
// Tool call manipulation // Tool call manipulation
void add_tool_call(const common_chat_tool_call & tool_call); void add_tool_call(const common_chat_tool_call & tool_call);
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments); bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
bool add_tool_call(const json & tool_call); bool add_tool_call(const json & tool_call);
bool add_tool_calls(const json & arr); bool add_tool_calls(const json & arr);
void clear_tools(); void clear_tools();
// Parsing utilities // Parsing utilities
std::string consume_rest(); std::string consume_rest();
bool try_consume_literal(const std::string & literal); bool try_consume_literal(const std::string & literal);
void consume_literal(const std::string & literal); void consume_literal(const std::string & literal);
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think); bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
// Regex-based parsing methods (new) // Regex-based parsing methods (new)
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true); std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
find_regex_result consume_regex(const common_regex & regex); find_regex_result consume_regex(const common_regex & regex);
std::optional<find_regex_result> try_consume_regex(const common_regex & regex); std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
// Progressive parsing primitives (for Phase 4) // Progressive parsing primitives (for Phase 4)
std::optional<find_regex_result> try_find_literal(const std::string & literal); std::optional<find_regex_result> try_find_literal(const std::string & literal);
bool consume_spaces(); bool consume_spaces();
void set_healing_marker(const std::string & marker); void set_healing_marker(const std::string & marker);
// Main parsing entry point // Main parsing entry point
void parse(); void parse();
// Finishing // Finishing
void finish(); void finish();
// Result extraction // Result extraction
common_chat_msg result_and_reset(); common_chat_msg result_and_reset();
// Advanced JSON parsing (following original llama.cpp patterns) // Advanced JSON parsing (following original llama.cpp patterns)
struct consume_json_result { struct consume_json_result {
json value; json value;
bool is_partial; bool is_partial;
}; };
std::optional<common_json> try_consume_json(); std::optional<common_json> try_consume_json();
common_json consume_json(); common_json consume_json();
consume_json_result consume_json_with_dumped_args( consume_json_result consume_json_with_dumped_args(
@@ -112,8 +112,8 @@ private:
void parse_kimi_k2_format(); void parse_kimi_k2_format();
void parse_deepseek_r1_format(); void parse_deepseek_r1_format();
void parse_generic_format(); void parse_generic_format();
// JSON parsing utilities (enhanced streaming support) // JSON parsing utilities (enhanced streaming support)
struct json_parse_result { struct json_parse_result {
json value; json value;
@@ -121,11 +121,11 @@ private:
bool is_partial; bool is_partial;
std::string healing_marker; std::string healing_marker;
}; };
// Partial detection utilities // Partial detection utilities
bool detect_partial_function_call(const std::string& content); bool detect_partial_function_call(const std::string& content);
void handle_partial_detection(); void handle_partial_detection();
// Legacy find_literal for compatibility // Legacy find_literal for compatibility
std::optional<find_regex_result> try_find_literal_legacy(const std::string & literal); std::optional<find_regex_result> try_find_literal_legacy(const std::string & literal);
}; };
@@ -133,4 +133,4 @@ private:
// Main parsing function (public API) // Main parsing function (public API)
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
// Content-only parsing for fallback scenarios (static internal function) // Content-only parsing for fallback scenarios (static internal function)

View File

@@ -220,7 +220,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
// Check for the new tools array format first (no DeepSeek markers) // Check for the new tools array format first (no DeepSeek markers)
auto original_pos = builder.pos(); auto original_pos = builder.pos();
// First, try the tools array format for content like "function\n```json\n{"tools": [...]}" // First, try the tools array format for content like "function\n```json\n{"tools": [...]}"
if (builder.try_find_regex(function_regex_simple)) { if (builder.try_find_regex(function_regex_simple)) {
builder.move_to(original_pos); builder.move_to(original_pos);
@@ -231,7 +231,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
// Fall through to try standard DeepSeek patterns // Fall through to try standard DeepSeek patterns
} }
} }
// If tools array format didn't work, try XML-wrapped format // If tools array format didn't work, try XML-wrapped format
builder.move_to(original_pos); builder.move_to(original_pos);
try { try {
@@ -240,7 +240,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
} catch (const common_chat_msg_partial_exception&) { } catch (const common_chat_msg_partial_exception&) {
// Fall through to try standard DeepSeek patterns // Fall through to try standard DeepSeek patterns
} }
// If XML wrapper format didn't work, try standard DeepSeek patterns // If XML wrapper format didn't work, try standard DeepSeek patterns
builder.move_to(original_pos); builder.move_to(original_pos);
try { try {
@@ -278,7 +278,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
throw; // Re-throw for partial mode throw; // Re-throw for partial mode
} }
} }
// Add any remaining content (critical for responses without tool calls) // Add any remaining content (critical for responses without tool calls)
builder.add_content(builder.consume_rest()); builder.add_content(builder.consume_rest());
} }
@@ -286,19 +286,19 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
// Parse DeepSeek R1 tools array format following original llama.cpp parse_prefixed_json_tool_call_array pattern // Parse DeepSeek R1 tools array format following original llama.cpp parse_prefixed_json_tool_call_array pattern
static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) { static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) {
static const common_regex prefix("function\n```json\n"); static const common_regex prefix("function\n```json\n");
if (auto res = builder.try_find_regex(prefix)) { if (auto res = builder.try_find_regex(prefix)) {
// Parse JSON and manually process tools array to convert arguments to strings // Parse JSON and manually process tools array to convert arguments to strings
auto json_result = builder.try_consume_json(); auto json_result = builder.try_consume_json();
if (!json_result) { if (!json_result) {
throw common_chat_msg_partial_exception("invalid JSON"); throw common_chat_msg_partial_exception("invalid JSON");
} }
// DeepSeek R1 format has "tools" array, manually process each tool // DeepSeek R1 format has "tools" array, manually process each tool
if (json_result->json.contains("tools") && json_result->json.at("tools").is_array()) { if (json_result->json.contains("tools") && json_result->json.at("tools").is_array()) {
// Manually create tool calls array with string arguments (following original pattern) // Manually create tool calls array with string arguments (following original pattern)
json tools_with_dumped_args = json::array(); json tools_with_dumped_args = json::array();
for (const auto& tool : json_result->json.at("tools")) { for (const auto& tool : json_result->json.at("tools")) {
@@ -310,15 +310,15 @@ static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) {
tools_with_dumped_args.push_back(formatted_tool); tools_with_dumped_args.push_back(formatted_tool);
} }
} }
if (!builder.add_tool_calls(tools_with_dumped_args) || !json_result->healing_marker.marker.empty()) { if (!builder.add_tool_calls(tools_with_dumped_args) || !json_result->healing_marker.marker.empty()) {
throw common_chat_msg_partial_exception("incomplete tool call array"); throw common_chat_msg_partial_exception("incomplete tool call array");
} }
} else { } else {
throw common_chat_msg_partial_exception("tools key not found or not array"); throw common_chat_msg_partial_exception("tools key not found or not array");
} }
// Consume closing ``` // Consume closing ```
builder.try_consume_regex(common_regex("```")); builder.try_consume_regex(common_regex("```"));
} else { } else {
@@ -326,41 +326,41 @@ static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) {
} }
} }
// Parse DeepSeek R1 XML-wrapped format following original Hermes-2-Pro pattern // Parse DeepSeek R1 XML-wrapped format following original Hermes-2-Pro pattern
static void parse_deepseek_r1_xml_wrapped(common_chat_msg_parser & builder) { static void parse_deepseek_r1_xml_wrapped(common_chat_msg_parser & builder) {
// Pattern for: <tool_call>\nfunction</think>FunctionName\n```json\n{...}\n```\n</tool_call> // Pattern for: <tool_call>\nfunction</think>FunctionName\n```json\n{...}\n```\n</tool_call>
static const common_regex xml_pattern( static const common_regex xml_pattern(
"<tool_call>\\s*" // Opening XML tag "<tool_call>\\s*" // Opening XML tag
"function</think>([^\\n]+)" // Function name after "function</think>" "function</think>([^\\n]+)" // Function name after "function</think>"
"\\s*```json\\s*" // JSON block start "\\s*```json\\s*" // JSON block start
); );
if (auto res = builder.try_find_regex(xml_pattern)) { if (auto res = builder.try_find_regex(xml_pattern)) {
// Extract function name from capture group // Extract function name from capture group
std::string function_name = builder.str(res->groups[1]); std::string function_name = builder.str(res->groups[1]);
// Parse JSON arguments // Parse JSON arguments
auto json_result = builder.try_consume_json(); auto json_result = builder.try_consume_json();
if (!json_result) { if (!json_result) {
throw common_chat_msg_partial_exception("invalid JSON in XML wrapper"); throw common_chat_msg_partial_exception("invalid JSON in XML wrapper");
} }
// Create single tool call following original pattern // Create single tool call following original pattern
json tool_call; json tool_call;
tool_call["name"] = function_name; tool_call["name"] = function_name;
tool_call["arguments"] = json_result->json.dump(); // Convert to string tool_call["arguments"] = json_result->json.dump(); // Convert to string
json tool_calls_array = json::array(); json tool_calls_array = json::array();
tool_calls_array.push_back(tool_call); tool_calls_array.push_back(tool_call);
if (!builder.add_tool_calls(tool_calls_array) || !json_result->healing_marker.marker.empty()) { if (!builder.add_tool_calls(tool_calls_array) || !json_result->healing_marker.marker.empty()) {
throw common_chat_msg_partial_exception("incomplete XML wrapped tool call"); throw common_chat_msg_partial_exception("incomplete XML wrapped tool call");
} }
// Consume closing ```\n</tool_call> // Consume closing ```\n</tool_call>
builder.try_consume_regex(common_regex("```\\s*</tool_call>")); builder.try_consume_regex(common_regex("```\\s*</tool_call>"));
} else { } else {
@@ -384,6 +384,15 @@ static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
builder.add_content(kimi_k2::clean_content(builder.input())); builder.add_content(kimi_k2::clean_content(builder.input()));
} }
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
// TODO @ngxson : this won't work with --special enabled, we should fix that
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
if (!builder.syntax().enable_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
}
// Main parsing dispatch function // Main parsing dispatch function
static void common_chat_parse(common_chat_msg_parser & builder) { static void common_chat_parse(common_chat_msg_parser & builder) {
switch (builder.syntax().format) { switch (builder.syntax().format) {
@@ -399,6 +408,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_KIMI_K2: case COMMON_CHAT_FORMAT_KIMI_K2:
common_chat_parse_kimi_k2(builder); common_chat_parse_kimi_k2(builder);
break; break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
default: default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
} }
@@ -432,6 +444,19 @@ const char* common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_GENERIC: return "generic"; case COMMON_CHAT_FORMAT_GENERIC: return "generic";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "deepseek_r1"; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "deepseek_r1";
case COMMON_CHAT_FORMAT_KIMI_K2: return "kimi_k2"; case COMMON_CHAT_FORMAT_KIMI_K2: return "kimi_k2";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
default: return "unknown"; default: return "unknown";
} }
} }
const char * common_reasoning_format_name(common_reasoning_format format) {
switch (format) {
case COMMON_REASONING_FORMAT_NONE: return "none";
case COMMON_REASONING_FORMAT_AUTO: return "auto";
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
default:
throw std::runtime_error("Unknown reasoning format");
}
}

View File

@@ -13,20 +13,20 @@ struct common_chat_templates;
struct common_string_range { struct common_string_range {
size_t begin; size_t begin;
size_t end; size_t end;
common_string_range(size_t begin, size_t end) : begin(begin), end(end) { common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
if (begin > end) { if (begin > end) {
throw std::runtime_error("Invalid range"); throw std::runtime_error("Invalid range");
} }
} }
// prevent default ctor // prevent default ctor
common_string_range() = delete; common_string_range() = delete;
bool empty() const { bool empty() const {
return begin == end; return begin == end;
} }
bool operator==(const common_string_range & other) const { bool operator==(const common_string_range & other) const {
return begin == other.begin && end == other.end; return begin == other.begin && end == other.end;
} }
@@ -40,7 +40,7 @@ struct common_chat_tool_call {
bool operator==(const common_chat_tool_call & other) const { bool operator==(const common_chat_tool_call & other) const {
return name == other.name && arguments == other.arguments && id == other.id; return name == other.name && arguments == other.arguments && id == other.id;
} }
bool operator!=(const common_chat_tool_call & other) const { bool operator!=(const common_chat_tool_call & other) const {
return !(*this == other); return !(*this == other);
} }
@@ -65,10 +65,10 @@ struct common_chat_msg {
std::string tool_call_id; std::string tool_call_id;
bool empty() const { bool empty() const {
return content.empty() && content_parts.empty() && tool_calls.empty() && return content.empty() && content_parts.empty() && tool_calls.empty() &&
reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
} }
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) { void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
for (auto i = 0u; i < tool_calls.size(); i++) { for (auto i = 0u; i < tool_calls.size(); i++) {
if (ids_cache.size() <= i) { if (ids_cache.size() <= i) {
@@ -91,7 +91,7 @@ struct common_chat_msg {
&& tool_name == other.tool_name && tool_name == other.tool_name
&& tool_call_id == other.tool_call_id; && tool_call_id == other.tool_call_id;
} }
bool operator!=(const common_chat_msg & other) const { bool operator!=(const common_chat_msg & other) const {
return !(*this == other); return !(*this == other);
} }
@@ -110,7 +110,7 @@ struct common_chat_msg_diff {
&& tool_call_index == other.tool_call_index && tool_call_index == other.tool_call_index
&& tool_call_delta == other.tool_call_delta; && tool_call_delta == other.tool_call_delta;
} }
bool operator!=(const common_chat_msg_diff & other) const { bool operator!=(const common_chat_msg_diff & other) const {
return !(*this == other); return !(*this == other);
} }
@@ -132,18 +132,20 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY, COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC, COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_DEEPSEEK_R1, COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_KIMI_K2, // Our custom format (keep last for backward compatibility) COMMON_CHAT_FORMAT_KIMI_K2, // Our custom format (keep last for backward compatibility)
}; };
enum common_reasoning_format { enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE, COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_AUTO,
COMMON_REASONING_FORMAT_DEEPSEEK, COMMON_REASONING_FORMAT_DEEPSEEK,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY,
}; };
struct common_chat_syntax { struct common_chat_syntax {
common_chat_format format = COMMON_CHAT_FORMAT_KIMI_K2; common_chat_format format = COMMON_CHAT_FORMAT_KIMI_K2;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; //COMMON_REASONING_FORMAT_NONE;
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false; bool reasoning_in_content = false;
bool thinking_forced_open = false; bool thinking_forced_open = false;
@@ -165,11 +167,12 @@ class common_chat_msg_partial_exception : public std::runtime_error {
// Format detection from chat template // Format detection from chat template
common_chat_format common_chat_format_detect(const std::string & chat_template); common_chat_format common_chat_format_detect(const std::string & chat_template);
const char* common_chat_format_name(common_chat_format format); const char* common_chat_format_name(common_chat_format format);
const char* common_reasoning_format_name(common_reasoning_format format);
// Main parsing function (entry point for original llama.cpp compatibility) // Main parsing function (entry point for original llama.cpp compatibility)
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
// Forward declare parser class // Forward declare parser class
class common_chat_msg_parser; class common_chat_msg_parser;
// Format-specific parsing functions (accessible from chat-parser) // Format-specific parsing functions (accessible from chat-parser)

View File

@@ -61,7 +61,7 @@ int main(int argc, char ** argv) {
const llama_vocab * vocab = llama_get_vocab(ctx); const llama_vocab * vocab = llama_get_vocab(ctx);
llama_token bos = llama_token_bos_impl(*vocab); llama_token bos = vocab->token_bos();
//llama_token eos = llama_token_eos_impl(*vocab); //llama_token eos = llama_token_eos_impl(*vocab);
const unsigned int n_vocab = llama_n_vocab(model); const unsigned int n_vocab = llama_n_vocab(model);

View File

@@ -132,7 +132,7 @@ set (GGML_CUDA_MIN_BATCH_OFFLOAD "32" CACHE STRING
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF) option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF) option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ON)
option(GGML_IQK_FLASH_ATTENTION "ggml: enable the IQK FlashAttention CPU kernels" ON) option(GGML_IQK_FLASH_ATTENTION "ggml: enable the IQK FlashAttention CPU kernels" ON)
option(GGML_IQK_FA_ALL_QUANTS "ggml: compile all quants for IQK FlashAttention" OFF) option(GGML_IQK_FA_ALL_QUANTS "ggml: compile all quants for IQK FlashAttention" OFF)

View File

@@ -325,6 +325,16 @@
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb) GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#define GGML_TENSOR_TERNARY_OP_LOCALS \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#define GGML_TENSOR_BINARY_OP_LOCALS01 \ #define GGML_TENSOR_BINARY_OP_LOCALS01 \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
@@ -571,6 +581,7 @@ extern "C" {
GGML_OP_DUP, GGML_OP_DUP,
GGML_OP_ADD, GGML_OP_ADD,
GGML_OP_ADD_ID,
GGML_OP_ADD1, GGML_OP_ADD1,
GGML_OP_ACC, GGML_OP_ACC,
GGML_OP_SUB, GGML_OP_SUB,
@@ -674,6 +685,7 @@ extern "C" {
GGML_UNARY_OP_HARDSWISH, GGML_UNARY_OP_HARDSWISH,
GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_SWIGLU, GGML_UNARY_OP_SWIGLU,
GGML_UNARY_OP_SWIGLU_OAI,
GGML_UNARY_OP_COUNT, GGML_UNARY_OP_COUNT,
}; };
@@ -1028,6 +1040,13 @@ extern "C" {
struct ggml_tensor * b, struct ggml_tensor * b,
enum ggml_type type); enum ggml_type type);
// dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
GGML_API struct ggml_tensor * ggml_add_id(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * ids);
GGML_API struct ggml_tensor * ggml_add1( GGML_API struct ggml_tensor * ggml_add1(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@@ -1268,6 +1287,13 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_swiglu_oai(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float alpha,
float limit);
// a - x // a - x
// b - dy // b - dy
GGML_API struct ggml_tensor * ggml_silu_back( GGML_API struct ggml_tensor * ggml_silu_back(
@@ -1370,6 +1396,16 @@ extern "C" {
struct ggml_tensor * ids, struct ggml_tensor * ids,
enum ggml_unary_op op); enum ggml_unary_op op);
GGML_API struct ggml_tensor * ggml_moe_up_gate_ext(
struct ggml_context * ctx,
struct ggml_tensor * a_up,
struct ggml_tensor * a_gate,
struct ggml_tensor * b,
struct ggml_tensor * ids,
struct ggml_tensor * a_up_b,
struct ggml_tensor * a_gate_b,
enum ggml_unary_op op);
// A: m columns, n rows, // A: m columns, n rows,
// B: p columns, n rows, // B: p columns, n rows,
// result is m columns, p rows // result is m columns, p rows
@@ -1662,6 +1698,11 @@ extern "C" {
float scale, float scale,
float max_bias); float max_bias);
GGML_API void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);
GGML_API struct ggml_tensor * ggml_soft_max_back( GGML_API struct ggml_tensor * ggml_soft_max_back(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@@ -1998,6 +2039,10 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
enum ggml_prec prec); enum ggml_prec prec);
GGML_API void ggml_flash_attn_ext_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);
// TODO: needs to be adapted to ggml_flash_attn_ext // TODO: needs to be adapted to ggml_flash_attn_ext
GGML_API struct ggml_tensor * ggml_flash_attn_back( GGML_API struct ggml_tensor * ggml_flash_attn_back(
struct ggml_context * ctx, struct ggml_context * ctx,

View File

@@ -43,6 +43,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
case GGML_OP_DIAG_MASK_ZERO: case GGML_OP_DIAG_MASK_ZERO:
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1: case GGML_OP_ADD1:
case GGML_OP_SUB: case GGML_OP_SUB:
case GGML_OP_MUL: case GGML_OP_MUL:

View File

@@ -37,6 +37,8 @@
#include "ggml-cuda/unary.cuh" #include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh" #include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/add-id.cuh"
#include "ggml-cuda/graph.cuh"
#include <algorithm> #include <algorithm>
#include <array> #include <array>
@@ -49,6 +51,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <condition_variable>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <stdarg.h> #include <stdarg.h>
@@ -77,6 +80,7 @@ GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback,
#define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__) #define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
#define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__) #define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
#define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) #define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
#define GGML_CUDA_LOG_DEBUG(...) ggml_cuda_log(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
GGML_ATTRIBUTE_FORMAT(2, 3) GGML_ATTRIBUTE_FORMAT(2, 3)
static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) { static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) {
@@ -444,6 +448,35 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device)); return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
} }
static std::mutex ggml_cuda_lock;
static std::condition_variable ggml_cuda_lock_cv;
static std::atomic<int> ggml_cuda_lock_counter;
ggml_backend_cuda_context::ggml_backend_cuda_context(int device) :
device(device), name(GGML_CUDA_NAME + std::to_string(device)) {
}
ggml_backend_cuda_context::~ggml_backend_cuda_context() {
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
if (copy_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(copy_event));
}
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (streams[i][j] != nullptr) {
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
}
}
if (cublas_handles[i] != nullptr) {
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
}
}
}
// cuda buffer // cuda buffer
struct ggml_backend_cuda_buffer_context { struct ggml_backend_cuda_buffer_context {
@@ -2220,6 +2253,24 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin
} }
} }
//static __global__ void k_quick_add(uint32_t n, uint32_t n_per_row, const float * src1, const float * src2, float * dst) {
//
// for (uint32_t j = threadIdx.x; j < n; j += blockDim.x) {
// dst[j] = src1[j] + src2[j % n_per_row];
// }
//}
static __global__ void k_quick_add(uint32_t n_per_row, const float * src1, const float * src2, float * dst) {
uint32_t row = blockIdx.x;
const float * src1_row = src1 + row*n_per_row;
float * dst_row = dst + row*n_per_row;
for (uint32_t j = threadIdx.x; j < n_per_row; j += blockDim.x) {
dst_row[j] = src1_row[j] + src2[j];
}
}
static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids, static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
const ggml_tensor * ids, std::vector<int>& moe_counts, std::vector<int>& cum_moe_counts, const ggml_tensor * ids, std::vector<int>& moe_counts, std::vector<int>& cum_moe_counts,
ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) { ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {
@@ -2270,7 +2321,7 @@ static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n
return is_ser; return is_ser;
} }
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * ids = dst->src[2]; const ggml_tensor * ids = dst->src[2];
@@ -2319,7 +2370,25 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
0, src0->ne[1], 1, src1_padded_col_size, stream); 0, src0->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
return; if (next && next->op == GGML_OP_MUL_MAT_ID && next->src[0]->type == src0->type && src1 == next->src[1] &&
ggml_are_same_shape(src0, next->src[0]) &&
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
ggml_backend_buffer_is_cuda(next->buffer) &&
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer)) {
ggml_backend_cuda_buffer_context * next_src0_ctx = (ggml_backend_cuda_buffer_context *) next->src[0]->buffer->context;
ggml_backend_cuda_buffer_context * next_dst_ctx = (ggml_backend_cuda_buffer_context *) next->buffer->context;
if (next_src0_ctx->device == device_id &&
next_dst_ctx->device == device_id) {
local_dst.data = next->data;
ggml_cuda_op_mul_mat_vec_q_id(ctx, next->src[0], &local_src1, ids, &local_dst,
(const char *)next->src[0]->data, nullptr, src1_quantized.get(), (float *)next->data,
0, src0->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
return true;
}
}
return false;
} }
} }
@@ -2356,7 +2425,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
dst_row.nb[2] = nb1; dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1; dst_row.nb[3] = nb1;
if (ne12 == 1) { if (false && ne12 == 1) {
std::vector<char> ids_host(ggml_nbytes(ids)); std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data; const char * ids_dev = (const char *) ids->data;
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
@@ -2442,6 +2511,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
} }
} }
} }
return false;
} }
static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) { static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
@@ -2470,6 +2540,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
src0_2_ctx->device == device_id && src0_2_ctx->device == device_id &&
src1_ctx->device == device_id && src1_ctx->device == device_id &&
dst_ctx->device == device_id) { dst_ctx->device == device_id) {
//printf("%s(%s, %s): %ld x %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, src0_1->name, src0_2->name,
// src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2], ids->ne[0], ids->ne[1], ids->ne[2]);
// Fast TG path // Fast TG path
const int64_t n_ids = ids->ne[0]; const int64_t n_ids = ids->ne[0];
auto stream = ctx.stream(device_id, 0); auto stream = ctx.stream(device_id, 0);
@@ -2505,12 +2577,26 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
0, src0_1->ne[1], 1, src1_padded_col_size, stream); 0, src0_1->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
if (dst->src[4]) {
ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[4]->data,
(const int32_t *)ids->data, (float *)local_dst.data,
local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2],
local_dst.nb[1], local_dst.nb[2], dst->src[4]->nb[1], ids->nb[2], stream);
}
local_dst.data = dst_gate_contiguous.get(); local_dst.data = dst_gate_contiguous.get();
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst, ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst,
(const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(), (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(),
0, src0_2->ne[1], 1, src1_padded_col_size, stream); 0, src0_2->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
if (dst->src[5]) {
ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[5]->data,
(const int32_t *)ids->data, (float *)local_dst.data,
local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2],
local_dst.nb[1], local_dst.nb[2], dst->src[5]->nb[1], ids->nb[2], stream);
}
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
ggml_backend_buffer_is_cuda(next->src[0]->buffer) && ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) && !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) &&
@@ -2518,8 +2604,15 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
ggml_backend_buffer_is_cuda(next->buffer) && ggml_backend_buffer_is_cuda(next->buffer) &&
((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) { ((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst->ne[0]*n_ids, auto unary_op = (ggml_unary_op)dst->op_params[0];
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get(), dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, unary_op, dst->ne[0]*n_ids,
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get());
}
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING); const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING);
@@ -2555,8 +2648,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
return true; return true;
} else { } else {
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream)); CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst), auto unary_op = (ggml_unary_op)dst->op_params[0];
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst->data, dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, unary_op, ggml_nelements(dst),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
}
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
return false; return false;
} }
@@ -2624,7 +2723,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
final_src.nb[3] = final_src.nb[2]; final_src.nb[3] = final_src.nb[2];
} }
if (ne12 == 1) { if (false && ne12 == 1) {
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
if (fuse_down) { if (fuse_down) {
@@ -2761,6 +2860,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
} }
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
if (dst->src[4]) {
dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u));
dim3 grid_dims(num_src1_rows);
k_quick_add<<<grid_dims, block_dims, 0, stream>>>(dst_row.ne[0], (const float *)dst_row.data,
(const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data);
CUDA_CHECK(cudaGetLastError());
}
dst_row.data = dst_gate_contiguous.get(); dst_row.data = dst_gate_contiguous.get();
if (use_quantized_src1) { if (use_quantized_src1) {
ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
@@ -2770,8 +2877,24 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
} }
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), if (dst->src[5]) {
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u));
dim3 grid_dims(num_src1_rows);
k_quick_add<<<grid_dims, block_dims, 0, stream>>>(dst_row.ne[0], (const float *)dst_row.data,
(const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data);
CUDA_CHECK(cudaGetLastError());
}
auto unary_op = (ggml_unary_op)dst->op_params[0];
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0],
1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get());
}
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
if (fuse_down) { if (fuse_down) {
@@ -2851,6 +2974,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ADD: case GGML_OP_ADD:
ggml_cuda_op_add(ctx, dst); ggml_cuda_op_add(ctx, dst);
break; break;
case GGML_OP_ADD_ID:
ggml_cuda_op_add_id(ctx, dst);
break;
case GGML_OP_MULTI_ADD: case GGML_OP_MULTI_ADD:
ggml_cuda_op_multi_add(ctx, dst); ggml_cuda_op_multi_add(ctx, dst);
break; break;
@@ -2877,6 +3003,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_SWIGLU: case GGML_UNARY_OP_SWIGLU:
ggml_cuda_op_swiglu(ctx, dst); ggml_cuda_op_swiglu(ctx, dst);
break; break;
case GGML_UNARY_OP_SWIGLU_OAI:
ggml_cuda_op_swiglu_oai(ctx, dst);
break;
case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_GELU_QUICK:
ggml_cuda_op_gelu_quick(ctx, dst); ggml_cuda_op_gelu_quick(ctx, dst);
break; break;
@@ -2938,7 +3067,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
} }
break; break;
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
ggml_cuda_mul_mat_id(ctx, dst); skip_next = ggml_cuda_mul_mat_id(ctx, dst, next);
break; break;
case GGML_OP_MOE_FUSED_UP_GATE: case GGML_OP_MOE_FUSED_UP_GATE:
skip_next = ggml_cuda_up_gate_unary(ctx, dst, next); skip_next = ggml_cuda_up_gate_unary(ctx, dst, next);
@@ -3119,6 +3248,105 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
GGML_UNUSED(backend); GGML_UNUSED(backend);
} }
#ifdef USE_CUDA_GRAPH
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool use_cuda_graph) {
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
#endif
}
if (node->op == GGML_OP_MUL_MAT_ID && (node->ne[2] != 1 || node->src[2]->ne[0] != 1)) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s(%s): disabling CUDA graphs due to unsupported node type %ld %ld\n",
__func__, node->src[0]->name, node->ne[2], node->src[2]->ne[0]);
#endif
}
if (node->op == GGML_OP_MOE_FUSED_UP_GATE) {
auto src0_1 = node->src[0];
auto src0_2 = node->src[1];
auto src1 = node->src[2];
if (src1->ne[1] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || src1->type != GGML_TYPE_F32 ||
!ggml_is_quantized(src0_1->type) || !ggml_is_quantized(src0_2->type)) {
use_cuda_graph = false;
} else {
if (i < cgraph->n_nodes-1) {
auto next = cgraph->nodes[i+1];
if (next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type)) {
++i;
}
}
}
}
if (node->op == GGML_OP_ADD &&
node->src[1] && node->src[1]->ne[1] > 1 &&
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
use_cuda_graph = false;
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
#endif
}
if (node->op == GGML_OP_CPY) {
// Store the pointers which are updated for each token, such that these can be sent
// to the device and accessed using indirection from CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
// store a pointer to each copy op CUDA kernel to identify it later
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
if (!ptr) {
use_cuda_graph = false;
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
#endif
}
}
if (!use_cuda_graph) {
break;
}
}
if (use_cuda_graph) {
cuda_ctx->cuda_graph->use_cpy_indirection = true;
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
}
return use_cuda_graph;
}
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
graph_node_properties->node_address = node->data; graph_node_properties->node_address = node->data;
graph_node_properties->node_op = node->op; graph_node_properties->node_op = node->op;
@@ -3129,6 +3357,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
for (int i = 0; i < GGML_MAX_SRC; i++) { for (int i = 0; i < GGML_MAX_SRC; i++) {
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
} }
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
} }
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
@@ -3160,9 +3389,246 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
return false; return false;
} }
} }
if (node->op == GGML_OP_SCALE &&
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false;
}
return true; return true;
} }
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
bool cuda_graph_update_required = false;
if (cuda_ctx->cuda_graph->instance == nullptr) {
cuda_graph_update_required = true;
}
// Check if the graph size has changed
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
cuda_graph_update_required = true;
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
}
// Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = true;
if (!cuda_graph_update_required) {
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
if (!has_matching_properties) {
cuda_graph_update_required = true;
}
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
return cuda_graph_update_required;
}
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
#else
cudaGraphNode_t errorNode;
cudaGraphExecUpdateResult result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
#endif // CUDART_VERSION >= 12000
if (stat == cudaErrorGraphExecUpdateFailure) {
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
#endif
// The pre-existing graph exec cannot be updated due to violated constraints
// so instead clear error and re-instantiate
(void)cudaGetLastError();
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
cuda_ctx->cuda_graph->instance = nullptr;
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
} else {
GGML_ASSERT(stat == cudaSuccess);
}
}
#endif
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
// flag used to determine whether it is an integrated_gpu
// TODO
const bool integrated = false; //ggml_cuda_info().devices[cuda_ctx->device].integrated;
while (!graph_evaluated_or_captured) {
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
// With the use of CUDA graphs, the execution will be performed by the graph launch.
if (!use_cuda_graph || cuda_graph_update_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr;
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
#if 0
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
i += 2;
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
continue;
}
}
#endif
#ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) {
assert(node->src[j]->buffer);
//assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
// ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
}
}
#else
GGML_UNUSED(integrated);
#endif // NDEBUG
bool skip_next = false;
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next);
if (!ok) {
GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
if (skip_next) ++i;
}
}
#ifdef USE_CUDA_GRAPH
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
if (cuda_ctx->cuda_graph->graph != nullptr) {
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
cuda_ctx->cuda_graph->graph = nullptr;
}
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
graph_evaluated_or_captured = true; // CUDA graph has been captured
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
ggml_cuda_lock_cv.notify_all();
}
} else {
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
}
}
if (use_cuda_graph) {
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
}
if (cuda_graph_update_required) { // Update graph executable
update_cuda_graph_executable(cuda_ctx);
}
// Launch graph
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
#else
graph_evaluated_or_captured = true;
#endif // USE_CUDA_GRAPH
}
}
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_device(cuda_ctx->device);
#ifdef USE_CUDA_GRAPH
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
// Objects required for CUDA Graph
if (cuda_ctx->cuda_graph == nullptr) {
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
}
bool use_cuda_graph = true;
bool cuda_graph_update_required = false;
if (cuda_ctx->cuda_graph->graph == nullptr) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
#endif
}
}
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
// or previous graph capture failure.
// Also disable for multi-gpu for now. TO DO investigate
if (disable_cuda_graphs_due_to_env
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
use_cuda_graph = false;
}
if (use_cuda_graph) {
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) {
cuda_ctx->cuda_graph->number_consecutive_updates++;
} else {
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
}
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
#endif
}
}
if (use_cuda_graph && cuda_graph_update_required) {
// Start CUDA graph capture
{
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
}
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
}
if (!use_cuda_graph) {
cuda_ctx->cuda_graph->use_cpy_indirection = false;
}
#else
bool use_cuda_graph = false;
bool cuda_graph_update_required = false;
#endif // USE_CUDA_GRAPH
bool graph_evaluated_or_captured = false;
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
return GGML_STATUS_SUCCESS;
}
/*
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
@@ -3431,6 +3897,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }
*/
GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) { GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
@@ -3440,6 +3907,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_SWIGLU: case GGML_UNARY_OP_SWIGLU:
case GGML_UNARY_OP_SWIGLU_OAI:
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSIGMOID:
@@ -3629,6 +4097,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE: case GGML_OP_TRANSPOSE:
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_MULTI_ADD: case GGML_OP_MULTI_ADD:
case GGML_OP_MUL: case GGML_OP_MUL:
case GGML_OP_DIV: case GGML_OP_DIV:

View File

@@ -0,0 +1,72 @@
#include "add-id.cuh"
static __global__ void add_id_kernel(
const float * src0, const float * src1, const int32_t * src2, float * dst,
int64_t ne0, int64_t ne1,
size_t nb01, size_t nb02,
size_t nb11,
size_t nb21
) {
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
const size_t nb1 = ne0 * sizeof(float);
const size_t nb2 = ne1 * nb1;
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02);
const float * src1_row = (const float *)((char *)src1 + i11*nb11);
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb20 == sizeof(int32_t));
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
const int32_t * src2_d = (const int32_t *)src2->data;
float * dst_d = (float *)dst->data;
int threads = std::min((int)ne00, 768); // cols
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
src0_d, src1_d, src2_d, dst_d,
ne0, ne1,
nb01, nb02,
nb11,
nb21
);
}
void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst,
int64_t ne00, int64_t ne01, int64_t ne02,
int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream) {
int threads = std::min((int)ne00, 768); // cols
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
add_id_kernel<<<blocks, threads, 0, stream>>>(
src0, src1, src2, dst,
ne0, ne1,
nb01, nb02,
nb11,
nb21
);
}

View File

@@ -0,0 +1,8 @@
#include "common.cuh"
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst,
int64_t ne00, int64_t ne01, int64_t ne02,
int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream);

View File

@@ -108,6 +108,23 @@ static const char * cu_get_error_str(CUresult err) {
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
#endif #endif
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
do { \
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
const int id = ggml_cuda_get_device(); \
if (!shared_memory_limit_raised[id]) { \
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
shared_memory_limit_raised[id] = true; \
} \
} while (0)
#else
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
do { \
GGML_UNUSED(nbytes); \
} while (0)
#endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) #if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
#define GGML_CUDA_ASSUME(x) __builtin_assume(x) #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
#else #else
@@ -808,37 +825,7 @@ struct ggml_tensor_extra_gpu {
#define USE_CUDA_GRAPH #define USE_CUDA_GRAPH
#endif #endif
struct ggml_graph_node_properties { struct ggml_cuda_graph;
void * node_address;
ggml_op node_op;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
};
struct ggml_cuda_graph {
#ifdef USE_CUDA_GRAPH
~ggml_cuda_graph() {
if (instance != nullptr) {
CUDA_CHECK(cudaGraphExecDestroy(instance));
}
if (graph != nullptr) {
CUDA_CHECK(cudaGraphDestroy(graph));
}
}
cudaGraph_t graph = nullptr;
cudaGraphExec_t instance = nullptr;
size_t num_nodes = 0;
std::vector<cudaGraphNode_t> nodes;
std::vector<cudaKernelNodeParams> params;
bool disable_due_to_gpu_arch = false;
bool disable_due_to_too_many_updates = false;
bool disable_due_to_failed_graph_capture = false;
int number_consecutive_updates = 0;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
std::vector<char **> updated_kernel_arg;
#endif
};
struct ggml_backend_cuda_context { struct ggml_backend_cuda_context {
int device; int device;
@@ -850,26 +837,9 @@ struct ggml_backend_cuda_context {
std::unique_ptr<ggml_cuda_graph> cuda_graph; std::unique_ptr<ggml_cuda_graph> cuda_graph;
explicit ggml_backend_cuda_context(int device) : explicit ggml_backend_cuda_context(int device);
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
}
~ggml_backend_cuda_context() { ~ggml_backend_cuda_context();
if (copy_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(copy_event));
}
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (streams[i][j] != nullptr) {
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
}
}
if (cublas_handles[i] != nullptr) {
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
}
}
}
cudaStream_t stream(int device, int stream) { cudaStream_t stream(int device, int stream) {
if (streams[device][stream] == nullptr) { if (streams[device][stream] == nullptr) {

View File

@@ -0,0 +1,262 @@
#pragma once
#include "ggml-common.h"
template<typename src_t, typename dst_t>
static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
if constexpr (std::is_same_v<src_t, dst_t>) {
*dst = *src;
} else {
*dst = float(*src);
}
}
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_0; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK4_0/2 + j]*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
y->qs[j] = xi0;
y->qs[j] |= xi1 << 4;
}
}
static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
float vmin = FLT_MAX;
float vmax = -FLT_MAX;
for (int j = 0; j < QK4_1; ++j) {
const float v = x[j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
y->dm.x = d;
y->dm.y = vmin;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (x[0 + j] - vmin)*id;
const float x1 = (x[QK4_1/2 + j] - vmin)*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
y->qs[j] = xi0;
y->qs[j] |= xi1 << 4;
}
}
static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK5_0; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK5_0/2 + j]*id;
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
memcpy(y->qh, &qh, sizeof(qh));
}
static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
float min = x[0];
float max = x[0];
for (int j = 1; j < QK5_1; ++j) {
const float v = x[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
y->dm.x = d;
y->dm.y = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (x[0 + j] - min)*id;
const float x1 = (x[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
memcpy(y->qh, &qh, sizeof(qh));
}
static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = x[j];
amax = fmaxf(amax, fabsf(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = x[j]*id;
y->qs[j] = roundf(x0);
}
}
static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_NL; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
y->qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = x[0 + j]*x[0 + j];
const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
y->d = sumq2 > 0 ? sumqx/sumq2 : d;
}
static __device__ void quantize_f32_q6_0_block(const float * __restrict__ xi, block_q6_0 * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK6_0; ++j) {
const float v = xi[j];
const float av = fabsf(xi[j]);
if (amax < av) {
amax = av;
vmax = v;
}
}
const float d = vmax / -32;
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
memset(y->qh, 0, QK6_0/4);
for (int j = 0; j < QK6_0/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_0/2 + j]*id;
const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f));
const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f));
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
y->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4)));
}
}
// Wrapper functions for cpy.cu compatibility
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
}
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
}
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
}
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
}
static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) {
quantize_f32_q6_0_block((const float *)cxi, (block_q6_0 *)cdsti);
}
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
}
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
}
template<typename src_t, typename dst_t>
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
convert_flt((const src_t *)cxi, (dst_t *)cdsti);
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,11 @@
#include "common.cuh" #include "common.cuh"
#define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_CPY_BLOCK_SIZE 64
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false);
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1); void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);

View File

@@ -86,6 +86,24 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
#endif // GGML_CUDA_F16 #endif // GGML_CUDA_F16
} }
static __device__ __forceinline__ void dequantize_q6_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q6_0 * x = (const block_q6_0 *) vx;
const dfloat d = x[ib].d;
const uint8_t h = x[ib].qh[iqs%8] >> 2*(iqs/8);
v.x = ((x[ib].qs[iqs] & 0xf) | ((h & 0x3) << 4));
v.y = ((x[ib].qs[iqs] >> 4) | ((h & 0xc) << 2));
#ifdef GGML_CUDA_F16
v = __hsub2(v, {32.0f, 32.0f});
v = __hmul2(v, {d, d});
#else
v.x = (v.x - 32.0f) * d;
v.y = (v.y - 32.0f) * d;
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q8_0 * x = (const block_q8_0 *) vx; const block_q8_0 * x = (const block_q8_0 *) vx;

View File

@@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)(
const char * __restrict__ K, const char * __restrict__ K,
const char * __restrict__ V, const char * __restrict__ V,
const char * __restrict__ mask, const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst, float * __restrict__ dst,
float2 * __restrict__ dst_meta, float2 * __restrict__ dst_meta,
const float scale, const float scale,
@@ -747,6 +748,7 @@ void launch_fattn(
const ggml_tensor * V = dst->src[2]; const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3]; const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
ggml_tensor * KQV = dst; ggml_tensor * KQV = dst;
@@ -837,6 +839,7 @@ void launch_fattn(
K_data, K_data,
V_data, V_data,
mask ? ((const char *) mask->data) : nullptr, mask ? ((const char *) mask->data) : nullptr,
sinks ? ((const char *) sinks->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, softcap, n_head_log2, scale, max_bias, m0, m1, softcap, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
@@ -1008,7 +1011,8 @@ void launch_fattn_mma(
const ggml_tensor * K = dst->src[1]; const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2]; const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3]; const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
ggml_tensor * KQV = dst; ggml_tensor * KQV = dst;
@@ -1162,6 +1166,7 @@ void launch_fattn_mma(
K_data, K_data,
V_data, V_data,
mask ? ((const char *) mask->data) : nullptr, mask ? ((const char *) mask->data) : nullptr,
sinks ? ((const char *)sinks->data) : nullptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap, scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],

View File

@@ -425,6 +425,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const half2 * const __restrict__ K_h2, const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2, const half2 * const __restrict__ V_h2,
const half2 * const __restrict__ mask_h2, const half2 * const __restrict__ mask_h2,
const float * const __restrict__ sinks_f,
float2 * const __restrict__ dstk, float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup, float2 * const __restrict__ dstk_fixup,
const float scale, const float scale,
@@ -584,6 +585,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
} }
} }
// If attention sinks are used, potentially re-scale if KQ_max is small.
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
// so it's being done unconditionally for every thread.
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
float KQ_max_scale[cols_per_thread];
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
const float sink = sinks_f[jc % ncols2];
const float KQ_max_new = fmaxf(KQ_max[col], sink);
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
KQ_max_scale[col] = expf(KQ_max_diff);
KQ_max[col] = KQ_max_new;
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
const float KQ_max_add = expf(sink - KQ_max_new);
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
}
if (ntiles == 1) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
#pragma unroll
for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
#pragma unroll
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
} else {
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
#pragma unroll
for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) {
#pragma unroll
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
}
}
}
}
}
// Write VKQ accumulators to shared memory in column-major format. // Write VKQ accumulators to shared memory in column-major format.
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// Also for np > 1 the combination is done via these values in shared memory. // Also for np > 1 the combination is done via these values in shared memory.
@@ -823,6 +870,7 @@ static __global__ void flash_attn_mma_ext_f16(
const char * __restrict__ K, const char * __restrict__ K,
const char * __restrict__ V, const char * __restrict__ V,
const char * __restrict__ mask, const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst, float * __restrict__ dst,
float2 * __restrict__ dst_meta, float2 * __restrict__ dst_meta,
const float scale, const float scale,
@@ -896,6 +944,7 @@ static __global__ void flash_attn_mma_ext_f16(
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
const float * sinks_f = sinks ? (const float *) sinks + channel * ncols2 : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
@@ -906,12 +955,12 @@ static __global__ void flash_attn_mma_ext_f16(
if (kb0_start == 0) { if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup> flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else { } else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup> flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} }
@@ -934,6 +983,7 @@ static __global__ void flash_attn_mma_ext_f16(
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
@@ -943,10 +993,10 @@ static __global__ void flash_attn_mma_ext_f16(
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false; constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup> flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else #else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);

View File

@@ -43,37 +43,37 @@ struct fattn_mma_f16_config;
// Perhaps the 256 head size needs a closer look // Perhaps the 256 head size needs a closer look
// to see if this implementation is better. // to see if this implementation is better.
// //
//template <> template <>
//struct fattn_mma_f16_config< 64, 64> { struct fattn_mma_f16_config< 64, 64> {
// static constexpr int nbatch_fa = 64; static constexpr int nbatch_fa = 64;
// static constexpr int nwarps_max = 4; static constexpr int nwarps_max = 4;
// static constexpr bool Q_in_reg = true; static constexpr bool Q_in_reg = true;
// static constexpr int nstages_target = 2; static constexpr int nstages_target = 2;
//
// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
// return 32; return 32;
// } }
//
// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
// return 32; return 32;
// } }
//
// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
// return 32; return 32;
// } }
//
// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
// return 32; return 32;
// } }
//
// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
// return 32; return 32;
// } }
//
// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
// return 32; return 32;
// } }
//}; };
// //
//template <> //template <>
//struct fattn_mma_f16_config< 80, 80> { //struct fattn_mma_f16_config< 80, 80> {
@@ -493,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
} else { } else {
constexpr bool use_cp_async = nstages == 1; constexpr bool use_cp_async = nstages == 1;
if constexpr (ncols2 > 1 || mask_h2) { if (ncols2 > 1 || mask_h2) {
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
} }
} }
@@ -576,7 +576,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
float KQ_rowsum_add[cols_per_thread] = {0.0f}; float KQ_rowsum_add[cols_per_thread] = {0.0f};
if constexpr (ntiles == 1) { if constexpr (ntiles == 1) {
if constexpr (ncols2 > 1 || mask_h2) { if (ncols2 > 1 || mask_h2) {
#pragma unroll #pragma unroll
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
@@ -818,6 +818,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const half2 * const __restrict__ K_h2, const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2, const half2 * const __restrict__ V_h2,
const half2 * const __restrict__ mask_h2, const half2 * const __restrict__ mask_h2,
const float * const __restrict__ sinks_f,
float2 * const __restrict__ dstk, float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup, float2 * const __restrict__ dstk_fixup,
const float scale, const float scale,
@@ -975,6 +976,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
__syncthreads(); __syncthreads();
} }
// If attention sinks are used, potentially re-scale if KQ_max is small.
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
// so it's being done unconditionally for every thread.
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
float KQ_max_scale[cols_per_thread];
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
const float sink = sinks_f[jc % ncols2];
const float KQ_max_new = fmaxf(KQ_max[col], sink);
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
KQ_max_scale[col] = expf(KQ_max_diff);
KQ_max[col] = KQ_max_new;
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
const float KQ_max_add = expf(sink - KQ_max_new);
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
}
if (ntiles == 1) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
#pragma unroll
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
#pragma unroll
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
} else {
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
#pragma unroll
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
#pragma unroll
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
}
}
}
}
}
// Finally, sum up partial KQ rowsums. // Finally, sum up partial KQ rowsums.
// The partial sums are spread across 8/4 threads each, does not need full reduce. // The partial sums are spread across 8/4 threads each, does not need full reduce.
{ {
@@ -1222,7 +1269,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
} }
} }
#else #else
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(sinks_f);
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
@@ -1239,6 +1286,7 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ K, const char * __restrict__ K,
const char * __restrict__ V, const char * __restrict__ V,
const char * __restrict__ mask, const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst, float * __restrict__ dst,
float2 * __restrict__ dst_meta, float2 * __restrict__ dst_meta,
const float scale, const float scale,
@@ -1323,6 +1371,7 @@ static __global__ void flash_attn_ext_f16(
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1335,12 +1384,12 @@ static __global__ void flash_attn_ext_f16(
if (kb0_start == 0) { if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup> flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else { } else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup> flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} }
@@ -1362,6 +1411,7 @@ static __global__ void flash_attn_ext_f16(
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1373,7 +1423,7 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false; constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup> flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else #else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -1535,7 +1585,8 @@ static void launch_fattn_new_mma(
const ggml_tensor * K = dst->src[1]; const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2]; const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3]; const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
ggml_tensor * KQV = dst; ggml_tensor * KQV = dst;
@@ -1709,6 +1760,7 @@ static void launch_fattn_new_mma(
K_data, K_data,
V_data, V_data,
mask ? ((const char *) mask->data) : nullptr, mask ? ((const char *) mask->data) : nullptr,
sinks ? ((const char *)sinks->data) : nullptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, logit_softcap, n_head_log2, scale, max_bias, m0, m1, logit_softcap, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
@@ -1853,6 +1905,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2]; const int gqa_ratio = Q->ne[2] / K->ne[2];
if (use_gqa_opt && gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 16>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 8 == 0) { if (use_gqa_opt && gqa_ratio % 8 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
return; return;
@@ -1878,8 +1935,6 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
const ggml_tensor * V = dst->src[2]; const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3]; const ggml_tensor * mask = dst->src[3];
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
float max_bias = 0.0f; float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
@@ -1888,6 +1943,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2]; const int gqa_ratio = Q->ne[2] / K->ne[2];
if (K->ne[0] == 64 && V->ne[0] == 64) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst);
return;
}
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
GGML_ASSERT(gqa_ratio % 16 == 0); GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);

View File

@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
const char * __restrict__ K, const char * __restrict__ K,
const char * __restrict__ V, const char * __restrict__ V,
const char * __restrict__ mask, const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst, float * __restrict__ dst,
float2 * __restrict__ dst_meta, float2 * __restrict__ dst_meta,
const float scale, const float scale,

View File

@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32(
const char * __restrict__ K, const char * __restrict__ K,
const char * __restrict__ V, const char * __restrict__ V,
const char * __restrict__ mask, const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst, float * __restrict__ dst,
float2 * __restrict__ dst_meta, float2 * __restrict__ dst_meta,
const float scale, const float scale,

View File

@@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16(
const char * __restrict__ K, const char * __restrict__ K,
const char * __restrict__ V, const char * __restrict__ V,
const char * __restrict__ mask, const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst, float * __restrict__ dst,
float2 * __restrict__ dst_meta, float2 * __restrict__ dst_meta,
const float scale, const float scale,
@@ -71,6 +72,7 @@ static __global__ void flash_attn_vec_ext_f16(
V += nb22*(blockIdx.y / gqa_ratio); V += nb22*(blockIdx.y / gqa_ratio);
const half * maskh = (const half *) mask + ne11*ic0; const half * maskh = (const half *) mask + ne11*ic0;
const float * sinksf = (const float *) (sinks);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef); const half slopeh = __float2half(slopef);
@@ -270,6 +272,39 @@ static __global__ void flash_attn_vec_ext_f16(
__syncthreads(); __syncthreads();
} }
if (sinksf) {
const half sink = __float2half(sinksf[blockIdx.y]);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
half kqmax_new_j = kqmax_shared[j][threadIdx.y];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;
const half val = hexp(sink - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale;
if (tid == 0) {
kqsum[j] += val;
}
VKQ[j] *= __half2half2(KQ_max_scale);
}
__syncthreads();
}
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum(kqsum[j]); kqsum[j] = warp_reduce_sum(kqsum[j]);

View File

@@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32(
const char * __restrict__ K, const char * __restrict__ K,
const char * __restrict__ V, const char * __restrict__ V,
const char * __restrict__ mask, const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst, float * __restrict__ dst,
float2 * __restrict__ dst_meta, float2 * __restrict__ dst_meta,
const float scale, const float scale,
@@ -69,6 +70,7 @@ static __global__ void flash_attn_vec_ext_f32(
K += nb12*(blockIdx.y / gqa_ratio); K += nb12*(blockIdx.y / gqa_ratio);
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0; const half * maskh = (const half *) mask + ne11*ic0;
const float * sinksf = (const float *) (sinks);
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
@@ -254,6 +256,39 @@ static __global__ void flash_attn_vec_ext_f32(
__syncthreads(); __syncthreads();
} }
if (sinksf) {
const float sink = sinksf[blockIdx.y];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
float kqmax_new_j = kqmax_shared[j][threadIdx.y];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;
const float val = expf(sink - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale;
if (tid == 0) {
kqsum[j] += val;
}
VKQ[j] *= KQ_max_scale;
}
__syncthreads();
}
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum(kqsum[j]); kqsum[j] = warp_reduce_sum(kqsum[j]);

View File

@@ -5,6 +5,8 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// //
// TODO: attention sinks !!!
#include "common.cuh" #include "common.cuh"
#include "fattn-common.cuh" #include "fattn-common.cuh"
@@ -22,6 +24,7 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ K, const char * __restrict__ K,
const char * __restrict__ V, const char * __restrict__ V,
const char * __restrict__ mask, const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst, float * __restrict__ dst,
float2 * __restrict__ dst_meta, float2 * __restrict__ dst_meta,
const float scale, const float scale,
@@ -93,6 +96,7 @@ static __global__ void flash_attn_ext_f16(
const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr;
const int stride_Q = nb01 / sizeof(float); const int stride_Q = nb01 / sizeof(float);
const int stride_K = nb11 / sizeof(half); const int stride_K = nb11 / sizeof(half);

View File

@@ -539,7 +539,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return; return;
} }
// As mentioned above, the new new MMA is slower than then the new MMA. // As mentioned above, the new-new MMA is slower then the new MMA.
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
//ggml_cuda_flash_attn_ext_mma_new(ctx, dst); //ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
} }

View File

@@ -0,0 +1,41 @@
#pragma once
struct ggml_graph_node_properties {
void * node_address;
ggml_op node_op;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};
struct ggml_cuda_graph {
#ifdef USE_CUDA_GRAPH
~ggml_cuda_graph() {
if (instance != nullptr) {
CUDA_CHECK(cudaGraphExecDestroy(instance));
}
if (graph != nullptr) {
CUDA_CHECK(cudaGraphDestroy(graph));
}
}
cudaGraph_t graph = nullptr;
cudaGraphExec_t instance = nullptr;
size_t num_nodes = 0;
std::vector<cudaGraphNode_t> nodes;
std::vector<cudaKernelNodeParams> params;
bool disable_due_to_gpu_arch = false;
bool disable_due_to_too_many_updates = false;
bool disable_due_to_failed_graph_capture = false;
int number_consecutive_updates = 0;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
bool use_cpy_indirection = false;
std::vector<char *> cpy_dest_ptrs;
char ** dest_ptrs_d;
int dest_ptrs_size = 0;
// Index to allow each cpy kernel to be aware of it's position within the graph
// relative to other cpy nodes.
int graph_cpynode_index = -1;
#endif
};

View File

@@ -19,7 +19,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
} }
template <bool vals_smem, int ncols_template, int block_size_template, typename T> template <bool vals_smem, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) { static __global__ void soft_max_f32_nosinks(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x; const int tid = threadIdx.x;
@@ -124,7 +124,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
} }
template<typename T> template<typename T>
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) { static void soft_max_f32_cuda_nosinks(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
int nth = WARP_SIZE; int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1); const dim3 block_dims(nth, 1, 1);
@@ -142,39 +142,40 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
switch (ncols_x) { switch (ncols_x) {
case 32: case 32:
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); soft_max_f32_nosinks<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break; break;
case 64: case 64:
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); soft_max_f32_nosinks<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break; break;
case 128: case 128:
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); soft_max_f32_nosinks<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break; break;
case 256: case 256:
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); soft_max_f32_nosinks<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break; break;
case 512: case 512:
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); soft_max_f32_nosinks<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break; break;
case 1024: case 1024:
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); soft_max_f32_nosinks<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break; break;
case 2048: case 2048:
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); soft_max_f32_nosinks<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break; break;
case 4096: case 4096:
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); soft_max_f32_nosinks<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break; break;
default: default:
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); soft_max_f32_nosinks<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break; break;
} }
} else { } else {
const size_t shmem_low = WARP_SIZE*sizeof(float); const size_t shmem_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); soft_max_f32_nosinks<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
} }
} }
#if 0
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src1 = dst->src[1];
@@ -205,13 +206,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (use_f16) { if (use_f16) {
const half * src1_dd = (const half *)src1_d; const half * src1_dd = (const half *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
} else { } else {
const float * src1_dd = (const float *)src1_d; const float * src1_dd = (const float *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
} }
} }
#endif
void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
@@ -241,10 +243,283 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds
if (use_f16) { if (use_f16) {
const half * src1_dd = (const half *)src1_d; const half * src1_dd = (const half *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
} else { } else {
const float * src1_dd = (const float *)src1_d; const float * src1_dd = (const float *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
} }
} }
struct soft_max_params {
int64_t nheads;
uint32_t n_head_log2;
int64_t ncols;
int64_t nrows_x;
int64_t nrows_y;
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
int64_t nb11;
int64_t nb12;
int64_t nb13;
int64_t ne12;
int64_t ne13;
float scale;
float max_bias;
float m0;
float m1;
};
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template <bool use_shared, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(
const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
const int tid = threadIdx.x;
const int64_t i03 = blockIdx.z;
const int64_t i02 = blockIdx.y;
const int64_t i01 = blockIdx.x;
//TODO: noncontigous inputs/outputs
const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
const int64_t i11 = i01;
const int64_t i12 = i02 % p.ne12;
const int64_t i13 = i03 % p.ne13;
x += int64_t(rowx)*ncols;
mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
dst += int64_t(rowx)*ncols;
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
// shared memory buffer to cache values between iterations:
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
float max_val = sinks ? sinks[i02] : -INFINITY;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
break;
}
const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
}
// find the max value in the block
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = -INFINITY;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = max_val;
}
__syncthreads();
max_val = buf_iw[lane_id];
max_val = warp_reduce_max(max_val);
}
float tmp = 0.0f; // partial sum
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
break;
}
const float val = expf(vals[col] - max_val);
tmp += val;
vals[col] = val;
}
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__syncthreads();
if (warp_id == 0) {
buf_iw[lane_id] = 0.0f;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = tmp;
}
__syncthreads();
tmp = buf_iw[lane_id];
tmp = warp_reduce_sum(tmp);
}
if (sinks) {
tmp += expf(sinks[i02] - max_val);
}
const float inv_sum = 1.0f / tmp;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
return;
}
dst[col] = vals[col] * inv_sum;
}
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
template<int... Ns, typename T>
static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
{
const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
auto launch_kernel = [=](auto I) -> bool {
constexpr int ncols = decltype(I)::value;
constexpr int block = (ncols > 1024 ? 1024 : ncols);
if (p.ncols == ncols) {
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, sinks, dst, p);
return true;
}
return false;
};
// unary fold over launch_kernel
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
return;
}
//default case
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
}
template<typename T>
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
int nth = WARP_SIZE;
const int64_t ncols_x = params.ncols;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(params.ne01, params.ne02, params.ne03);
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
if (nbytes_shared <= smpbo) {
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
} else {
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
}
}
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const float * src0_d = (const float *) src0->data;
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
const void * src2_d = src2 ? (const void *) src2->data : nullptr;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
const int64_t ne00 = src0->ne[0];
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
const int64_t nb11 = src1 ? src1->nb[1] : 1;
const int64_t nb12 = src1 ? src1->nb[2] : 1;
const int64_t nb13 = src1 ? src1->nb[3] : 1;
const int64_t ne12 = src1 ? src1->ne[2] : 1;
const int64_t ne13 = src1 ? src1->ne[3] : 1;
const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
soft_max_params params = {};
params.nheads = src0->ne[2];
params.n_head_log2 = n_head_log2;
params.ncols = ne00;
params.nrows_x = nrows_x;
params.nrows_y = nrows_y;
params.ne00 = src0->ne[0];
params.ne01 = src0->ne[1];
params.ne02 = src0->ne[2];
params.ne03 = src0->ne[3];
params.nb11 = nb11;
params.nb12 = nb12;
params.nb13 = nb13;
params.ne12 = ne12;
params.ne13 = ne13;
params.scale = scale;
params.max_bias = max_bias;
params.m0 = m0;
params.m1 = m1;
if (use_f16) {
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
} else {
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
}
}

View File

@@ -470,3 +470,83 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
} }
template <typename T>
static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
// perform base op and multiply with gate (either offset in same tensor or a separate one)
const int64_t j0 = (i / n) * o0 + (i % n);
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
float xi = x[j0];
float gi = g[j1];
xi = fminf(xi, limit);
gi = fmaxf(fminf(gi, limit), -limit);
float out_glu = xi / (1.0f + expf(-xi * alpha));
out_glu = out_glu * (1.0f + gi);
dst[i] = out_glu;
}
template <typename T>
static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
const int64_t num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
swiglu_oai_kernel<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
}
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
void * src0_d = src0->data;
void * src1_d = src1 ? src1->data : src0->data;
const int64_t src0_o = src0->nb[1];
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
void * dst_d = dst->data;
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == dst->type);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
GGML_ASSERT(src1->ne[0] == nc);
GGML_ASSERT(src0->type == src1->type);
}
//const int32_t swapped = ((const int32_t *) dst->op_params)[1];
const int32_t swapped = false; //ggml_get_op_params_i32(dst, 1);
const float * op_params = (const float *)dst->op_params;
const float alpha = op_params[2];
const float limit = op_params[3];
float * src0_p = (float *) src0_d;
float * src1_p = (float *) src1_d;
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc,
src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
}
void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n,
const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
swiglu_oai_cuda(x, g, dst, k, n, o0, o1, alpha, limit, stream);
}

View File

@@ -47,3 +47,9 @@ void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
int64_t nelements, const float * x, const float * y, float * z); int64_t nelements, const float * x, const float * y, float * z);
void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n,
const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream);

View File

@@ -2823,7 +2823,6 @@ inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
@@ -2834,6 +2833,19 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
int i = 0;
#if defined(__AVX2__)
for (; i + 7 < n; i += 8) {
__m256 vx = _mm256_loadu_ps(x + i);
__m256 vy = _mm256_loadu_ps(y + i);
__m256 vz = _mm256_add_ps(vx, vy);
_mm256_storeu_ps(z + i, vz);
}
#endif
for (; i < n; ++i) z[i] = x[i] + y[i];
}
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
assert(nrc == 1); assert(nrc == 1);
UNUSED(nrc); UNUSED(nrc);
@@ -4004,6 +4016,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"DUP", "DUP",
"ADD", "ADD",
"ADD_ID",
"ADD1", "ADD1",
"ACC", "ACC",
"SUB", "SUB",
@@ -4092,13 +4105,14 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK", "CROSS_ENTROPY_LOSS_BACK",
}; };
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
"x", "x",
"x+y", "x+y",
"x[i]+y",
"x+y", "x+y",
"view(x,nb,offset)+=y->x", "view(x,nb,offset)+=y->x",
"x-y", "x-y",
@@ -4187,7 +4201,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)", "cross_entropy_loss_back(x,y)",
}; };
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -4207,9 +4221,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSWISH", "HARDSWISH",
"HARDSIGMOID", "HARDSIGMOID",
"SWIGLU", "SWIGLU",
"SWIGLU_OAI",
}; };
static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14"); static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -5917,6 +5932,29 @@ struct ggml_tensor * ggml_add_cast(
return ggml_add_cast_impl(ctx, a, b, type); return ggml_add_cast_impl(ctx, a, b, type);
} }
// ggml_add_id
struct ggml_tensor * ggml_add_id(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * ids) {
GGML_ASSERT(a->ne[0] == b->ne[0]);
GGML_ASSERT(a->ne[1] == ids->ne[0]);
GGML_ASSERT(a->ne[2] == ids->ne[1]);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
result->op = GGML_OP_ADD_ID;
result->src[0] = a;
result->src[1] = b;
result->src[2] = ids;
return result;
}
// ggml_add1 // ggml_add1
static struct ggml_tensor * ggml_add1_impl( static struct ggml_tensor * ggml_add1_impl(
@@ -6662,6 +6700,36 @@ struct ggml_tensor * ggml_swiglu(
return result; return result;
} }
struct ggml_tensor * ggml_swiglu_oai(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float alpha,
float limit) {
GGML_ASSERT(ggml_is_contiguous_1(a));
if (b) {
GGML_ASSERT(ggml_is_contiguous_1(b));
GGML_ASSERT(ggml_are_same_shape(a, b));
GGML_ASSERT(a->type == b->type);
}
int64_t ne[4] = {a->ne[0]/2, a->ne[1], a->ne[2], a->ne[3]};
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
result->op = GGML_OP_UNARY;
result->grad = NULL;
result->src[0] = a;
result->src[1] = b;
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU_OAI);
ggml_set_op_params_f32(result, 2, alpha);
ggml_set_op_params_f32(result, 3, limit);
return result;
}
// ggml_silu_back // ggml_silu_back
struct ggml_tensor * ggml_silu_back( struct ggml_tensor * ggml_silu_back(
@@ -7017,6 +7085,66 @@ struct ggml_tensor * ggml_moe_up_gate(
result->src[1] = as_gate; result->src[1] = as_gate;
result->src[2] = b; result->src[2] = b;
result->src[3] = ids; result->src[3] = ids;
result->src[4] = NULL;
result->src[5] = NULL;
ggml_set_op_params_i32(result, 0, (int32_t) op);
return result;
}
struct ggml_tensor * ggml_moe_up_gate_ext(
struct ggml_context * ctx,
struct ggml_tensor * as_up,
struct ggml_tensor * as_gate,
struct ggml_tensor * b,
struct ggml_tensor * ids,
struct ggml_tensor * as_up_b,
struct ggml_tensor * as_gate_b,
enum ggml_unary_op op) {
if (!as_up_b && !as_gate_b) {
return ggml_moe_up_gate(ctx, as_up, as_gate, b, ids, op);
}
if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) {
struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids);
if (as_up_b) {
result_up = ggml_add_id(ctx, result_up, as_up_b, ids);
}
struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids);
if (as_gate_b) {
result_gate = ggml_add_id(ctx, result_gate, as_gate_b, ids);
}
return ggml_fused_mul_unary(ctx, result_gate, result_up, op);
}
GGML_ASSERT(!ggml_is_transposed(as_up));
GGML_ASSERT(!ggml_is_transposed(as_gate));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert)
GGML_ASSERT(b->ne[3] == 1); // b is 3d
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
GGML_ASSERT(as_up->ne[0] == b->ne[0]); // can_mul_mat
GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
GGML_ASSERT(as_up->ne[1] == as_up_b->ne[0]);
GGML_ASSERT(as_gate->ne[1] == as_gate_b->ne[0]);
bool is_node = false;
const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_MOE_FUSED_UP_GATE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = as_up;
result->src[1] = as_gate;
result->src[2] = b;
result->src[3] = ids;
result->src[4] = as_up_b;
result->src[5] = as_gate_b;
ggml_set_op_params_i32(result, 0, (int32_t) op); ggml_set_op_params_i32(result, 0, (int32_t) op);
@@ -7970,6 +8098,22 @@ struct ggml_tensor * ggml_soft_max_ext(
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
} }
void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks) {
if (!sinks) {
a->src[2] = NULL;
return;
}
GGML_ASSERT(a->op == GGML_OP_SOFT_MAX);
GGML_ASSERT(a->src[2] == NULL);
GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
GGML_ASSERT(sinks->type == GGML_TYPE_F32);
a->src[2] = sinks;
}
// ggml_soft_max_back // ggml_soft_max_back
static struct ggml_tensor * ggml_soft_max_back_impl( static struct ggml_tensor * ggml_soft_max_back_impl(
@@ -8833,6 +8977,22 @@ void ggml_flash_attn_ext_set_prec(
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
} }
void ggml_flash_attn_ext_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks) {
if (!sinks) {
a->src[4] = NULL;
return;
}
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
GGML_ASSERT(a->src[4] == NULL);
GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
GGML_ASSERT(sinks->type == GGML_TYPE_F32);
a->src[4] = sinks;
}
// ggml_flash_attn_back // ggml_flash_attn_back
struct ggml_tensor * ggml_flash_attn_back( struct ggml_tensor * ggml_flash_attn_back(
@@ -11497,6 +11657,77 @@ static void ggml_compute_forward_multi_add(
} }
} }
// ggml_compute_forward_add_id
static void ggml_compute_forward_add_id_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_nrows(src0);
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
// src1 indices
const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
GGML_ASSERT(i11 >= 0 && i11 < ne11);
ggml_vec_add_f32(ne0,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
(float *) ((char *) src1->data + i11*nb11));
}
}
static void ggml_compute_forward_add_id(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_add_id_f32(params, dst);
} break;
default:
{
GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
}
}
}
// ggml_compute_forward_add1 // ggml_compute_forward_add1
static void ggml_compute_forward_add1_f32( static void ggml_compute_forward_add1_f32(
@@ -13760,6 +13991,93 @@ static void ggml_compute_forward_swiglu(
} }
} }
// ggml_compute_forward_swiglu_oai
static void ggml_compute_forward_swiglu_oai_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = false; //ggml_get_op_params_i32(dst, 1);
const float alpha = ggml_get_op_params_f32(dst, 2);
const float limit = ggml_get_op_params_f32(dst, 3);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float * src0_p = (float *) (src0_d + i1*src0_o);
float * src1_p = (float *) (src1_d + i1*src1_o);
float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
for (int k = 0; k < nc; k++) {
const float x = MIN(src0_p[k], limit);
const float y = MAX(MIN(src1_p[k], limit), -limit);
const float out_glu = x / (1.f + expf(alpha * (-x)));
dst_p[k] = out_glu * (y + 1.f);
}
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = dst_p[k];
GGML_UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_compute_forward_swiglu_oai(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_swiglu_oai_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_fused_mul_unary // ggml_compute_forward_fused_mul_unary
static void ggml_compute_forward_fused_mul_unary_f32( static void ggml_compute_forward_fused_mul_unary_f32(
@@ -15167,6 +15485,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
const struct ggml_tensor * src1 = dst->src[2]; const struct ggml_tensor * src1 = dst->src[2];
const struct ggml_tensor * ids = dst->src[3]; const struct ggml_tensor * ids = dst->src[3];
const struct ggml_tensor * up_b = dst->src[4];
const struct ggml_tensor * gate_b = dst->src[5];
const struct ggml_tensor * src0_1 = dst->src[0]; const struct ggml_tensor * src0_1 = dst->src[0];
const struct ggml_tensor * src0_2 = dst->src[1]; const struct ggml_tensor * src0_2 = dst->src[1];
const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works
@@ -15191,6 +15511,9 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
GGML_ASSERT(nb2 <= nb3); GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne13 == 1); GGML_ASSERT(ne13 == 1);
const size_t nb41 = up_b ? up_b->nb[1] : 0;
const size_t nb51 = up_b ? gate_b->nb[1] : 0;
// row groups // row groups
const int n_ids = ids->ne[0]; // n_expert_used const int n_ids = ids->ne[0]; // n_expert_used
const int n_as = ne02; // n_expert const int n_as = ne02; // n_expert
@@ -15278,6 +15601,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02; const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02;
const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02; const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02;
const char * up_b_cur = up_b ? (const char *)up_b->data + cur_a*nb41 : NULL;
const char * gate_b_cur = gate_b ? (const char *)gate_b->data + cur_a*nb51 : NULL;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10); const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -15288,6 +15613,7 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0], if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0],
type, src0_1_cur, src0_2_cur, nb01, type, src0_1_cur, src0_2_cur, nb01,
vec_dot_type, (const char *)wdata, row_size, vec_dot_type, (const char *)wdata, row_size,
up_b_cur, gate_b_cur,
(float *)dst->data, nb1, nb2, (float *)dst->data, nb1, nb2,
matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
@@ -16645,6 +16971,7 @@ static void ggml_compute_forward_soft_max_f32(
const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * src2 = dst->src[2];
assert(ggml_is_contiguous(dst)); assert(ggml_is_contiguous(dst));
assert(ggml_are_same_shape(src0, dst)); assert(ggml_are_same_shape(src0, dst));
@@ -16662,6 +16989,13 @@ static void ggml_compute_forward_soft_max_f32(
GGML_TENSOR_UNARY_OP_LOCALS GGML_TENSOR_UNARY_OP_LOCALS
const int64_t nb11 = src1 ? src1->nb[1] : 1;
const int64_t nb12 = src1 ? src1->nb[2] : 1;
const int64_t nb13 = src1 ? src1->nb[3] : 1;
const int64_t ne12 = src1 ? src1->ne[2] : 1;
const int64_t ne13 = src1 ? src1->ne[3] : 1;
//const int64_t ne11 = src1 ? src1->ne[1] : 1; //const int64_t ne11 = src1 ? src1->ne[1] : 1;
// TODO: is this supposed to be ceil instead of floor? // TODO: is this supposed to be ceil instead of floor?
@@ -16673,67 +17007,80 @@ static void ggml_compute_forward_soft_max_f32(
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int nc = src0->ne[0]; const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
for (int i1 = ir0; i1 < ir1; i1++) { // sinks
// ALiBi const float * sk = src2 ? (float *)((char *) src2->data) : NULL;
const uint32_t h = (i1/ne01)%ne02; // head
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); for (int64_t i03 = 0; i03 < ne03; i03++) {
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const int64_t i11 = i01;
const int64_t i12 = i02%ne12;
const int64_t i13 = i03%ne13;
// broadcast the mask across rows // ALiBi
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; const uint32_t h = i02; // head
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
ggml_vec_cpy_f32 (nc, wp, sp); float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_vec_scale_f32(nc, wp, scale); float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
if (mp_f32) {
if (use_f16) { // broadcast the mask across rows
for (int i = 0; i < nc; ++i) { ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
ggml_vec_cpy_f32 (ne00, wp, sp);
ggml_vec_scale_f32(ne00, wp, scale);
if (mp_f32) {
if (use_f16) {
for (int i = 0; i < ne00; ++i) {
wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
}
} else {
for (int i = 0; i < ne00; ++i) {
wp[i] += slope*mp_f32[i];
}
}
} }
} else {
for (int i = 0; i < nc; ++i) { #ifndef NDEBUG
wp[i] += slope*mp_f32[i]; for (int i = 0; i < ne00; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(wp[i]));
} }
#endif
float max = -INFINITY;
ggml_vec_max_f32(ne00, &max, wp);
// if we have sinks, make a correction as if they were included in the softmax
if (sk) {
max = MAX(max, sk[i02]);
}
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
assert(sum > 0.0);
if (sk) {
sum += (ggml_float) expf(sk[i02] - max);
}
sum = 1.0/sum;
ggml_vec_scale_f32(ne00, dp, sum);
#ifndef NDEBUG
for (int i = 0; i < ne00; ++i) {
assert(!isnan(dp[i]));
assert(!isinf(dp[i]));
}
#endif
} }
} }
//#ifndef NDEBUG
// for (int i = 0; i < nc; ++i) {
// //printf("p[%d] = %f\n", i, p[i]);
// assert(!isnan(wp[i]));
// }
//#endif
float max = -INFINITY;
ggml_vec_max_f32(nc, &max, wp);
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
//assert(sum > 0.0);
sum = 1.0/sum;
ggml_vec_scale_f32(nc, dp, sum);
//#ifndef NDEBUG
// for (int i = 0; i < nc; ++i) {
// assert(!isnan(dp[i]));
// assert(!isinf(dp[i]));
// }
//#endif
} }
} }
@@ -16755,7 +17102,6 @@ static void ggml_compute_forward_soft_max(
} }
} }
// ggml_compute_forward_soft_max_back // ggml_compute_forward_soft_max_back
static void ggml_compute_forward_soft_max_back_f32( static void ggml_compute_forward_soft_max_back_f32(
@@ -18308,12 +18654,14 @@ static void ggml_compute_forward_argsort_thresh(
static void ggml_compute_forward_flash_attn_ext_f16( static void ggml_compute_forward_flash_attn_ext_f16(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * q,
const struct ggml_tensor * k,
const struct ggml_tensor * v,
const struct ggml_tensor * mask,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
const struct ggml_tensor * q = dst->src[0];
const struct ggml_tensor * k = dst->src[1];
const struct ggml_tensor * v = dst->src[2];
const struct ggml_tensor * mask = dst->src[3];
const struct ggml_tensor * sinks = dst->src[4];
GGML_TENSOR_LOCALS(int64_t, neq, q, ne) GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb) GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_TENSOR_LOCALS(int64_t, nek, k, ne) GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
@@ -18383,6 +18731,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
} }
#if GGML_USE_IQK_MULMAT #if GGML_USE_IQK_MULMAT
// For now we do not implement sinks in the iqk FA implementation
if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias, if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias,
q->ne[3], q->ne[2], q->nb[3], q->nb[2], q->ne[3], q->ne[2], q->nb[3], q->nb[2],
k->ne[3], k->ne[2], k->nb[3], k->nb[2], k->ne[3], k->ne[2], k->nb[3], k->nb[2],
@@ -18390,7 +18739,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
dst->ne[2], dst->ne[1], dst->nb[1], dst->ne[2], dst->ne[1], dst->nb[1],
k->type, v->type, k->type, v->type,
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
q->data, k->data, v->data, mask->data, q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL,
scale, softcap, (float *)dst->data, scale, softcap, (float *)dst->data,
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return; params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return;
@@ -18447,6 +18796,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
ggml_to_float_t const v_to_float = type_traits[v->type].to_float; ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
const int64_t Dkv = MAX(Dk, Dv); const int64_t Dkv = MAX(Dk, Dv);
// loop over n_batch and n_head // loop over n_batch and n_head
@@ -18552,6 +18904,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
} }
} }
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
ms = expf(M - s);
ggml_vec_scale_f32(Dv, VKQ32, ms);
} else {
vs = expf(s - M);
}
S = S*ms + vs;
}
// V /= S // V /= S
const float S_inv = 1.0f/S; const float S_inv = 1.0f/S;
ggml_vec_scale_f32(Dv, VKQ32, S_inv); ggml_vec_scale_f32(Dv, VKQ32, S_inv);
@@ -18571,17 +18939,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
static void ggml_compute_forward_flash_attn_ext( static void ggml_compute_forward_flash_attn_ext(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * q,
const struct ggml_tensor * k,
const struct ggml_tensor * v,
const struct ggml_tensor * mask,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
switch (dst->op_params[3]) { switch (dst->op_params[3]) {
case GGML_PREC_DEFAULT: case GGML_PREC_DEFAULT:
case GGML_PREC_F32: case GGML_PREC_F32:
{ {
// uses F32 accumulators // uses F32 accumulators
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); ggml_compute_forward_flash_attn_ext_f16(params, dst);
} break; } break;
default: default:
{ {
@@ -19350,6 +19714,10 @@ static void ggml_compute_forward_unary(
{ {
ggml_compute_forward_swiglu(params, dst); ggml_compute_forward_swiglu(params, dst);
} break; } break;
case GGML_UNARY_OP_SWIGLU_OAI:
{
ggml_compute_forward_swiglu_oai(params, dst);
} break;
case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSWISH:
{ {
ggml_compute_forward_hardswish(params, dst); ggml_compute_forward_hardswish(params, dst);
@@ -19898,6 +20266,10 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_add(params, tensor); ggml_compute_forward_add(params, tensor);
} break; } break;
case GGML_OP_ADD_ID:
{
ggml_compute_forward_add_id(params, tensor);
} break;
case GGML_OP_ADD1: case GGML_OP_ADD1:
{ {
ggml_compute_forward_add1(params, tensor); ggml_compute_forward_add1(params, tensor);
@@ -20136,7 +20508,7 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break; } break;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
{ {
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); ggml_compute_forward_flash_attn_ext(params, tensor);
} break; } break;
case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_FLASH_ATTN_BACK:
{ {
@@ -20486,6 +20858,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
} }
} break; } break;
case GGML_OP_ADD_ID:
{
GGML_ABORT("fatal error"); // TODO: implement
} break;
case GGML_OP_ADD1: case GGML_OP_ADD1:
{ {
if (src0->grad) { if (src0->grad) {
@@ -21719,6 +22095,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_DUP: case GGML_OP_DUP:
case GGML_OP_CONT: case GGML_OP_CONT:
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1: case GGML_OP_ADD1:
case GGML_OP_ACC: case GGML_OP_ACC:
case GGML_OP_MULTI_ADD: case GGML_OP_MULTI_ADD:
@@ -21758,6 +22135,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_SWIGLU: case GGML_UNARY_OP_SWIGLU:
case GGML_UNARY_OP_SWIGLU_OAI:
{ {
n_tasks = n_threads; n_tasks = n_threads;
} break; } break;
@@ -21952,6 +22330,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
} }
} break; } break;
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1: case GGML_OP_ADD1:
{ {
if (ggml_is_quantized(node->src[0]->type)) { if (ggml_is_quantized(node->src[0]->type)) {

View File

@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_128_128) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) { if (nk%64 == 0) {
iqk_flash_helper_T<128, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, iqk_flash_helper_T<128, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
iqk_flash_helper_T<128, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, iqk_flash_helper_T<128, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
#endif #endif
if (nk%128 == 0) { if (nk%128 == 0) {
return iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }
if (nk%64 == 0) { if (nk%64 == 0) {
return iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }
return iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }

View File

@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_192_128) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) { if (nk%64 == 0) {
iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
#endif #endif
if (nk%128 == 0) { if (nk%128 == 0) {
return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }
if (nk%64 == 0) { if (nk%64 == 0) {
return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }
return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }

View File

@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_256_256) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) { if (nk%64 == 0) {
iqk_flash_helper_T<256, 256, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, iqk_flash_helper_T<256, 256, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
iqk_flash_helper_T<256, 256, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, iqk_flash_helper_T<256, 256, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
#endif #endif
if (nk%128 == 0) { if (nk%128 == 0) {
return iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }
if (nk%64 == 0) { if (nk%64 == 0) {
return iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }
return iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }

View File

@@ -9,7 +9,8 @@ namespace {
template <int step_k, typename KHelper, typename VHelper> template <int step_k, typename KHelper, typename VHelper>
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { const float * q, const char * mask, float scale, float softcap, float * qkv,
const float * sinkf, float * M, float * S) {
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
nq1 -= n; nq1 -= n;
if (nq1 == 0) return true; if (nq1 == 0) return true;
@@ -21,29 +22,29 @@ inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
}; };
if (nq1 >= 16) { if (nq1 >= 16) {
int n_step = nq1/16; int n_step = nq1/16;
FlashAttn<576, 512, 16, step_k> fa(scale, softcap); FlashAttn<576, 512, 16, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(16*n_step)) return; if (update(16*n_step)) return;
} }
if (nq1 >= 8) { if (nq1 >= 8) {
int n_step = nq1/8; int n_step = nq1/8;
FlashAttn<576, 512, 8, step_k> fa(scale, softcap); FlashAttn<576, 512, 8, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(8*n_step)) return; if (update(8*n_step)) return;
} }
if (nq1 >= 4) { if (nq1 >= 4) {
int n_step = nq1/4; int n_step = nq1/4;
FlashAttn<576, 512, 4, step_k> fa(scale, softcap); FlashAttn<576, 512, 4, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(4*n_step)) return; if (update(4*n_step)) return;
} }
if (nq1 >= 2) { if (nq1 >= 2) {
int n_step = nq1/2; int n_step = nq1/2;
FlashAttn<576, 512, 2, step_k> fa(scale, softcap); FlashAttn<576, 512, 2, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(2*n_step)) return; if (update(2*n_step)) return;
} }
FlashAttn<576, 512, 1, step_k> fa(scale, softcap); FlashAttn<576, 512, 1, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
} }
@@ -51,37 +52,37 @@ template <int step_k>
inline bool iqk_deepseek_helper(ggml_type type_k, inline bool iqk_deepseek_helper(ggml_type type_k,
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask, const float * q, const char * k, const char * v, const char * mask,
float scale, float softcap, float * qkv, float * M, float * S) { float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
if (type_k == GGML_TYPE_Q8_0) { if (type_k == GGML_TYPE_Q8_0) {
HelperQ80 kh((const char *)k, stride_k); HelperQ80 kh((const char *)k, stride_k);
HelperQ80 vh((const char *)v, stride_v); HelperQ80 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
if (type_k == GGML_TYPE_Q8_0_R8) { if (type_k == GGML_TYPE_Q8_0_R8) {
HelperQ80R8<576> kh((const char *)k, stride_k); HelperQ80R8<576> kh((const char *)k, stride_k);
HelperQ80 vh((const char *)v, stride_v); HelperQ80 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
if (type_k == GGML_TYPE_Q6_0) { if (type_k == GGML_TYPE_Q6_0) {
HelperQ60 kh((const char *)k, stride_k); HelperQ60 kh((const char *)k, stride_k);
HelperQ60 vh((const char *)v, stride_v); HelperQ60 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
#if GGML_IQK_FA_ALL_QUANTS #if GGML_IQK_FA_ALL_QUANTS
if (type_k == GGML_TYPE_Q8_KV) { if (type_k == GGML_TYPE_Q8_KV) {
HelperQ8KV<576> kh((const char *)k, stride_k); HelperQ8KV<576> kh((const char *)k, stride_k);
HelperQ8KV<512> vh((const char *)v, stride_v); HelperQ8KV<512> vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
#endif #endif
if (type_k == GGML_TYPE_F16) { if (type_k == GGML_TYPE_F16) {
HelperF16 kh((const char *)k, stride_k); HelperF16 kh((const char *)k, stride_k);
HelperF16 vh((const char *)v, stride_v); HelperF16 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
#ifdef __AVX512BF16__ #ifdef __AVX512BF16__
@@ -89,10 +90,10 @@ inline bool iqk_deepseek_helper(ggml_type type_k,
HelperBF16<576, step_k> kh((const char *)k, stride_k); HelperBF16<576, step_k> kh((const char *)k, stride_k);
HelperBF16<512, step_k> vh((const char *)v, stride_v); HelperBF16<512, step_k> vh((const char *)v, stride_v);
if (nq1 % 8 == 0) { if (nq1 % 8 == 0) {
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap); FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
} else { } else {
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap); FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
} }
return true; return true;
@@ -113,7 +114,7 @@ IQK_FA_CASE(iqk_fa_576_512) {
} }
stride_q /= sizeof(float); // q stride as float stride_q /= sizeof(float); // q stride as float
return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S); q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, sinkf, M, S);
} }

View File

@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_64_64) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) { if (nk%64 == 0) {
iqk_flash_helper_T<64, 64, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, iqk_flash_helper_T<64, 64, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
iqk_flash_helper_T<64, 64, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, iqk_flash_helper_T<64, 64, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
#endif #endif
if (nk%128 == 0) { if (nk%128 == 0) {
return iqk_flash_helper_T<64, 64, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<64, 64, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }
if (nk%64 == 0) { if (nk%64 == 0) {
return iqk_flash_helper_T<64, 64, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<64, 64, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }
return iqk_flash_helper_T<64, 64, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<64, 64, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }

View File

@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_96_96) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) { if (nk%64 == 0) {
iqk_flash_helper_T<96, 96, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, iqk_flash_helper_T<96, 96, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
iqk_flash_helper_T<96, 96, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, iqk_flash_helper_T<96, 96, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true; return true;
} }
#endif #endif
if (nk%128 == 0) { if (nk%128 == 0) {
return iqk_flash_helper_T<96, 96, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<96, 96, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }
if (nk%64 == 0) { if (nk%64 == 0) {
return iqk_flash_helper_T<96, 96, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<96, 96, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }
return iqk_flash_helper_T<96, 96, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_flash_helper_T<96, 96, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S); q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
} }

View File

@@ -1141,10 +1141,25 @@ struct FlashQKV {
} }
template <typename FMS> template <typename FMS>
inline void normalize_and_store_1row(const FMS& fms, int j, const qkv_cache_t * R, float * qkv) const { inline void normalize_and_store_1row(const FMS& fms, int j, qkv_cache_t * R, float * qkv, const float * sinkf) const {
static_assert(q_step == FMS::q_step); static_assert(q_step == FMS::q_step);
GGML_ASSERT(fms.S[j] > 0); float S = fms.S[j];
auto norm = F16::set1(1/fms.S[j]); if (sinkf) {
float s = *sinkf;
if (s > fms.M[j]) {
float m = expf(fms.M[j] - s);
auto vm = F16::set1(m);
for (int i = 0; i < D/F16::block_size; ++i) {
auto Ri = R + F16::block_size*i;
F16::store(Ri, F16::mul(vm, F16::load(Ri)));
}
S = S*m + 1;
} else {
S += expf(s - fms.M[j]);
}
}
GGML_ASSERT(S > 0);
auto norm = F16::set1(1/S);
//auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
for (int i = 0; i < D/F16::block_size; ++i) { for (int i = 0; i < D/F16::block_size; ++i) {
auto r = F16::load(R + F16::block_size*i); auto r = F16::load(R + F16::block_size*i);
@@ -1153,7 +1168,7 @@ struct FlashQKV {
} }
template <typename FMS> template <typename FMS>
inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const { inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, const float * sinkf, float * M, float * S) {
static_assert(q_step == FMS::q_step); static_assert(q_step == FMS::q_step);
if (M && S) { if (M && S) {
std::memcpy(M, fms.M, nq1*sizeof(float)); std::memcpy(M, fms.M, nq1*sizeof(float));
@@ -1173,7 +1188,7 @@ struct FlashQKV {
} else { } else {
auto R = qkv_cache; auto R = qkv_cache;
for (int j = 0; j < nq1; ++j) { for (int j = 0; j < nq1; ++j) {
normalize_and_store_1row(fms, j, R, qkv); normalize_and_store_1row(fms, j, R, qkv, sinkf);
qkv += stride_qkv; qkv += stride_qkv;
R += D; R += D;
} }
@@ -1181,7 +1196,7 @@ struct FlashQKV {
} }
template <typename FMS> template <typename FMS>
inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, float * M, float * S) const { inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, const float * sinkf, float * M, float * S) {
static_assert(q_step == FMS::q_step); static_assert(q_step == FMS::q_step);
if (M && S) { if (M && S) {
std::memcpy(M, fms.M, q_step*sizeof(float)); std::memcpy(M, fms.M, q_step*sizeof(float));
@@ -1201,7 +1216,7 @@ struct FlashQKV {
} else { } else {
auto R = qkv_cache; auto R = qkv_cache;
for (int j = 0; j < q_step; ++j) { for (int j = 0; j < q_step; ++j) {
normalize_and_store_1row(fms, j, R, qkv); normalize_and_store_1row(fms, j, R, qkv, sinkf);
qkv += stride_qkv; qkv += stride_qkv;
R += D; R += D;
} }
@@ -1332,7 +1347,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
FlashMS<q_step, k_step>& fms, FlashMS<q_step, k_step>& fms,
FlashQKV<Dv, q_step, k_step>& fqkv, FlashQKV<Dv, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv, const float * q, const char * mask, float * qkv,
float * M, float * S) { const float * sinkf, float * M, float * S) {
#ifdef __aarch64__ #ifdef __aarch64__
float16_t q_f16[Dk*q_step]; float16_t q_f16[Dk*q_step];
#endif #endif
@@ -1356,7 +1371,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
vh.next_block(k_step); vh.next_block(k_step);
mr += k_step*sizeof(ggml_half); mr += k_step*sizeof(ggml_half);
} }
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
q += q_step*stride_q; q += q_step*stride_q;
mask += q_step*stride_m; mask += q_step*stride_m;
@@ -1383,7 +1398,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
vh.next_block(k_step); vh.next_block(k_step);
mr += k_step*sizeof(ggml_half); mr += k_step*sizeof(ggml_half);
} }
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S);
} }
} }
@@ -1392,7 +1407,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
FlashMS<q_step, k_step>& fms, FlashMS<q_step, k_step>& fms,
FlashQKV<Dv, q_step, k_step>& fqkv, FlashQKV<Dv, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv, const float * q, const char * mask, float * qkv,
float * M, float * S, char * qptr) { const float * sinkf, float * M, float * S, char * qptr) {
auto q8 = (typename KHelper::block_q8 *)qptr; auto q8 = (typename KHelper::block_q8 *)qptr;
if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80>) { if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80>) {
if (nq1 == q_step) { if (nq1 == q_step) {
@@ -1412,7 +1427,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
vh.next_block(k_step); vh.next_block(k_step);
mr += k_step*sizeof(ggml_half); mr += k_step*sizeof(ggml_half);
} }
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
return; return;
} }
} }
@@ -1449,10 +1464,10 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
} }
#if FA_TIMING #if FA_TIMING
t1 = Perf::cur_time(); t1 = Perf::cur_time();
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
perf.accum_nolock(3, t1); perf.accum_nolock(3, t1);
#else #else
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
#endif #endif
q += q_step*stride_q; q += q_step*stride_q;
@@ -1474,7 +1489,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
vh.next_block(k_step); vh.next_block(k_step);
mr += k_step*sizeof(ggml_half); mr += k_step*sizeof(ggml_half);
} }
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S);
} }
#if FA_TIMING #if FA_TIMING
Perf::instance().add(perf); Perf::instance().add(perf);
@@ -1504,7 +1519,7 @@ struct FlashAttn {
static_assert(k_step%F16::block_size == 0); static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0); static_assert(q_step <= 4 || q_step%4 == 0);
FlashAttn(float scale, float softcap) : fms(scale, softcap) {} FlashAttn(float scale, float softcap, const float * sinkf) : fms(scale, softcap), sinkf(sinkf) {}
template <typename KHelper, typename VHelper> template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
@@ -1533,7 +1548,7 @@ struct FlashAttn {
HelperQ80R8<Dk> khr4(nk1, kh); HelperQ80R8<Dk> khr4(nk1, kh);
#endif #endif
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr);
return; return;
} }
@@ -1547,29 +1562,30 @@ struct FlashAttn {
HelperQ8KVR8<Dk> khr4(nk1, kh); HelperQ8KVR8<Dk> khr4(nk1, kh);
#endif #endif
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr);
return; return;
} }
#endif #endif
} }
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr);
} }
else { else {
typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, (char *)q8); kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, (char *)q8);
} }
} }
else { else {
compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S);
} }
} }
FlashMS<q_step, k_step> fms; FlashMS<q_step, k_step> fms;
FlashQKV<Dv, q_step, k_step> fqkv; FlashQKV<Dv, q_step, k_step> fqkv;
const float * sinkf;
}; };
@@ -1927,7 +1943,7 @@ struct FlashAttnBF16 {
static_assert(k_step%32 == 0); static_assert(k_step%32 == 0);
static_assert(q_step <= 4 || q_step%4 == 0); static_assert(q_step <= 4 || q_step%4 == 0);
FlashAttnBF16(float scale, float softcap) : fms(scale, softcap) {} FlashAttnBF16(float scale, float softcap, const float * sinkf) : fms(scale, softcap), sinkf(sinkf) {}
template <typename KHelper, typename VHelper> template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
@@ -1967,7 +1983,7 @@ struct FlashAttnBF16 {
#if FA_TIMING #if FA_TIMING
t1 = Perf::cur_time(); t1 = Perf::cur_time();
#endif #endif
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
#if FA_TIMING #if FA_TIMING
perf.accum_nolock(4, t1); perf.accum_nolock(4, t1);
#endif #endif
@@ -1990,7 +2006,7 @@ struct FlashAttnBF16 {
vh.next_block(k_step); vh.next_block(k_step);
mr += k_step*sizeof(ggml_half); mr += k_step*sizeof(ggml_half);
} }
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S);
} }
#if FA_TIMING #if FA_TIMING
Perf::instance().add(perf); Perf::instance().add(perf);
@@ -1999,12 +2015,14 @@ struct FlashAttnBF16 {
FlashMS<q_step, k_step> fms; FlashMS<q_step, k_step> fms;
FlashQKV<Dv, q_step, k_step> fqkv; FlashQKV<Dv, q_step, k_step> fqkv;
const float * sinkf;
}; };
#endif #endif
template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper> template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { const float * q, const char * mask, float scale, float softcap, float * qkv,
const float * sinkf, float * M, float * S) {
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
nq1 -= n; nq1 -= n;
@@ -2018,48 +2036,48 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
if (nk1 >= 512) { if (nk1 >= 512) {
if (nq1 >= 128) { if (nq1 >= 128) {
int n_step = nq1/128; int n_step = nq1/128;
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap); FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(128*n_step)) return; if (update(128*n_step)) return;
} }
if (nq1 >= 64) { if (nq1 >= 64) {
int n_step = nq1/64; int n_step = nq1/64;
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap); FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(64*n_step)) return; if (update(64*n_step)) return;
} }
if (nq1 >= 32) { if (nq1 >= 32) {
int n_step = nq1/32; int n_step = nq1/32;
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap); FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(32*n_step)) return; if (update(32*n_step)) return;
} }
if (nq1 >= 16) { if (nq1 >= 16) {
int n_step = nq1/16; int n_step = nq1/16;
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap); FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(16*n_step)) return; if (update(16*n_step)) return;
} }
} }
if (nq1 >= 8) { if (nq1 >= 8) {
int n_step = nq1/8; int n_step = nq1/8;
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap); FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(8*n_step)) return; if (update(8*n_step)) return;
} }
else if (nq1 >= 4) { else if (nq1 >= 4) {
int n_step = nq1/4; int n_step = nq1/4;
FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap); FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(4*n_step)) return; if (update(4*n_step)) return;
} }
else if (nq1 >= 2) { else if (nq1 >= 2) {
int n_step = nq1/2; int n_step = nq1/2;
FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap); FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(2*n_step)) return; if (update(2*n_step)) return;
} }
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap); FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
} }
@@ -2067,26 +2085,26 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
template <int Dk, int Dv, int k_step> template <int Dk, int Dv, int k_step>
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask, const float * q, const char * k, const char * v, const char * mask,
float scale, float softcap, float * qkv, float * M, float * S) { float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
HelperBF16<Dk, k_step> kh(k, stride_k); HelperBF16<Dk, k_step> kh(k, stride_k);
HelperBF16<Dv, k_step> vh(v, stride_v); HelperBF16<Dv, k_step> vh(v, stride_v);
if (nk1 >= 4096) { if (nk1 >= 4096) {
if (nq1 >= 64) { if (nq1 >= 64) {
FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap); FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return; return;
} }
else if (nq1 >= 16) { else if (nq1 >= 16) {
FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap); FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return; return;
} }
} }
if (nq1 >= 8) { if (nq1 >= 8) {
FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap); FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
} else { } else {
FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap); FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
} }
} }
@@ -2096,43 +2114,43 @@ template <int Dk, int Dv, int k_step, typename KHelper>
inline bool iqk_flash_helper_T(KHelper& kh, ggml_type type_v, inline bool iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv, int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * v, const char * mask, const float * q, const char * v, const char * mask,
float scale, float softcap, float * qkv, float * M, float * S) { float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
switch (type_v) { switch (type_v) {
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
HelperF16 vh(v, stride_v); HelperF16 vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
#ifdef __AVX512BF16__ #ifdef __AVX512BF16__
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
HelperBF16<Dv, k_step> vh(v, stride_v); HelperBF16<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
#endif #endif
case GGML_TYPE_Q8_0: { case GGML_TYPE_Q8_0: {
HelperQ80 vh(v, stride_v); HelperQ80 vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
case GGML_TYPE_Q8_KV: { case GGML_TYPE_Q8_KV: {
HelperQ8KV<Dv> vh(v, stride_v); HelperQ8KV<Dv> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
case GGML_TYPE_Q6_0: { case GGML_TYPE_Q6_0: {
HelperQ60 vh(v, stride_v); HelperQ60 vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
#if GGML_IQK_FA_ALL_QUANTS #if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: { case GGML_TYPE_Q4_0: {
HelperQ40 vh(v, stride_v); HelperQ40 vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
case GGML_TYPE_Q4_1: { case GGML_TYPE_Q4_1: {
HelperQ41 vh(v, stride_v); HelperQ41 vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
case GGML_TYPE_IQ4_NL: { case GGML_TYPE_IQ4_NL: {
HelperIQ4nl vh(v, stride_v); HelperIQ4nl vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
#endif #endif
default: return false; default: return false;
@@ -2144,42 +2162,42 @@ template <int Dk, int Dv, int k_step>
inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask, const float * q, const char * k, const char * v, const char * mask,
float scale, float softcap, float * qkv, float * M, float * S) { float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
bool result = false; bool result = false;
switch (type_k) { switch (type_k) {
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
HelperF16 kh(k, stride_k); HelperF16 kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
case GGML_TYPE_Q8_0: { case GGML_TYPE_Q8_0: {
HelperQ80 kh(k, stride_k); HelperQ80 kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
case GGML_TYPE_Q8_0_R8: { case GGML_TYPE_Q8_0_R8: {
HelperQ80R8<Dk> kh(k, stride_k); HelperQ80R8<Dk> kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
case GGML_TYPE_Q6_0: { case GGML_TYPE_Q6_0: {
HelperQ60 kh(k, stride_k); HelperQ60 kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
#if GGML_IQK_FA_ALL_QUANTS #if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q8_KV: { case GGML_TYPE_Q8_KV: {
HelperQ8KV<Dk> kh(k, stride_k); HelperQ8KV<Dk> kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
case GGML_TYPE_Q4_0: { case GGML_TYPE_Q4_0: {
HelperQ40 kh(k, stride_k); HelperQ40 kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
case GGML_TYPE_Q4_1: { case GGML_TYPE_Q4_1: {
HelperQ41 kh(k, stride_k); HelperQ41 kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
case GGML_TYPE_IQ4_NL: { case GGML_TYPE_IQ4_NL: {
HelperIQ4nl kh(k, stride_k); HelperIQ4nl kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break; } break;
#endif #endif
default: break; default: break;
@@ -2194,7 +2212,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,\ int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,\
const float * q, const void * k, const void * v, const void * mask,\ const float * q, const void * k, const void * v, const void * mask,\
float scale, float softcap,\ float scale, float softcap,\
float * qkv, float * M, float * S) float * qkv, const float * sinkf, float * M, float * S)
IQK_FA_CASE(iqk_fa_576_512); IQK_FA_CASE(iqk_fa_576_512);
IQK_FA_CASE(iqk_fa_192_128); IQK_FA_CASE(iqk_fa_192_128);

View File

@@ -66,6 +66,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
float scale, // scale applied before softmax float scale, // scale applied before softmax
float softcap, // if > 0, a "soft-cap" operation is applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q)) float * qkv, // v*softmax(scale*(k*q))
@@ -139,7 +140,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
auto work_this_thread = (float *)(result_buffer + ith*size_thread); auto work_this_thread = (float *)(result_buffer + ith*size_thread);
if (!iqk_flash_attn_impl(int_type_k, int_type_v, if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, nullptr, 0,
scale, softcap, scale, softcap,
work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false; work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false;
@@ -182,51 +183,6 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) { if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
auto result_size = (Dv + 16)*rk2*sizeof(float); auto result_size = (Dv + 16)*rk2*sizeof(float);
int gcd = simple_gcd(nek2, nth); int gcd = simple_gcd(nek2, nth);
if (false && gcd > 1) {
int nth_g = nth/gcd;
int ith_g = ith%nth_g;
int nek1_32 = nek1/32;
int nek1_pt = (nek1_32 + nth_g - 1)/nth_g;
int ith_mid = nth_g;
if (nek1_pt*nth_g > nek1_32) {
ith_mid = nek1_32 - nth_g*(nek1_pt - 1);
}
nek1_pt *= 32;
int nek1_mid = ith_mid*nek1_pt;
int nek1_thread = ith_g < ith_mid ? nek1_pt : nek1_pt - 32;
for (int ik02 = ith/nth_g; ik02 < nek2; ik02 += gcd) {
int ik01 = ith_g < ith_mid ? ith_g*nek1_pt : nek1_mid + (ith_g - ith_mid)*nek1_thread;
auto this_result = (float *)((char *)work_buffer + (ik02*nth_g + ith_g)*result_size);
auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, rk2, nek1_thread, nbq2, stride_k, stride_v, 0, Dv,
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
}
barrier(barrier_data);
for (int iq2 = ith; iq2 < neq2; iq2 += nth) {
int ik02 = iq2/rk2;
int il = iq2 - ik02*rk2;
auto Racc = qkv + iq2*nb1/sizeof(float);
float M = -INFINITY, S = 0;
for (int ig = 0; ig < nth_g; ++ig) {
int istep_k = ik02*nth_g + ig;
auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
const float * R = this_result + il*Dv;
const float * Mj = this_result + Dv*rk2;
const float * Sj = Mj + rk2;
accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
}
float norm = S > 0 ? 1/S : 1;
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
}
return true;
}
int nth_k = nth/gcd; int nth_k = nth/gcd;
int nek2_k = nek2/gcd; int nek2_k = nek2/gcd;
int nchunk = nek2_k*nek1/32; int nchunk = nek2_k*nek1/32;
@@ -259,7 +215,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
if (!iqk_flash_attn_impl(int_type_k, int_type_v, if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv, Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv,
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, nullptr, 0,
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false; scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
} }
@@ -281,6 +237,16 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
const float * Sj = Mj + rk2; const float * Sj = Mj + rk2;
accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R); accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
} }
if (sinks) {
float s = ((const float *)sinks)[iq2];
if (s > M) {
float m = expf(M - s);
for (int i = 0; i < Dv; ++i) Racc[i] *= m;
S = S*m + 1;
} else {
S += expf(s - M);
}
}
float norm = S > 0 ? 1/S : 1; float norm = S > 0 ? 1/S : 1;
for (int i = 0; i < Dv; ++i) Racc[i] *= norm; for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
} }
@@ -306,6 +272,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int counter = 0; int counter = 0;
for (int64_t iq3 = 0; iq3 < neq3; iq3++) { for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
for (int64_t iq2 = 0; iq2 < neq2; iq2++) { for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
auto sinksf = sinks ? (const float *)sinks + iq2 : nullptr;
if (counter++ % (nth/ntg) == ith/ntg) { if (counter++ % (nth/ntg) == ith/ntg) {
int iq1 = (ith%ntg)*neq1g; int iq1 = (ith%ntg)*neq1g;
int this_neq1 = std::min(neq1g, neq1-iq1); int this_neq1 = std::min(neq1g, neq1-iq1);
@@ -314,7 +281,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q), (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
(const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3), (const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3),
(const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3), (const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3),
(const void *)((const char *)mask + iq1*stride_m), (const void *)((const char *)mask + iq1*stride_m), sinksf, 1,
scale, softcap, scale, softcap,
(float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false; (float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
} }

View File

@@ -23,6 +23,8 @@ bool iqk_flash_attn_impl(int type_k, // type of k
const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
const float * sinksf, // attention sinks
int nsinks, // number of sinks
float scale, // scale applied before softmax float scale, // scale applied before softmax
float softcap, // if > 0, a "soft-cap" operation is applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q)) float * qkv, // v*softmax(scale*(k*q))

View File

@@ -120,16 +120,21 @@ struct MulMat {
funcs[n_left-1](n, vx, bx, info, nrc_x); funcs[n_left-1](n, vx, bx, info, nrc_x);
} }
} }
inline void gelu(int n, const float * src, float * dst); inline static void gelu(int n, const float * src, float * dst);
inline void relu(int n, const float * src, float * dst); inline static void relu(int n, const float * src, float * dst);
inline void silu(int n, const float * src, float * dst); inline static void silu(int n, const float * src, float * dst);
inline void activate(ggml_unary_op op, int n, const float * src, float * dst) { inline static void swiglu_oai(int n, const float * src, float * dst);
inline static void clamp_oai(int n, float *x);
inline static void activate(ggml_unary_op op, int n, const float * src, float * dst) {
if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst); if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst);
else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst); else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst);
else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst); else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst);
else if (op == GGML_UNARY_OP_SWIGLU_OAI) swiglu_oai(n, src, dst);
else GGML_ABORT("fatal error"); else GGML_ABORT("fatal error");
} }
inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, DataInfo& info, int nrc_x, int nrc_y, int unary_op) { inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx,
const float * up_b, const float * gate_b,
DataInfo& info, int nrc_x, int nrc_y, int unary_op) {
#ifdef __aarch64__ #ifdef __aarch64__
constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small) constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small)
#else #else
@@ -137,6 +142,29 @@ struct MulMat {
#endif #endif
auto op = ggml_unary_op(unary_op); auto op = ggml_unary_op(unary_op);
float tmp[k_x_step*16]; float tmp[k_x_step*16];
auto process = [&tmp, n, op, vx_gate, vx_up, gate_b, up_b, bx, xstep = k_x_step] (mul_mat_t func, const DataInfo& this_info, int ix, int this_nrc_x, int ny) {
func(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny; ++ky) {
if (gate_b) {
auto b = gate_b + ix;
auto x = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) x[j] += b[j];
}
activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*xstep);
}
func(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny; ++ky) {
auto result = this_info.dst_row(ky);
if (up_b) {
auto b = up_b + ix;
for (int j = 0; j < this_nrc_x; ++j) result[j] += b[j];
}
if (op == GGML_UNARY_OP_SWIGLU_OAI) {
clamp_oai(this_nrc_x, result);
}
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*xstep + j];
}
};
if (func16 && nrc_y >= 16) { if (func16 && nrc_y >= 16) {
int n_step = (nrc_y - info.cur_y)/16; int n_step = (nrc_y - info.cur_y)/16;
for (int ix = 0; ix < nrc_x; ix += k_x_step) { for (int ix = 0; ix < nrc_x; ix += k_x_step) {
@@ -144,15 +172,7 @@ struct MulMat {
this_info.s += ix; this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
for (int iy = 0; iy < n_step; ++iy) { for (int iy = 0; iy < n_step; ++iy) {
func16(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); process(func16, this_info, ix, this_nrc_x, 16);
for (int ky = 0; ky < 16; ++ky) {
activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
}
func16(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < 16; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
this_info.cur_y += 16; this_info.cur_y += 16;
} }
} }
@@ -175,23 +195,11 @@ struct MulMat {
this_info.s += ix; this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
for (int iy = 0; iy < my1; ++iy) { for (int iy = 0; iy < my1; ++iy) {
funcs[ny1-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); process(funcs[ny1-1], this_info, ix, this_nrc_x, ny1);
for (int ky = 0; ky < ny1; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
funcs[ny1-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny1; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
this_info.cur_y += ny1; this_info.cur_y += ny1;
} }
for (int iy = 0; iy < my2; ++iy) { for (int iy = 0; iy < my2; ++iy) {
funcs[ny2-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); process(funcs[ny2-1], this_info, ix, this_nrc_x, ny2);
for (int ky = 0; ky < ny2; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
funcs[ny2-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny2; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
this_info.cur_y += ny2; this_info.cur_y += ny2;
} }
} }
@@ -203,13 +211,7 @@ struct MulMat {
this_info.s += ix; this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
for (int iy = 0; iy < n_step; ++iy) { for (int iy = 0; iy < n_step; ++iy) {
funcs[ny-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); process(funcs[ny-1], this_info, ix, this_nrc_x, ny);
for (int ky = 0; ky < ny; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
funcs[ny-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
this_info.cur_y += ny; this_info.cur_y += ny;
} }
} }
@@ -222,13 +224,7 @@ struct MulMat {
auto this_info = info; auto this_info = info;
this_info.s += ix; this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
funcs[n_left-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); process(funcs[n_left-1], this_info, ix, this_nrc_x, n_left);
for (int ky = 0; ky < n_left; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
funcs[n_left-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < n_left; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
} }
} }
} }
@@ -731,6 +727,7 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
int typeA, const void * Aup, const void * Agate, long strideA, int typeA, const void * Aup, const void * Agate, long strideA,
int typeB, const void * B, long strideB, int typeB, const void * B, long strideB,
const char * up_b_c, const char * gate_b_c,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping; const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;
@@ -774,7 +771,9 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) { if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) {
GGML_ABORT("Fatal error"); GGML_ABORT("Fatal error");
} }
mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, this_info, this_nrc_x, Ny, unary_op); auto up_b = up_b_c ? (const float *)up_b_c + first_x + ix : nullptr;
auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x + ix : nullptr;
mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, up_b, gate_b, this_info, this_nrc_x, Ny, unary_op);
} }
return true; return true;
@@ -795,7 +794,10 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
nrc_x *= num_rows; nrc_x *= num_rows;
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny, unary_op); auto up_b = up_b_c ? (const float *)up_b_c + first_x : nullptr;
auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x : nullptr;
mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx,
up_b, gate_b, info, nrc_x, Ny, unary_op);
return true; return true;
} }
@@ -993,6 +995,46 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
namespace { namespace {
// TODO: these swiglu_oai constants shouldn't be hard coded
constexpr float k_swiglu_oai_alpha = 1.702f;
constexpr float k_swiglu_oai_limit = 7.f;
void MulMat::swiglu_oai(int n, const float * x, float * y) {
// int i = 0;
//#if defined __AVX512F__ && defined __AVX512DQ__
// {
// auto max = _mm512_set1_ps(k_swiglu_oai_limit);
// auto alpha = _mm512_set1_ps(-k_swiglu_oai_alpha);
// for (; i + 15 < n; i += 16) {
// auto xc = v_clamp_max(_mm512_loadu_ps(x + i), max);
// _mm512_storeu_ps(y + i, v_silu_oai(xc, alpha));
// }
// }
//#endif
//#if defined __AVX2__ && defined __FMA__
// if (i + 7 < n) {
// auto max = _mm256_set1_ps(k_swiglu_oai_limit);
// auto alpha = _mm256_set1_ps(-k_swiglu_oai_alpha);
// for (; i + 7 < n; i += 8) {
// auto xc = v_clamp_max(_mm256_loadu_ps(x + i), max);
// _mm256_storeu_ps(y + i, v_silu_oai(xc, alpha));
// }
// }
//#endif
// for (; i < n; ++i) {
// auto xi = std::min(x[i], k_swiglu_oai_limit);
// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
// }
for (int i = 0; i < n; ++i) {
auto xi = std::min(x[i], k_swiglu_oai_limit);
y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
}
}
void MulMat::clamp_oai(int n, float * x) {
for (int i = 0; i < n; ++i) x[i] = 1.f + std::max(std::min(x[i], k_swiglu_oai_limit), -k_swiglu_oai_limit);
}
#if defined(__ARM_NEON) && defined(__aarch64__) #if defined(__ARM_NEON) && defined(__aarch64__)
void MulMat::gelu(int n, const float * x, float * y) { void MulMat::gelu(int n, const float * x, float * y) {
constexpr float GELU_COEF_A = 0.044715f; constexpr float GELU_COEF_A = 0.044715f;
@@ -1040,6 +1082,37 @@ void MulMat::gelu(int n, const float * x, float * y) {
for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i]))); for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i])));
} }
//void MulMat::swiglu_oai(int n, const float * x, float * y) {
// int i = 0;
//#if defined __AVX512F__ && defined __AVX512DQ__
// {
// auto limit = _mm512_set1_ps(k_swiglu_oai_limit);
// auto alpha = _mm512_set1_ps(k_swiglu_oai_alpha);
// for (; i + 15 < n; i += 16) {
// auto xi = _mm512_loadu_ps(x + i);
// auto mask = _mm512_cmp
//
// }
// __m512 c1 = _mm512_set1_ps(GELU_COEF_A);
// __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI);
// for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_gelu(_mm512_loadu_ps(x + i), c1, c2));
// }
//#endif
//#if defined __AVX2__ && defined __FMA__
// if (i + 7 < n) {
// __m256 c1 = _mm256_set1_ps(GELU_COEF_A);
// __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI);
// for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_gelu(_mm256_loadu_ps(x + i), c1, c2));
//
// }
//#endif
// for (; i < n; ++i) {
// auto xi = std::min(x[i], k_swiglu_oai_limit);
// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
// }
//}
void MulMat::silu(int n, const float * x, float * y) { void MulMat::silu(int n, const float * x, float * y) {
int i = 0; int i = 0;
#if defined __AVX512F__ && defined __AVX512DQ__ #if defined __AVX512F__ && defined __AVX512DQ__
@@ -1188,6 +1261,8 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
const float * sinksf, // mask. If not null, assumed to be fp16. nq x nk elements
[[maybe_unused]] int nsinks,
float scale, // scale applied before softmax float scale, // scale applied before softmax
float softcap, // if > 0, a "soft-cap" operation is applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q)) float * qkv, // v*softmax(scale*(k*q))
@@ -1197,32 +1272,32 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
if (Dk == 576 && Dv == 512) { if (Dk == 576 && Dv == 512) {
return iqk_fa_576_512(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_fa_576_512(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, M, S); q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
} }
if (Dk == 192 && Dv == 128) { if (Dk == 192 && Dv == 128) {
return iqk_fa_192_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_fa_192_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, M, S); q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
} }
if (Dk == 256 && Dv == 256) { if (Dk == 256 && Dv == 256) {
return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, M, S); q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
} }
if (Dk == 128 && Dv == 128) { if (Dk == 128 && Dv == 128) {
return iqk_fa_128_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_fa_128_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, M, S); q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
} }
if (Dk == 96 && Dv == 96) { if (Dk == 96 && Dv == 96) {
return iqk_fa_96_96(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_fa_96_96(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, M, S); q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
} }
if (Dk == 64 && Dv == 64) { if (Dk == 64 && Dv == 64) {
return iqk_fa_64_64(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, return iqk_fa_64_64(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, M, S); q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
} }
return false; return false;

View File

@@ -32,6 +32,7 @@ IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
int typeA, const void * Aup, const void * Agate, long strideA, int typeA, const void * Aup, const void * Agate, long strideA,
int typeB, const void * B, long strideB, int typeB, const void * B, long strideB,
const char * up_b, const char * gate_b,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
IQK_API int iqk_dequant_type(int type, int Ny); IQK_API int iqk_dequant_type(int type, int Ny);
@@ -57,6 +58,7 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
float scale, // scale applied before softmax float scale, // scale applied before softmax
float softcap, // if > 0, a "soft-cap" operation is applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q)) float * qkv, // v*softmax(scale*(k*q))

View File

@@ -61,6 +61,13 @@ static inline float32x4_t v_silu(float32x4_t x) {
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
return vdivq_f32(x, one_plus_exp_neg_x); return vdivq_f32(x, one_plus_exp_neg_x);
} }
static inline float32x4_t v_silu_oai(float32x4_t x, float32x4_t alpha) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t neg_x = vmulq_f32(alpha, x);
const float32x4_t exp_neg_x = v_expf(neg_x);
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
return vdivq_f32(x, one_plus_exp_neg_x);
}
static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) { static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
const float32x4_t one = vdupq_n_f32(1.0f); const float32x4_t one = vdupq_n_f32(1.0f);
float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x)); float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
@@ -131,6 +138,17 @@ static inline __m512 v_silu(__m512 x) {
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
return _mm512_div_ps(x, one_plus_exp_neg_x); return _mm512_div_ps(x, one_plus_exp_neg_x);
} }
static inline __m512 v_silu_oai(__m512 x, __m512 alpha) {
const __m512 one = _mm512_set1_ps(1);
const __m512 neg_x = _mm512_mul_ps(alpha, x);
const __m512 exp_neg_x = v_expf(neg_x);
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
return _mm512_div_ps(x, one_plus_exp_neg_x);
}
static inline __m512 v_clamp_max(__m512 x, __m512 max) {
auto mask = _mm512_cmp_ps_mask(x, max, _CMP_GT_OQ);
return _mm512_mask_blend_ps(mask, x, max);
}
#endif // __AVX512__ #endif // __AVX512__
#if defined(__AVX2__) && defined(__FMA__) #if defined(__AVX2__) && defined(__FMA__)
@@ -195,12 +213,23 @@ static inline __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) {
} }
static inline __m256 v_silu(__m256 x) { static inline __m256 v_silu(__m256 x) {
const __m256 one = _mm256_set1_ps(1); const __m256 one = _mm256_set1_ps(1);
const __m256 zero = _mm256_setzero_ps(); const __m256 zero = _mm256_setzero_ps();
const __m256 neg_x = _mm256_sub_ps(zero, x); const __m256 neg_x = _mm256_sub_ps(zero, x);
const __m256 exp_neg_x = v_expf(neg_x); const __m256 exp_neg_x = v_expf(neg_x);
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
return _mm256_div_ps(x, one_plus_exp_neg_x); return _mm256_div_ps(x, one_plus_exp_neg_x);
} }
static inline __m256 v_silu_oai(__m256 x, __m256 alpha) {
const __m256 one = _mm256_set1_ps(1);
const __m256 neg_x = _mm256_mul_ps(alpha, x);
const __m256 exp_neg_x = v_expf(neg_x);
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
return _mm256_div_ps(x, one_plus_exp_neg_x);
}
static inline __m256 v_clamp_max(__m256 x, __m256 max) {
auto mask = _mm256_cmp_ps(x, max, _CMP_GT_OQ);
return _mm256_or_ps(_mm256_and_ps(mask, max), _mm256_andnot_ps(mask, x));
}
#endif // __AVX2__ #endif // __AVX2__

View File

@@ -70,50 +70,52 @@ extern "C" {
typedef int32_t llama_seq_id; typedef int32_t llama_seq_id;
enum llama_vocab_type { enum llama_vocab_type {
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming
}; };
// pre-tokenization types // pre-tokenization types
enum llama_vocab_pre_type { //enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, // LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, // LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, // LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, // LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
LLAMA_VOCAB_PRE_TYPE_FALCON = 4, // LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
LLAMA_VOCAB_PRE_TYPE_MPT = 5, // LLAMA_VOCAB_PRE_TYPE_MPT = 5,
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, // LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, // LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
LLAMA_VOCAB_PRE_TYPE_REFACT = 8, // LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, // LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, // LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, // LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
LLAMA_VOCAB_PRE_TYPE_OLMO = 12, // LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
LLAMA_VOCAB_PRE_TYPE_DBRX = 13, // LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, // LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
LLAMA_VOCAB_PRE_TYPE_PORO = 15, // LLAMA_VOCAB_PRE_TYPE_PORO = 15,
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, // LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, // LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
LLAMA_VOCAB_PRE_TYPE_VIKING = 18, // LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
LLAMA_VOCAB_PRE_TYPE_JAIS = 19, // LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, // LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, // LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, // LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, //llama.cpp lists this as 28 // LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, //llama.cpp lists this as 28
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, // LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, // LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, // LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, // LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, // LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_FALCON_3 = 34, // LLAMA_VOCAB_PRE_TYPE_FALCON_3 = 34,
LLAMA_VOCAB_PRE_TYPE_FALCON_E = 35, // LLAMA_VOCAB_PRE_TYPE_FALCON_E = 35,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 36, //llama.cpp lists this as 35 // LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 36, //llama.cpp lists this as 35
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 37, //llama.cpp lists this as 36 // LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 37, //llama.cpp lists this as 36
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 38, //llama.cpp lists this as 37 // LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 38, //llama.cpp lists this as 37
}; //};
// note: these values should be synchronized with ggml_rope // note: these values should be synchronized with ggml_rope
// TODO: maybe move this enum to ggml.h (ggml_rope_type) // TODO: maybe move this enum to ggml.h (ggml_rope_type)

View File

@@ -17,6 +17,8 @@ add_library(llama
llama-vocab.cpp llama-vocab.cpp
llama-grammar.cpp llama-grammar.cpp
llama-sampling.cpp llama-sampling.cpp
llama-mmap.cpp
llama-model-loader.cpp
unicode.h unicode.h
unicode.cpp unicode.cpp
unicode-data.cpp unicode-data.cpp

288
src/llama-arch.h Normal file
View File

@@ -0,0 +1,288 @@
#pragma once
#include <string>
enum llm_arch {
LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4,
LLM_ARCH_DECI,
LLM_ARCH_FALCON,
LLM_ARCH_BAICHUAN,
LLM_ARCH_GROK,
LLM_ARCH_GPT2,
LLM_ARCH_GPTJ,
LLM_ARCH_GPTNEOX,
LLM_ARCH_MPT,
LLM_ARCH_STARCODER,
LLM_ARCH_REFACT,
LLM_ARCH_BERT,
LLM_ARCH_NOMIC_BERT,
LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
LLM_ARCH_QWEN,
LLM_ARCH_QWEN2,
LLM_ARCH_QWEN2MOE,
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_PHI2,
LLM_ARCH_PHI3,
LLM_ARCH_PLAMO,
LLM_ARCH_CODESHELL,
LLM_ARCH_ORION,
LLM_ARCH_INTERNLM2,
LLM_ARCH_MINICPM,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R,
LLM_ARCH_DBRX,
LLM_ARCH_OLMO,
LLM_ARCH_OPENELM,
LLM_ARCH_ARCTIC,
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE,
LLM_ARCH_BITNET,
LLM_ARCH_BITNET_25,
LLM_ARCH_BITNET_B158,
LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
LLM_ARCH_JAIS,
LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE,
LLM_ARCH_COHERE2,
LLM_ARCH_DOTS1,
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_OPENAI_MOE,
LLM_ARCH_UNKNOWN,
};
enum llm_kv {
LLM_KV_GENERAL_TYPE,
LLM_KV_GENERAL_ARCHITECTURE,
LLM_KV_GENERAL_QUANTIZATION_VERSION,
LLM_KV_GENERAL_ALIGNMENT,
LLM_KV_GENERAL_NAME,
LLM_KV_GENERAL_AUTHOR,
LLM_KV_GENERAL_VERSION,
LLM_KV_GENERAL_URL,
LLM_KV_GENERAL_DESCRIPTION,
LLM_KV_GENERAL_LICENSE,
LLM_KV_GENERAL_SOURCE_URL,
LLM_KV_GENERAL_SOURCE_HF_REPO,
LLM_KV_VOCAB_SIZE,
LLM_KV_CONTEXT_LENGTH,
LLM_KV_EMBEDDING_LENGTH,
LLM_KV_BLOCK_COUNT,
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
LLM_KV_FEED_FORWARD_LENGTH,
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
LLM_KV_USE_PARALLEL_RESIDUAL,
LLM_KV_TENSOR_DATA_LAYOUT,
LLM_KV_EXPERT_COUNT,
LLM_KV_EXPERT_USED_COUNT,
LLM_KV_EXPERT_SHARED_COUNT,
LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC,
LLM_KV_NEXTN_PREDICT_LAYERS,
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
LLM_KV_SWIN_NORM,
LLM_KV_RESCALE_EVERY_N_LAYERS,
LLM_KV_TIME_MIX_EXTRA_DIM,
LLM_KV_TIME_DECAY_EXTRA_DIM,
LLM_KV_RESIDUAL_SCALE,
LLM_KV_EMBEDDING_SCALE,
LLM_KV_TOKEN_SHIFT_COUNT,
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
LLM_KV_ATTENTION_MAX_ALIBI_BIAS,
LLM_KV_ATTENTION_CLAMP_KQV,
LLM_KV_ATTENTION_KEY_LENGTH,
LLM_KV_ATTENTION_VALUE_LENGTH,
LLM_KV_ATTENTION_LAYERNORM_EPS,
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
LLM_KV_ATTENTION_CAUSAL,
LLM_KV_ATTENTION_Q_LORA_RANK,
LLM_KV_ATTENTION_KV_LORA_RANK,
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_SCALING_TYPE,
LLM_KV_ROPE_SCALING_FACTOR,
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
LLM_KV_ROPE_SCALING_FINETUNED,
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
LLM_KV_SPLIT_NO,
LLM_KV_SPLIT_COUNT,
LLM_KV_SPLIT_TENSORS_COUNT,
LLM_KV_SSM_INNER_SIZE,
LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_PRE,
LLM_KV_TOKENIZER_LIST,
LLM_KV_TOKENIZER_TOKEN_TYPE,
LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,
LLM_KV_TOKENIZER_SCORES,
LLM_KV_TOKENIZER_MERGES,
LLM_KV_TOKENIZER_BOS_ID,
LLM_KV_TOKENIZER_EOS_ID,
LLM_KV_TOKENIZER_UNK_ID,
LLM_KV_TOKENIZER_SEP_ID,
LLM_KV_TOKENIZER_PAD_ID,
LLM_KV_TOKENIZER_CLS_ID,
LLM_KV_TOKENIZER_MASK_ID,
LLM_KV_TOKENIZER_ADD_BOS,
LLM_KV_TOKENIZER_ADD_EOS,
LLM_KV_TOKENIZER_ADD_SEP,
LLM_KV_TOKENIZER_ADD_PREFIX,
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV,
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
LLM_KV_TOKENIZER_FIM_PRE_ID,
LLM_KV_TOKENIZER_FIM_SUF_ID,
LLM_KV_TOKENIZER_FIM_MID_ID,
LLM_KV_TOKENIZER_FIM_PAD_ID,
LLM_KV_TOKENIZER_FIM_REP_ID,
LLM_KV_TOKENIZER_FIM_SEP_ID,
LLM_KV_TOKENIZER_PREFIX_ID,
LLM_KV_TOKENIZER_SUFFIX_ID,
LLM_KV_TOKENIZER_MIDDLE_ID,
LLM_KV_TOKENIZER_EOT_ID,
LLM_KV_TOKENIZER_EOM_ID,
LLM_KV_ADAPTER_TYPE,
LLM_KV_ADAPTER_LORA_ALPHA,
};
struct LLM_KV {
LLM_KV(llm_arch arch, const char* suffix = nullptr);
llm_arch arch;
const char* suffix;
std::string operator()(llm_kv kv) const;
};
enum llm_tensor {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_TOKEN_EMBD_NORM,
LLM_TENSOR_TOKEN_TYPES,
LLM_TENSOR_POS_EMBD,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_ROPE_FREQS,
LLM_TENSOR_ROPE_FACTORS_LONG,
LLM_TENSOR_ROPE_FACTORS_SHORT,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_NORM_2,
LLM_TENSOR_ATTN_OUT_NORM,
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_ROT_EMBD,
LLM_TENSOR_ATTN_SINKS,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_POST_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_ACT,
LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility
LLM_TENSOR_FFN_GATE_EXP,
LLM_TENSOR_FFN_UP_EXP,
LLM_TENSOR_FFN_NORM_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP,
LLM_TENSOR_FFN_EXP_PROBS_B,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_LAYER_OUT_NORM,
LLM_TENSOR_SSM_IN,
LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_X,
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_ATTN_Q_A,
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
LLM_TENSOR_FFN_SUB_NORM,
LLM_TENSOR_DEC_ATTN_NORM,
LLM_TENSOR_DEC_ATTN_Q,
LLM_TENSOR_DEC_ATTN_K,
LLM_TENSOR_DEC_ATTN_V,
LLM_TENSOR_DEC_ATTN_OUT,
LLM_TENSOR_DEC_ATTN_REL_B,
LLM_TENSOR_DEC_CROSS_ATTN_NORM,
LLM_TENSOR_DEC_CROSS_ATTN_Q,
LLM_TENSOR_DEC_CROSS_ATTN_K,
LLM_TENSOR_DEC_CROSS_ATTN_V,
LLM_TENSOR_DEC_CROSS_ATTN_OUT,
LLM_TENSOR_DEC_CROSS_ATTN_REL_B,
LLM_TENSOR_DEC_FFN_NORM,
LLM_TENSOR_DEC_FFN_GATE,
LLM_TENSOR_DEC_FFN_DOWN,
LLM_TENSOR_DEC_FFN_UP,
LLM_TENSOR_DEC_OUTPUT_NORM,
LLM_TENSOR_ENC_ATTN_NORM,
LLM_TENSOR_ENC_ATTN_Q,
LLM_TENSOR_ENC_ATTN_K,
LLM_TENSOR_ENC_ATTN_V,
LLM_TENSOR_ENC_ATTN_OUT,
LLM_TENSOR_ENC_ATTN_REL_B,
LLM_TENSOR_ENC_FFN_NORM,
LLM_TENSOR_ENC_FFN_GATE,
LLM_TENSOR_ENC_FFN_DOWN,
LLM_TENSOR_ENC_FFN_UP,
LLM_TENSOR_ENC_OUTPUT_NORM,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,
LLM_TENSOR_NEXTN_HNORM,
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
};
llm_arch llm_arch_from_string(const std::string & name);

View File

@@ -486,9 +486,9 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id; const llama_token id = candidates->data[i].id;
const std::string & piece = vocab->cache_token_to_piece.at(id); const std::string & piece = vocab->token_to_piece(id);
if (llama_token_is_eog_impl(*vocab, id)) { if (vocab->is_eog(id)) {
if (!allow_eog) { if (!allow_eog) {
candidates->data[i].logit = -INFINITY; candidates->data[i].logit = -INFINITY;
} }
@@ -511,7 +511,7 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
if (llama_token_is_eog_impl(*vocab, token)) { if (vocab->is_eog(token)) {
for (const auto & stack : grammar->stacks) { for (const auto & stack : grammar->stacks) {
if (stack.empty()) { if (stack.empty()) {
return; return;
@@ -520,7 +520,7 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
const std::string & piece = vocab->cache_token_to_piece.at(token); const std::string & piece = vocab->token_to_piece(token);
// Note terminating 0 in decoded string // Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8); const auto decoded = decode_utf8(piece, grammar->partial_utf8);

View File

@@ -10,6 +10,11 @@
#define LLAMA_API_INTERNAL #define LLAMA_API_INTERNAL
#include "llama.h" #include "llama.h"
#include <stdexcept> #include <stdexcept>
#include <climits>
#include <cstdarg>
#include <vector>
#include <cinttypes>
#include <cstring>
#ifdef __GNUC__ #ifdef __GNUC__
#ifdef __MINGW32__ #ifdef __MINGW32__
@@ -33,6 +38,7 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) #define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
#define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
// //
// helpers // helpers
@@ -166,3 +172,49 @@ struct ring_buffer {
size_t pos = 0; size_t pos = 0;
std::vector<T> data; std::vector<T> data;
}; };
LLAMA_ATTRIBUTE_FORMAT(1, 2)
static std::string format(const char * fmt, ...) {
va_list ap;
va_list ap2;
va_start(ap, fmt);
va_copy(ap2, ap);
int size = vsnprintf(NULL, 0, fmt, ap);
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
std::vector<char> buf(size + 1);
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
GGML_ASSERT(size2 == size);
va_end(ap2);
va_end(ap);
return std::string(buf.data(), size);
}
static std::string llama_format_tensor_shape(const std::vector<int64_t> & ne) {
char buf[256];
snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0));
for (size_t i = 1; i < ne.size(); i++) {
snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i));
}
return buf;
}
static std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
char buf[256];
snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]);
}
return buf;
}
template <typename T>
struct no_init {
T value;
no_init() { /* do nothing */ }
};
struct gguf_context;
std::string gguf_kv_to_str(const gguf_context * ctx_gguf, int i);
ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer);

650
src/llama-mmap.cpp Normal file
View File

@@ -0,0 +1,650 @@
#include "llama-mmap.h"
#include "llama-impl.h"
#include "ggml.h"
#include <cstring>
#include <climits>
#include <stdexcept>
#include <cerrno>
#include <algorithm>
#include <fstream>
#include <sstream>
#ifdef __has_include
#if __has_include(<unistd.h>)
#include <unistd.h>
#if defined(_POSIX_MAPPED_FILES)
#include <sys/mman.h>
#include <fcntl.h>
#endif
#if defined(_POSIX_MEMLOCK_RANGE)
#include <sys/resource.h>
#endif
#endif
#endif
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#ifndef PATH_MAX
#define PATH_MAX MAX_PATH
#endif
#include <io.h>
#endif
#if defined(__APPLE__)
#include <TargetConditionals.h>
#endif
// TODO: consider moving to llama-impl.h if needed in more places
#if defined(_WIN32)
static std::string llama_format_win_err(DWORD err) {
LPSTR buf;
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
if (!size) {
return "FormatMessageA failed";
}
std::string ret(buf, size);
LocalFree(buf);
return ret;
}
#endif
// llama_file
struct llama_file::impl {
#if defined(_WIN32)
HANDLE fp_win32;
std::string GetErrorMessageWin32(DWORD error_code) const {
std::string ret;
LPSTR lpMsgBuf = NULL;
DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
if (!bufLen) {
ret = format("Win32 error code: %lx", error_code);
} else {
ret = lpMsgBuf;
LocalFree(lpMsgBuf);
}
return ret;
}
impl(const char * fname, const char * mode) {
fp = ggml_fopen(fname, mode);
if (fp == NULL) {
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
}
fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp));
seek(0, SEEK_END);
size = tell();
seek(0, SEEK_SET);
}
size_t tell() const {
LARGE_INTEGER li;
li.QuadPart = 0;
BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT);
if (!ret) {
throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
}
return li.QuadPart;
}
void seek(size_t offset, int whence) const {
static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN");
static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT");
static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END");
LARGE_INTEGER li;
li.QuadPart = offset;
BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence);
if (!ret) {
throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
}
}
void read_raw(void * ptr, size_t len) const {
size_t bytes_read = 0;
while (bytes_read < len) {
size_t chunk_size = std::min<size_t>(len - bytes_read, 64*1024*1024);
DWORD chunk_read = 0;
BOOL result = ReadFile(fp_win32, reinterpret_cast<char*>(ptr) + bytes_read, chunk_size, &chunk_read, NULL);
if (!result) {
throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
}
if (chunk_read < chunk_size || chunk_read == 0) {
throw std::runtime_error("unexpectedly reached end of file");
}
bytes_read += chunk_read;
}
}
uint32_t read_u32() const {
uint32_t val;
read_raw(&val, sizeof(val));
return val;
}
void write_raw(const void * ptr, size_t len) const {
size_t bytes_written = 0;
while (bytes_written < len) {
size_t chunk_size = std::min<size_t>(len - bytes_written, 64*1024*1024);
DWORD chunk_written = 0;
BOOL result = WriteFile(fp_win32, reinterpret_cast<char const*>(ptr) + bytes_written, chunk_size, &chunk_written, NULL);
if (!result) {
throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
}
if (chunk_written < chunk_size || chunk_written == 0) {
throw std::runtime_error("unexpectedly failed to write bytes");
}
bytes_written += chunk_written;
}
}
void write_u32(uint32_t val) const {
write_raw(&val, sizeof(val));
}
~impl() {
if (fp) {
std::fclose(fp);
}
}
#else
impl(const char * fname, const char * mode) {
fp = ggml_fopen(fname, mode);
if (fp == NULL) {
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
}
seek(0, SEEK_END);
size = tell();
seek(0, SEEK_SET);
}
size_t tell() const {
// TODO: this ifdef is never true?
#ifdef _WIN32
__int64 ret = _ftelli64(fp);
#else
long ret = std::ftell(fp);
#endif
if (ret == -1) {
throw std::runtime_error(format("ftell error: %s", strerror(errno)));
}
return (size_t) ret;
}
void seek(size_t offset, int whence) const {
// TODO: this ifdef is never true?
#ifdef _WIN32
int ret = _fseeki64(fp, (__int64) offset, whence);
#else
int ret = std::fseek(fp, (long) offset, whence);
#endif
if (ret != 0) {
throw std::runtime_error(format("seek error: %s", strerror(errno)));
}
}
void read_raw(void * ptr, size_t len) const {
if (len == 0) {
return;
}
errno = 0;
std::size_t ret = std::fread(ptr, len, 1, fp);
if (ferror(fp)) {
throw std::runtime_error(format("read error: %s", strerror(errno)));
}
if (ret != 1) {
throw std::runtime_error("unexpectedly reached end of file");
}
}
uint32_t read_u32() const {
uint32_t ret;
read_raw(&ret, sizeof(ret));
return ret;
}
void write_raw(const void * ptr, size_t len) const {
if (len == 0) {
return;
}
errno = 0;
size_t ret = std::fwrite(ptr, len, 1, fp);
if (ret != 1) {
throw std::runtime_error(format("write error: %s", strerror(errno)));
}
}
void write_u32(uint32_t val) const {
write_raw(&val, sizeof(val));
}
~impl() {
if (fp) {
std::fclose(fp);
}
}
#endif
FILE * fp;
size_t size;
};
llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique<impl>(fname, mode)) {}
llama_file::~llama_file() = default;
size_t llama_file::tell() const { return pimpl->tell(); }
size_t llama_file::size() const { return pimpl->size; }
int llama_file::file_id() const {
#ifdef _WIN32
return _fileno(pimpl->fp);
#else
#if defined(fileno)
return fileno(pimpl->fp);
#else
return ::fileno(pimpl->fp);
#endif
#endif
}
void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); }
uint32_t llama_file::read_u32() const { return pimpl->read_u32(); }
void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); }
void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); }
// llama_mmap
struct llama_mmap::impl {
#ifdef _POSIX_MAPPED_FILES
std::vector<std::pair<size_t, size_t>> mapped_fragments;
impl(struct llama_file * file, size_t prefetch, bool numa, bool use_thp) {
size = file->size();
int fd = file->file_id();
int flags = MAP_SHARED;
if (numa) { prefetch = 0; }
#ifdef __linux__
if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) {
LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n",
strerror(errno));
}
if (prefetch) { flags |= MAP_POPULATE; }
if (use_thp) {
size_t huge = get_default_huge_page_size();
auto size = huge*((file->size() + huge - 1)/huge);
addr = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB, -1, 0);
if (addr != MAP_FAILED) {
printf("%s: using THP with page size %zu MiB ", __func__, huge/(1024*1024));
fflush(stdout);
size_t tot = 0;
while (tot < file->size()) {
auto n_read = pread(fd, static_cast<char*>(addr) + tot, file->size() - tot, tot);
if (n_read < 0) throw std::runtime_error(format("Reading into mapped huge pages failed at %zu (%s)", tot, strerror(errno)));
printf("."); fflush(stdout);
tot += n_read;
}
printf(" done\n");
mapped_fragments.emplace_back(0, file->size());
mapped_page_size = huge;
return;
}
else {
fprintf(stderr, "%s: mmap with huge page size %zu MiB failed (%s)\n", __func__, huge/(1024*1024), strerror(errno));
}
}
#endif
addr = mmap(NULL, file->size(), PROT_READ, flags, fd, 0);
if (addr == MAP_FAILED) {
throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
}
if (prefetch > 0) {
if (posix_madvise(addr, std::min(file->size(), prefetch), POSIX_MADV_WILLNEED)) {
LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
strerror(errno));
}
}
if (numa) {
if (posix_madvise(addr, file->size(), POSIX_MADV_RANDOM)) {
LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n",
strerror(errno));
}
}
mapped_fragments.emplace_back(0, file->size());
}
#ifdef __linux__
static int get_default_huge_page_size() {
int pg_size = 2048;
std::ifstream in("/proc/meminfo");
if (in) {
std::string line;
while (true) {
std::getline(in, line);
if (in.fail()) break;
if (auto pos = line.find("Hugepagesize:"); pos != std::string::npos) {
std::istringstream str(line.data() + pos + 13);
int aux;
str >> aux;
if (!str.fail()) pg_size = aux;
break;
}
}
}
return pg_size * 1024;
}
#endif
static void align_range(size_t * first, size_t * last, size_t page_size) {
size_t offset_in_page = *first & (page_size - 1);
size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page;
*first += offset_to_page;
*last = *last & ~(page_size - 1);
if (*last <= *first) {
*last = *first;
}
}
void unmap_fragment(size_t first, size_t last) {
int page_size = mapped_page_size > 0 ? mapped_page_size : sysconf(_SC_PAGESIZE);
align_range(&first, &last, page_size);
size_t len = last - first;
if (len == 0) {
return;
}
GGML_ASSERT(first % page_size == 0);
GGML_ASSERT(last % page_size == 0);
GGML_ASSERT(last > first);
void * next_page_start = (uint8_t *) addr + first;
if (munmap(next_page_start, len)) {
LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
}
std::vector<std::pair<size_t, size_t>> new_mapped_fragments;
for (const auto & frag : mapped_fragments) {
if (frag.first < first && frag.second > last) {
new_mapped_fragments.emplace_back(frag.first, first);
new_mapped_fragments.emplace_back(last, frag.second);
} else if (frag.first < first && frag.second > first) {
new_mapped_fragments.emplace_back(frag.first, first);
} else if (frag.first < last && frag.second > last) {
new_mapped_fragments.emplace_back(last, frag.second);
} else if (frag.first >= first && frag.second <= last) {
} else {
new_mapped_fragments.push_back(frag);
}
}
mapped_fragments = std::move(new_mapped_fragments);
}
~impl() {
for (const auto & frag : mapped_fragments) {
if (munmap((char *) addr + frag.first, frag.second - frag.first)) {
LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
}
}
}
#elif defined(_WIN32)
impl(struct llama_file * file, size_t prefetch, bool numa, [[maybe_unused]] bool use_thp) {
GGML_UNUSED(numa);
size = file->size();
HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id());
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
if (hMapping == NULL) {
DWORD error = GetLastError();
throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
}
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
DWORD error = GetLastError();
CloseHandle(hMapping);
if (addr == NULL) {
throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
}
if (prefetch > 0) {
#if _WIN32_WINNT >= 0x602
BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG);
HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll");
pPrefetchVirtualMemory = (decltype(pPrefetchVirtualMemory))(void *) GetProcAddress(hKernel32, "PrefetchVirtualMemory");
if (pPrefetchVirtualMemory) {
WIN32_MEMORY_RANGE_ENTRY range;
range.VirtualAddress = addr;
range.NumberOfBytes = (SIZE_T) std::min(size, prefetch);
if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n");
#endif
}
}
void unmap_fragment(size_t first, size_t last) {
GGML_UNUSED(first);
GGML_UNUSED(last);
}
~impl() {
if (!UnmapViewOfFile(addr)) {
LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
impl(struct llama_file * file, size_t prefetch, bool numa, [[maybe_unused]] bool use_thp) {
GGML_UNUSED(file);
GGML_UNUSED(prefetch);
GGML_UNUSED(numa);
throw std::runtime_error("mmap not supported");
}
void unmap_fragment(size_t first, size_t last) {
GGML_UNUSED(first);
GGML_UNUSED(last);
throw std::runtime_error("mmap not supported");
}
#endif
void * addr;
size_t size;
size_t mapped_page_size = 0;
};
llama_mmap::llama_mmap(struct llama_file * file, size_t prefetch, bool numa, bool use_thp) :
pimpl(std::make_unique<impl>(file, prefetch, numa, use_thp)) {}
llama_mmap::~llama_mmap() = default;
size_t llama_mmap::size() const { return pimpl->size; }
void * llama_mmap::addr() const { return pimpl->addr; }
void llama_mmap::unmap_fragment(size_t first, size_t last) { pimpl->unmap_fragment(first, last); }
#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32)
const bool llama_mmap::SUPPORTED = true;
#else
const bool llama_mmap::SUPPORTED = false;
#endif
// llama_mlock
struct llama_mlock::impl {
#ifdef _POSIX_MEMLOCK_RANGE
static size_t lock_granularity() {
return (size_t) sysconf(_SC_PAGESIZE);
}
bool raw_lock(const void * addr, size_t size) const {
if (!mlock(addr, size)) {
return true;
}
#ifdef __APPLE__
#define MLOCK_SUGGESTION \
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n"
#else
#define MLOCK_SUGGESTION \
"Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n"
#endif
char* errmsg = std::strerror(errno);
bool suggest = (errno == ENOMEM);
#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX)
// visionOS/tvOS dont't support RLIMIT_MEMLOCK
// Skip resource limit checks on visionOS/tvOS
suggest = false;
#else
struct rlimit lock_limit;
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
suggest = false;
}
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
suggest = false;
}
#endif
LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
return false;
}
static void raw_unlock(void * addr, size_t size) {
if (munlock(addr, size)) {
LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno));
}
}
#elif defined(_WIN32)
static size_t lock_granularity() {
SYSTEM_INFO si;
GetSystemInfo(&si);
return (size_t) si.dwPageSize;
}
bool raw_lock(void * ptr, size_t len) const {
for (int tries = 1; ; tries++) {
if (VirtualLock(ptr, len)) {
return true;
}
if (tries == 2) {
LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
len, size, llama_format_win_err(GetLastError()).c_str());
return false;
}
SIZE_T min_ws_size, max_ws_size;
if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
return false;
}
size_t increment = len + 1048576;
min_ws_size += increment;
max_ws_size += increment;
if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
return false;
}
}
}
static void raw_unlock(void * ptr, size_t len) {
if (!VirtualUnlock(ptr, len)) {
LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
static size_t lock_granularity() {
return (size_t) 65536;
}
bool raw_lock(const void * addr, size_t len) const {
LLAMA_LOG_WARN("warning: mlock not supported on this system\n");
return false;
}
static void raw_unlock(const void * addr, size_t len) {}
#endif
impl() : addr(NULL), size(0), failed_already(false) {}
void init(void * ptr) {
GGML_ASSERT(addr == NULL && size == 0);
addr = ptr;
}
void grow_to(size_t target_size) {
GGML_ASSERT(addr);
if (failed_already) {
return;
}
size_t granularity = lock_granularity();
target_size = (target_size + granularity - 1) & ~(granularity - 1);
if (target_size > size) {
if (raw_lock((uint8_t *) addr + size, target_size - size)) {
size = target_size;
} else {
failed_already = true;
}
}
}
void * addr;
size_t size;
bool failed_already;
};
llama_mlock::llama_mlock() : pimpl(std::make_unique<impl>()) {}
llama_mlock::~llama_mlock() = default;
void llama_mlock::init(void * ptr) { pimpl->init(ptr); }
void llama_mlock::grow_to(size_t target_size) { pimpl->grow_to(target_size); }
#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32)
const bool llama_mlock::SUPPORTED = true;
#else
const bool llama_mlock::SUPPORTED = false;
#endif
size_t llama_path_max() {
return PATH_MAX;
}

68
src/llama-mmap.h Normal file
View File

@@ -0,0 +1,68 @@
#pragma once
#include <cstdint>
#include <memory>
#include <vector>
struct llama_file;
struct llama_mmap;
struct llama_mlock;
using llama_files = std::vector<std::unique_ptr<llama_file>>;
using llama_mmaps = std::vector<std::unique_ptr<llama_mmap>>;
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
struct llama_file {
llama_file(const char * fname, const char * mode);
~llama_file();
size_t tell() const;
size_t size() const;
int file_id() const; // fileno overload
void seek(size_t offset, int whence) const;
void read_raw(void * ptr, size_t len) const;
uint32_t read_u32() const;
void write_raw(const void * ptr, size_t len) const;
void write_u32(uint32_t val) const;
private:
struct impl;
std::unique_ptr<impl> pimpl;
};
struct llama_mmap {
llama_mmap(const llama_mmap &) = delete;
llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false, bool use_thp = false);
~llama_mmap();
size_t size() const;
void * addr() const;
void unmap_fragment(size_t first, size_t last);
static const bool SUPPORTED;
private:
struct impl;
std::unique_ptr<impl> pimpl;
};
struct llama_mlock {
llama_mlock();
~llama_mlock();
void init(void * ptr);
void grow_to(size_t target_size);
static const bool SUPPORTED;
private:
struct impl;
std::unique_ptr<impl> pimpl;
};
size_t llama_path_max();

1082
src/llama-model-loader.cpp Normal file

File diff suppressed because it is too large Load Diff

169
src/llama-model-loader.h Normal file
View File

@@ -0,0 +1,169 @@
#pragma once
#include "llama.h"
#include "llama-impl.h"
#include "llama-mmap.h"
#include "llama-arch.h"
#include <cstdint>
#include <cstddef>
#include <stdexcept>
#include <unordered_map>
#include <vector>
enum llama_fver {
GGUF_FILE_VERSION_V1 = 1,
GGUF_FILE_VERSION_V2 = 2,
GGUF_FILE_VERSION_V3 = 3,
};
static const char * llama_file_version_name(llama_fver version) {
switch (version) {
case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)";
case GGUF_FILE_VERSION_V2: return "GGUF V2";
case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)";
}
return "unknown";
}
using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>;
struct llama_model_loader {
int n_kv = 0;
int n_tensors = 0;
int n_created = 0;
int64_t n_elements = 0;
size_t n_bytes = 0;
bool use_mmap = false;
bool check_tensors;
bool repack_tensors = false;
bool use_thp = false;
llama_files files;
llama_ftype ftype;
llama_fver fver;
llama_mmaps mappings;
// Holds information on a model weight
struct llama_tensor_weight {
uint16_t idx; // source file index
size_t offs; // tensor data offset in the original file
ggml_tensor * tensor;
llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size()) {
throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name));
}
}
};
std::vector<llama_tensor_weight> weights;
std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
const llama_model_tensor_buft_override * tensor_buft_overrides;
gguf_context * meta = NULL;
std::vector<ggml_context *> contexts;
std::string arch_name;
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp,
const llama_model_kv_override * param_overrides_p,
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
~llama_model_loader();
template<typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type
get_arr_n(const std::string & key, T & result, const bool required = true);
template<typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type
get_arr_n(const enum llm_kv kid, T & result, const bool required = true);
template<typename T>
bool get_arr(const std::string & key, std::vector<T> & result, const bool required = true);
template<typename T, size_t N_MAX>
bool get_arr(const std::string & key, std::array<T, N_MAX> & result, const bool required = true);
template<typename T>
bool get_arr(const enum llm_kv kid, T & result, const bool required = true);
template<typename T>
bool get_key(const std::string & key, T & result, const bool required = true);
template<typename T>
bool get_key(const enum llm_kv kid, T & result, const bool required = true);
// get array of n <= N_MAX elements, or a single element repeated n times
template<typename T, size_t N_MAX>
bool get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, const bool required = true);
template<typename T>
bool get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required = true);
const std::string& get_arch_name() const { return arch_name; }
enum llm_arch get_arch() const { return llm_kv.arch; }
const char * get_tensor_name(int i) const;
const llama_tensor_weight * get_weight(const char * name) const;
const llama_tensor_weight * get_weight(int i) const {
return get_weight(get_tensor_name(i));
}
const llama_tensor_weight & require_weight(const char * name) const;
struct ggml_tensor * get_tensor_meta(const char * name) const;
struct ggml_tensor * require_tensor_meta(const char * name) const;
struct ggml_tensor * get_tensor_meta(int i) const {
return get_tensor_meta(get_tensor_name(i));
}
struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated);
const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector<int64_t> & ne, bool required) const;
static const int TENSOR_NOT_REQUIRED = 1 << 0;
static const int TENSOR_DUPLICATED = 1 << 1;
static const int TENSOR_SKIP = 1 << 2;
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, int flags = 0);
struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base,
const std::string & name, const std::vector<int64_t> & ne, size_t offset, bool required = true);
void done_getting_tensors() const;
void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr, bool use_thp = false);
void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const;
// for backwards compatibility, does not support ggml-backend
void load_data_for(struct ggml_tensor * cur) const;
size_t size_done = 0;
size_t size_data = 0;
std::vector<std::pair<size_t, size_t>> mmaps_used;
// Returns false if cancelled by progress_callback
bool load_all_data(
struct ggml_context * ctx,
llama_buf_map & bufs_mmap,
llama_mlocks * lmlocks,
llama_progress_callback progress_callback,
void * progress_callback_user_data);
};

View File

@@ -734,7 +734,7 @@ llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_da
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
static void get_overlapping_token_sequences(const llama_vocab& vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) { static void get_overlapping_token_sequences(const llama_vocab& vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_tokens(); token_id++) { for (llama_token token_id = 0; token_id < (llama_token)vocab.n_tokens(); token_id++) {
std::string word = llama_detokenize(vocab, { token_id }, true); auto word = vocab.detokenize( { token_id }, true);
if (word.find(str) != std::string::npos) { if (word.find(str) != std::string::npos) {
token_sequences.emplace(token_id, std::vector<llama_token>()); token_sequences.emplace(token_id, std::vector<llama_token>());
} }
@@ -751,7 +751,8 @@ static void get_overlapping_token_sequences(const llama_vocab& vocab, const std:
} }
} }
if (match) { if (match) {
std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false); auto tokenization = vocab.tokenize(str.substr(i), false, false);
//std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
tokenization.resize(max_tail_len); tokenization.resize(max_tail_len);
} }

File diff suppressed because it is too large Load Diff

View File

@@ -1,155 +1,178 @@
#pragma once #pragma once
#include "llama-impl.h" #include "llama.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <memory>
#include <map>
// pre-tokenization types
enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38,
};
struct LLM_KV;
struct llama_model_loader;
struct llama_vocab { struct llama_vocab {
using id = llama_token;
using token = std::string;
using tattr = llama_token_attr;
struct token_data { struct token_data {
token text; std::string text;
float score; float score;
tattr attr; llama_token_attr attr;
}; };
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; llama_vocab();
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; ~llama_vocab();
int max_token_len = 0; // used for optimizing longest token search void load(llama_model_loader & ml, const LLM_KV & kv);
std::string get_tokenizer_model() const;
std::string get_tokenizer_pre() const;
enum llama_vocab_type get_type() const;
enum llama_vocab_pre_type get_pre_type() const;
uint32_t n_tokens() const; uint32_t n_tokens() const;
uint32_t n_token_types() const;
std::unordered_map<token, id> token_to_id; std::string type_name() const;
std::vector<token_data> id_to_token;
std::vector<id> cache_special_tokens; bool is_normal (llama_token id) const;
std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = true); bool is_unknown (llama_token id) const;
bool is_control (llama_token id) const;
bool is_byte (llama_token id) const;
bool is_user_defined(llama_token id) const;
bool is_unused (llama_token id) const;
bool is_eog (llama_token id) const;
std::map<std::pair<std::string, std::string>, int> bpe_ranks; uint8_t token_to_byte(llama_token id) const;
llama_token byte_to_token(uint8_t ch) const;
// default LLaMA special tokens llama_token text_to_token(const std::string & text) const;
id special_bos_id = 1;
id special_eos_id = 2;
id special_unk_id = 0;
id special_sep_id = -1;
id special_pad_id = -1;
id special_cls_id = -1;
id special_mask_id = -1;
id linefeed_id = 13; const token_data & get_token_data(llama_token id) const;
// fim tokens const char * token_get_text (llama_token id) const;
llama_token special_fim_pre_id = -1; float token_get_score(llama_token id) const;
llama_token special_fim_suf_id = -1; llama_token_attr token_get_attr (llama_token id) const;
llama_token special_fim_mid_id = -1;
llama_token special_fim_pad_id = -1;
llama_token special_fim_rep_id = -1; // repo
llama_token special_fim_sep_id = -1; // file separator
id special_prefix_id = -1; llama_token token_bos() const;
id special_suffix_id = -1; llama_token token_eos() const;
id special_middle_id = -1; llama_token token_eot() const;
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token llama_token token_eom() const;
id special_eom_id = -1; llama_token token_unk() const;
llama_token token_sep() const;
llama_token token_nl () const;
llama_token token_pad() const;
llama_token token_mask() const;
// tokenizer flags llama_token token_prefix() const;
bool tokenizer_add_space_prefix = false; llama_token token_middle() const;
bool tokenizer_add_bos = false; llama_token token_suffix() const;
bool tokenizer_add_eos = false;
bool tokenizer_ignore_merges = false;
bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
bool tokenizer_remove_extra_whitespaces = false;
bool tokenizer_escape_whitespaces = true;
bool tokenizer_treat_whitespace_as_suffix = false;
std::vector<char> precompiled_charsmap; llama_token token_fim_pre() const;
llama_token token_fim_suf() const;
llama_token token_fim_mid() const;
llama_token token_fim_pad() const;
llama_token token_fim_rep() const;
llama_token token_fim_sep() const;
bool get_add_space_prefix () const;
bool get_add_bos () const;
bool get_add_eos () const;
bool get_add_sep () const;
bool get_ignore_merges () const;
bool get_clean_spaces () const;
bool get_remove_extra_whitespaces () const;
bool get_escape_whitespaces () const;
bool get_treat_whitespace_as_suffix() const;
int max_token_len() const;
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
std::vector<std::string> get_bpe_merges() const;
std::vector<char> get_precompiled_charsmap() const;
int32_t tokenize(
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special) const;
std::vector<llama_token> tokenize(
const std::string & raw_text,
bool add_special,
bool parse_special = false) const;
// does not write null-terminator to buf
int32_t token_to_piece(
llama_token token,
char * buf,
int32_t length,
int32_t lstrip,
bool special) const;
// use cached data
const std::string & token_to_piece(llama_token token) const;
int32_t detokenize(
const llama_token * tokens,
int32_t n_tokens,
char * text,
int32_t text_len_max,
bool remove_special,
bool unparse_special) const;
std::string detokenize(
const std::vector<llama_token> & tokens,
bool special) const;
void print_info() const;
private:
struct impl;
std::unique_ptr<impl> pimpl;
}; };
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx); const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
//
// internal API
//
// TODO: rename to llama_tokenize_impl
// TODO: This should probably be in llama.h
std::vector<llama_vocab::id> llama_tokenize_internal(
const llama_vocab & vocab,
std::string raw_text,
bool add_special,
bool parse_special = false);
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab);
int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
llama_token llama_token_eot_impl (const struct llama_vocab & vocab);
llama_token llama_token_eom_impl (const struct llama_vocab & vocab);
int32_t llama_tokenize_impl(
const struct llama_vocab & vocab,
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special);
// does not write null-terminator to buf
int32_t llama_token_to_piece_impl(
const struct llama_vocab & vocab,
llama_token token,
char * buf,
int32_t length,
int32_t lstrip,
bool special);
int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
int32_t n_tokens,
char * text,
int32_t text_len_max,
bool remove_special,
bool unparse_special);
std::string llama_detokenize(
const struct llama_vocab& vocab,
const std::vector<llama_token>& tokens,
bool special);

File diff suppressed because it is too large Load Diff

View File

@@ -5,20 +5,19 @@
#include "unicode.h" #include "unicode.h"
#include "unicode-data.h" #include "unicode-data.h"
#include <algorithm>
#include <cassert> #include <cassert>
#include <codecvt>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <locale>
#include <map> #include <map>
#include <regex> #include <regex>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <locale>
#include <codecvt>
#include <iostream>
size_t unicode_len_utf8(char src) { size_t unicode_len_utf8(char src) {
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
@@ -26,7 +25,7 @@ size_t unicode_len_utf8(char src) {
return lookup[highbits]; return lookup[highbits];
} }
static std::string unicode_cpts_to_utf8(const std::vector<uint32_t>& cps) { static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
std::string result; std::string result;
for (size_t i = 0; i < cps.size(); ++i) { for (size_t i = 0; i < cps.size(); ++i) {
result.append(unicode_cpt_to_utf8(cps[i])); result.append(unicode_cpt_to_utf8(cps[i]));
@@ -34,7 +33,7 @@ static std::string unicode_cpts_to_utf8(const std::vector<uint32_t>& cps) {
return result; return result;
} }
uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) { uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
assert(offset < utf8.size()); assert(offset < utf8.size());
if (!(utf8[offset + 0] & 0x80)) { if (!(utf8[offset + 0] & 0x80)) {
auto result = utf8[offset + 0]; auto result = utf8[offset + 0];
@@ -45,7 +44,7 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
throw std::invalid_argument("invalid character"); throw std::invalid_argument("invalid character");
} }
if (!(utf8[offset + 0] & 0x20)) { if (!(utf8[offset + 0] & 0x20)) {
if (offset + 1 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80)) { if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) {
throw std::invalid_argument("invalid character"); throw std::invalid_argument("invalid character");
} }
auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f); auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
@@ -53,7 +52,7 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
return result; return result;
} }
if (!(utf8[offset + 0] & 0x10)) { if (!(utf8[offset + 0] & 0x10)) {
if (offset + 2 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80) || !((utf8[offset + 2] & 0xc0) == 0x80)) { if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) {
throw std::invalid_argument("invalid character"); throw std::invalid_argument("invalid character");
} }
auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f); auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
@@ -61,7 +60,7 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
return result; return result;
} }
if (!(utf8[offset + 0] & 0x08)) { if (!(utf8[offset + 0] & 0x08)) {
if (offset + 3 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80) || !((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) { if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
throw std::invalid_argument("invalid character"); throw std::invalid_argument("invalid character");
} }
auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f); auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
@@ -71,15 +70,15 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
throw std::invalid_argument("failed to convert utf8 to codepoint"); throw std::invalid_argument("failed to convert utf8 to codepoint");
} }
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) { //static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cpt) {
// std::vector<uint16_t> result; // std::vector<uint16_t> result;
// if (/* 0x0000 <= cp && */ cp <= 0xffff) { // if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
// result.emplace_back(cp); // result.emplace_back(cpt);
// return result; // return result;
// } // }
// if (0x10000 <= cp && cp <= 0x10ffff) { // if (0x10000 <= cpt && cpt <= 0x10ffff) {
// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); // result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); // result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
// return result; // return result;
// } // }
// throw std::invalid_argument("failed to convert codepoint to utf16"); // throw std::invalid_argument("failed to convert codepoint to utf16");
@@ -120,14 +119,14 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
// return result; // return result;
//} //}
static std::vector<codepoint_flags> unicode_cpt_flags_array() { static std::vector<unicode_cpt_flags> unicode_cpt_flags_array() {
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED); std::vector<unicode_cpt_flags> cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
assert(unicode_ranges_flags.front().first == 0); assert (unicode_ranges_flags.begin()[0].first == 0);
assert(unicode_ranges_flags.back().first == MAX_CODEPOINTS); assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) { for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
const auto range_ini = unicode_ranges_flags[i - 1]; // codepoint_ini, flags const auto range_ini = unicode_ranges_flags.begin()[i-1]; // codepoint_ini, flags
const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags const auto range_end = unicode_ranges_flags.begin()[i]; // codepoint_end, flags
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) { for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
cpt_flags[cpt] = range_ini.second; cpt_flags[cpt] = range_ini.second;
} }
@@ -145,7 +144,7 @@ static std::vector<codepoint_flags> unicode_cpt_flags_array() {
cpt_flags[p.second].is_uppercase = true; cpt_flags[p.second].is_uppercase = true;
} }
for (auto& range : unicode_ranges_nfd) { // start, last, nfd for (auto &range : unicode_ranges_nfd) { // start, last, nfd
cpt_flags[range.nfd].is_nfd = true; cpt_flags[range.nfd].is_nfd = true;
} }
@@ -200,55 +199,38 @@ static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
return map; return map;
} }
static inline bool is_valid_utf8(const std::string& str) { static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
int remaining_bytes = 0; // 当前多字节字符剩余的字节数
for (unsigned char c : str) {
if (remaining_bytes == 0) {
if ((c & 0x80) == 0x00) continue; // 1字节字符
else if ((c & 0xE0) == 0xC0) remaining_bytes = 1; // 2字节
else if ((c & 0xF0) == 0xE0) remaining_bytes = 2; // 3字节
else if ((c & 0xF8) == 0xF0) remaining_bytes = 3; // 4字节
else return false; // 非法起始字节
}
else {
// 检查后续字节是否为10xxxxxx
if ((c & 0xC0) != 0x80)
{
return false;
}
remaining_bytes--;
}
}
return (remaining_bytes == 0); // 确保多字节字符完整
}
static inline std::wstring unicode_wstring_from_utf8(const std::string& s) {
#if defined(__clang__) #if defined(__clang__)
// disable C++17 deprecation warning for std::codecvt_utf8 // disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push # pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations" # pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif #endif
bool isvalid = is_valid_utf8(s);
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv; std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
#if defined(__clang__) #if defined(__clang__)
# pragma clang diagnostic pop # pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif #endif
return conv.from_bytes(s); return conv.from_bytes(s);
} }
static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string>& bpe_words) { static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
std::vector<std::string> bpe_encoded_words; std::vector<std::string> bpe_encoded_words;
for (const auto& word : bpe_words) { for (const auto & word : bpe_words) {
std::string text_utf; std::string text_utf;
auto utf_word = unicode_cpts_from_utf8(word); auto utf_word = unicode_cpts_from_utf8(word);
for (size_t i = 0; i < utf_word.size(); ++i) { for (size_t i = 0; i < utf_word.size(); ++i) {
text_utf += unicode_cpt_to_utf8(utf_word[i]); text_utf += unicode_cpt_to_utf8(utf_word[i]);
} }
std::string encoded_token; std::string encoded_token;
for (char& c : text_utf) { for (char & c : text_utf) {
encoded_token += unicode_byte_to_utf8(c); encoded_token += unicode_byte_to_utf8(c);
} }
bpe_encoded_words.emplace_back(encoded_token); bpe_encoded_words.emplace_back(encoded_token);
@@ -257,7 +239,7 @@ static std::vector<std::string> unicode_byte_encoding_process(const std::vector<
} }
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string& text, const std::vector<size_t>& offsets) { static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & text, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets; // store the offset of each word std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
@@ -271,16 +253,16 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string& te
start = offset_end; start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
auto _get_cpt = [&](const size_t pos) -> uint32_t { auto _get_cpt = [&] (const size_t pos) -> uint32_t {
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
}; };
auto _get_flags = [&](const size_t pos) -> codepoint_flags { auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
auto _add_token = [&](const size_t end) -> size_t { auto _add_token = [&] (const size_t end) -> size_t {
assert(_prev_end <= end && end <= offset_end); assert(_prev_end <= end && end <= offset_end);
size_t len = end - _prev_end; size_t len = end - _prev_end;
if (len > 0) { if (len > 0) {
@@ -296,29 +278,29 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string& te
return len; return len;
}; };
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/) { for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const uint32_t cpt = _get_cpt(pos); const uint32_t cpt = _get_cpt(pos);
const auto flags = _get_flags(pos); const auto flags = _get_flags(pos);
// regex: 's|'t|'re|'ve|'m|'ll|'d // regex: 's|'t|'re|'ve|'m|'ll|'d
if (cpt == '\'' && pos + 1 < offset_end) { if (cpt == '\'' && pos+1 < offset_end) {
uint32_t cpt_next = _get_cpt(pos + 1); uint32_t cpt_next = _get_cpt(pos+1);
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
pos += _add_token(pos + 2); pos += _add_token(pos+2);
continue; continue;
} }
if (pos + 2 < offset_end) { if (pos+2 < offset_end) {
uint32_t cpt_next_next = _get_cpt(pos + 2); uint32_t cpt_next_next = _get_cpt(pos+2);
if ((cpt_next == 'r' && cpt_next_next == 'e') || if ((cpt_next == 'r' && cpt_next_next == 'e') ||
(cpt_next == 'v' && cpt_next_next == 'e') || (cpt_next == 'v' && cpt_next_next == 'e') ||
(cpt_next == 'l' && cpt_next_next == 'l')) { (cpt_next == 'l' && cpt_next_next == 'l')) {
pos += _add_token(pos + 3); pos += _add_token(pos+3);
continue; continue;
} }
} }
} }
auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags); auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
// regex: <space>?\p{L}+ // regex: <space>?\p{L}+
if (flags2.is_letter) { if (flags2.is_letter) {
pos += (cpt == ' '); pos += (cpt == ' ');
@@ -348,12 +330,12 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string& te
} }
size_t num_whitespaces = 0; size_t num_whitespaces = 0;
while (_get_flags(pos + num_whitespaces).is_whitespace) { while (_get_flags(pos+num_whitespaces).is_whitespace) {
num_whitespaces++; num_whitespaces++;
} }
// regex: \s+(?!\S) // regex: \s+(?!\S)
if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) { if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
pos += num_whitespaces - 1; pos += num_whitespaces - 1;
_add_token(pos); _add_token(pos);
continue; continue;
@@ -374,6 +356,207 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string& te
return bpe_offsets; return bpe_offsets;
} }
// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
const auto cpts = unicode_cpts_from_utf8(text);
size_t start = 0;
for (auto offset : offsets) {
const size_t offset_ini = start;
const size_t offset_end = start + offset;
assert(offset_end <= cpts.size());
start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
};
size_t _prev_end = offset_ini;
auto _add_token = [&] (const size_t end) -> size_t {
assert(_prev_end <= end && end <= offset_end);
size_t len = end - _prev_end;
if (len > 0) {
bpe_offsets.push_back(len);
}
_prev_end = end;
//if (len > 0) {
// std::string s = "";
// for(size_t p = end-len; p < end; p++)
// s += unicode_cpt_to_utf8(cpts[p]);
// printf(">>> '%s'\n", s.c_str());
//}
return len;
};
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const uint32_t cpt = _get_cpt(pos);
const auto flags = _get_flags(pos);
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
if (cpt == '\'' && pos+1 < offset_end) {
uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
pos += _add_token(pos+2);
continue;
}
if (pos+2 < offset_end) {
uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
(cpt_next == 'v' && cpt_next_next == 'e') ||
(cpt_next == 'l' && cpt_next_next == 'l')) {
pos += _add_token(pos+3);
continue;
}
}
}
// regex: [^\r\n\p{L}\p{N}]?\p{L}+
if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
pos++;
while (_get_flags(pos).is_letter) {
pos++;
}
_add_token(pos);
continue;
}
}
// regex: \p{N}{1,3}
if (flags.is_number) {
size_t ini = pos;
while (_get_flags(pos).is_number) {
if (++pos - ini >= 3 ) {
_add_token(pos);
ini = pos;
}
}
_add_token(pos);
continue;
}
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
pos += (cpt == ' ');
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
flags2 = _get_flags(++pos);
}
uint32_t cpt2 = _get_cpt(pos);
while (cpt2 == '\r' || cpt2 == '\n') {
cpt2 = _get_cpt(++pos);
}
_add_token(pos);
continue;
}
size_t num_whitespaces = 0;
size_t last_end_r_or_n = 0;
while (_get_flags(pos+num_whitespaces).is_whitespace) {
uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n') {
last_end_r_or_n = pos + num_whitespaces + 1;
}
num_whitespaces++;
}
// regex: \s*[\r\n]+
if (last_end_r_or_n > 0) {
pos = last_end_r_or_n;
_add_token(pos);
continue;
}
// regex: \s+(?!\S)
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
pos += num_whitespaces - 1;
_add_token(pos);
continue;
}
// regex: \s+
if (num_whitespaces > 0) {
pos += num_whitespaces;
_add_token(pos);
continue;
}
// no matches
_add_token(++pos);
}
}
return bpe_offsets;
}
// use std::wregex to split the text
static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) {
std::wregex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
std::wcregex_iterator end;
int64_t start_idx = 0;
while (it != end) {
std::wcmatch match = *it;
if (match.position() > start_idx) {
bpe_offsets.emplace_back(match.position() - start_idx);
}
bpe_offsets.emplace_back(match.length());
start_idx = match.position() + match.length();
++it;
}
if (start_idx < (int64_t) offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
start += offset;
}
return bpe_offsets;
}
// use std::regex to split the text
static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
std::regex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
std::cregex_iterator end;
int64_t start_idx = 0;
while (it != end) {
std::cmatch match = *it;
if (match.position() > start_idx) {
bpe_offsets.emplace_back(match.position() - start_idx);
}
bpe_offsets.emplace_back(match.length());
start_idx = match.position() + match.length();
++it;
}
if (start_idx < (int64_t) offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
start += offset;
}
return bpe_offsets;
}
// K2 system regex patterns (from tokenization_kimi.py): // K2 system regex patterns (from tokenization_kimi.py):
// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+ // [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector<size_t> & offsets) { static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector<size_t> & offsets) {
@@ -394,8 +577,8 @@ static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
}; };
auto _get_flags = [&] (const size_t pos) -> codepoint_flags { auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
@@ -546,220 +729,17 @@ static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string
return bpe_offsets; return bpe_offsets;
} }
// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string& text, const std::vector<size_t>& offsets) {
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
const auto cpts = unicode_cpts_from_utf8(text);
size_t start = 0;
for (auto offset : offsets) {
const size_t offset_ini = start;
const size_t offset_end = start + offset;
assert(offset_end <= cpts.size());
start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
auto _get_cpt = [&](const size_t pos) -> uint32_t {
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&](const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
};
size_t _prev_end = offset_ini;
auto _add_token = [&](const size_t end) -> size_t {
assert(_prev_end <= end && end <= offset_end);
size_t len = end - _prev_end;
if (len > 0) {
bpe_offsets.push_back(len);
}
_prev_end = end;
//if (len > 0) {
// std::string s = "";
// for(size_t p = end-len; p < end; p++)
// s += unicode_cpt_to_utf8(cpts[p]);
// printf(">>> '%s'\n", s.c_str());
//}
return len;
};
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/) {
const uint32_t cpt = _get_cpt(pos);
const auto flags = _get_flags(pos);
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
if (cpt == '\'' && pos + 1 < offset_end) {
uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1));
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
pos += _add_token(pos + 2);
continue;
}
if (pos + 2 < offset_end) {
uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2));
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
(cpt_next == 'v' && cpt_next_next == 'e') ||
(cpt_next == 'l' && cpt_next_next == 'l')) {
pos += _add_token(pos + 3);
continue;
}
}
}
// regex: [^\r\n\p{L}\p{N}]?\p{L}+
if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
if (flags.is_letter || _get_flags(pos + 1).is_letter) { // one or more letters
pos++;
while (_get_flags(pos).is_letter) {
pos++;
}
_add_token(pos);
continue;
}
}
// regex: \p{N}{1,3}
if (flags.is_number) {
size_t ini = pos;
while (_get_flags(pos).is_number) {
if (++pos - ini >= 3) {
_add_token(pos);
ini = pos;
}
}
_add_token(pos);
continue;
}
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
pos += (cpt == ' ');
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
flags2 = _get_flags(++pos);
}
uint32_t cpt2 = _get_cpt(pos);
while (cpt2 == '\r' || cpt2 == '\n') {
cpt2 = _get_cpt(++pos);
}
_add_token(pos);
continue;
}
size_t num_whitespaces = 0;
size_t last_end_r_or_n = 0;
while (_get_flags(pos + num_whitespaces).is_whitespace) {
uint32_t cpt2 = _get_cpt(pos + num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n') {
last_end_r_or_n = pos + num_whitespaces + 1;
}
num_whitespaces++;
}
// regex: \s*[\r\n]+
if (last_end_r_or_n > 0) {
pos = last_end_r_or_n;
_add_token(pos);
continue;
}
// regex: \s+(?!\S)
if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
pos += num_whitespaces - 1;
_add_token(pos);
continue;
}
// regex: \s+
if (num_whitespaces > 0) {
pos += num_whitespaces;
_add_token(pos);
continue;
}
// no matches
_add_token(++pos);
}
}
return bpe_offsets;
}
// use std::wregex to split the text
static std::vector<size_t> unicode_regex_split_stl(const std::wstring& wtext, const std::wstring& regex_expr, const std::vector<size_t>& offsets) {
std::wregex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
std::wcregex_iterator end;
int64_t start_idx = 0;
while (it != end) {
std::wcmatch match = *it;
if (match.position() > start_idx) {
bpe_offsets.emplace_back(match.position() - start_idx);
}
bpe_offsets.emplace_back(match.length());
start_idx = match.position() + match.length();
++it;
}
if (start_idx < (int64_t)offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
start += offset;
}
return bpe_offsets;
}
// use std::regex to split the text
static std::vector<size_t> unicode_regex_split_stl(const std::string& text, const std::string& regex_expr, const std::vector<size_t>& offsets) {
std::regex expr(regex_expr);
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
size_t start = 0;
for (auto offset : offsets) {
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
std::cregex_iterator end;
int64_t start_idx = 0;
while (it != end) {
std::cmatch match = *it;
if (match.position() > start_idx) {
bpe_offsets.emplace_back(match.position() - start_idx);
}
bpe_offsets.emplace_back(match.length());
start_idx = match.position() + match.length();
++it;
}
if (start_idx < (int64_t)offset) {
bpe_offsets.emplace_back(offset - start_idx);
}
start += offset;
}
return bpe_offsets;
}
static std::vector<size_t> unicode_regex_split_custom(const std::string& text, const std::string& regex_expr, const std::vector<size_t>& offsets) {
std::vector<size_t> bpe_offsets; std::vector<size_t> bpe_offsets;
if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") { if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets); bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
} } else if (
else if ( regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" || regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
bpe_offsets = unicode_regex_split_custom_llama3(text, offsets); bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
} } else if (regex_expr == "\\p{Han}+") {
else if (regex_expr == "\\p{Han}+") {
// K2's first pattern - handle all K2 patterns together // K2's first pattern - handle all K2 patterns together
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets); bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
} }
@@ -771,71 +751,100 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string& text, c
// interface // interface
// //
std::string unicode_cpt_to_utf8(uint32_t cp) { std::string unicode_cpt_to_utf8(uint32_t cpt) {
std::string result; std::string result;
if (/* 0x00 <= cp && */ cp <= 0x7f) { if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
result.push_back(cp); result.push_back(cpt);
return result; return result;
} }
if (0x80 <= cp && cp <= 0x7ff) { if (0x80 <= cpt && cpt <= 0x7ff) {
result.push_back(0xc0 | ((cp >> 6) & 0x1f)); result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
result.push_back(0x80 | (cp & 0x3f)); result.push_back(0x80 | (cpt & 0x3f));
return result; return result;
} }
if (0x800 <= cp && cp <= 0xffff) { if (0x800 <= cpt && cpt <= 0xffff) {
result.push_back(0xe0 | ((cp >> 12) & 0x0f)); result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
result.push_back(0x80 | ((cp >> 6) & 0x3f)); result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cp & 0x3f)); result.push_back(0x80 | (cpt & 0x3f));
return result; return result;
} }
if (0x10000 <= cp && cp <= 0x10ffff) { if (0x10000 <= cpt && cpt <= 0x10ffff) {
result.push_back(0xf0 | ((cp >> 18) & 0x07)); result.push_back(0xf0 | ((cpt >> 18) & 0x07));
result.push_back(0x80 | ((cp >> 12) & 0x3f)); result.push_back(0x80 | ((cpt >> 12) & 0x3f));
result.push_back(0x80 | ((cp >> 6) & 0x3f)); result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cp & 0x3f)); result.push_back(0x80 | (cpt & 0x3f));
return result; return result;
} }
throw std::invalid_argument("invalid codepoint"); throw std::invalid_argument("invalid codepoint");
} }
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t>& cpts) { std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
auto comp = [](const uint32_t cpt, const range_nfd& range) { auto comp = [] (const uint32_t cpt, const range_nfd & range) {
return cpt < range.first; return cpt < range.first;
}; };
std::vector<uint32_t> result(cpts.size()); std::vector<uint32_t> result(cpts.size());
for (size_t i = 0; i < cpts.size(); ++i) { for (size_t i = 0; i < cpts.size(); ++i) {
const uint32_t cpt = cpts[i]; const uint32_t cpt = cpts[i];
auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1; auto it = std::upper_bound(unicode_ranges_nfd.begin(), unicode_ranges_nfd.end(), cpt, comp) - 1;
result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt; result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
} }
return result; return result;
} }
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string& utf8) { std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
std::vector<uint32_t> result; std::vector<uint32_t> result;
result.reserve(utf8.size()); result.reserve(utf8.size());
size_t offset = 0; size_t offset = 0;
while (offset < utf8.size()) { while (offset < utf8.size()) {
result.push_back(unicode_cpt_from_utf8(utf8, offset)); try {
result.push_back(unicode_cpt_from_utf8(utf8, offset));
}
catch (const std::invalid_argument & /*ex*/) {
// Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
++offset;
result.emplace_back(0xFFFD); // replacement character
}
} }
return result; return result;
} }
codepoint_flags unicode_cpt_flags(const uint32_t cp) { unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED); static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
static const auto cpt_flags = unicode_cpt_flags_array(); static const auto cpt_flags = unicode_cpt_flags_array();
return cp < cpt_flags.size() ? cpt_flags[cp] : undef; return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
} }
codepoint_flags unicode_cpt_flags(const std::string& utf8) { unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED); static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
if (utf8.empty()) { if (utf8.empty()) {
return undef; // undefined return undef; // undefined
} }
size_t offset = 0; size_t offset = 0;
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset)); return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
}
std::string unicode_byte_to_utf8(uint8_t byte) {
static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
return map.at(byte);
}
uint8_t unicode_utf8_to_byte(const std::string & utf8) {
static std::unordered_map<std::string, uint8_t> map = unicode_utf8_to_byte_map();
return map.at(utf8);
}
uint32_t unicode_tolower(uint32_t cpt) {
// binary search
auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
[](const std::pair<uint32_t, uint32_t> & pair, uint32_t value) {
return pair.first < value;
});
if (it != unicode_map_lowercase.end() && it->first == cpt) {
return it->second;
}
return cpt; // Return the original code point if no lowercase mapping is found
} }
bool unicode_cpt_is_han(uint32_t cpt) { bool unicode_cpt_is_han(uint32_t cpt) {
@@ -870,53 +879,37 @@ bool unicode_cpt_is_han(uint32_t cpt) {
return false; return false;
} }
std::string unicode_byte_to_utf8(uint8_t byte) { std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
return map.at(byte);
}
uint8_t unicode_utf8_to_byte(const std::string& utf8) {
static std::unordered_map<std::string, uint8_t> map = unicode_utf8_to_byte_map();
return map.at(utf8);
}
uint32_t unicode_tolower(uint32_t cp) {
auto it = unicode_map_lowercase.find(cp);
return it == unicode_map_lowercase.end() ? cp : it->second;
}
std::vector<std::string> unicode_regex_split(const std::string& text, const std::vector<std::string>& regex_exprs) {
// unicode categories // unicode categories
static const std::map<std::string, int> k_ucat_enum = { static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", codepoint_flags::NUMBER }, { "\\p{N}", unicode_cpt_flags::NUMBER },
{ "\\p{L}", codepoint_flags::LETTER }, { "\\p{L}", unicode_cpt_flags::LETTER },
{ "\\p{P}", codepoint_flags::PUNCTUATION }, { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
{ "\\p{M}", codepoint_flags::ACCENT_MARK }, { "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
{ "\\p{S}", codepoint_flags::SYMBOL }, { "\\p{S}", unicode_cpt_flags::SYMBOL },
}; };
static const std::map<int, int> k_ucat_cpt = { static const std::map<int, int> k_ucat_cpt = {
{ codepoint_flags::NUMBER, 0xD1 }, { unicode_cpt_flags::NUMBER, 0xD1 },
{ codepoint_flags::LETTER, 0xD2 }, { unicode_cpt_flags::LETTER, 0xD2 },
{ codepoint_flags::PUNCTUATION, 0xD3 }, { unicode_cpt_flags::PUNCTUATION, 0xD3 },
{ codepoint_flags::ACCENT_MARK, 0xD4 }, { unicode_cpt_flags::ACCENT_MARK, 0xD4 },
{ codepoint_flags::SYMBOL, 0xD5 }, { unicode_cpt_flags::SYMBOL, 0xD5 },
}; };
static const std::map<int, std::string> k_ucat_map = { static const std::map<int, std::string> k_ucat_map = {
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9 { unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z { unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}i { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
{ codepoint_flags::ACCENT_MARK, "" }, // no sub-128 codepoints { unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
{ codepoint_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`| { unicode_cpt_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
}; };
// compute collapsed codepoints only if needed by at least one regex // compute collapsed codepoints only if needed by at least one regex
bool need_collapse = false; bool need_collapse = false;
for (auto& regex_expr : regex_exprs) { for (const auto & regex_expr : regex_exprs) {
// search for unicode categories // search for unicode categories
for (const auto& ucat : k_ucat_enum) { for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) { if (std::string::npos != regex_expr.find(ucat.first)) {
need_collapse = true; need_collapse = true;
break; break;
@@ -927,7 +920,7 @@ std::vector<std::string> unicode_regex_split(const std::string& text, const std:
const auto cpts = unicode_cpts_from_utf8(text); const auto cpts = unicode_cpts_from_utf8(text);
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935 // ref: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2081479935
std::string text_collapsed; std::string text_collapsed;
if (need_collapse) { if (need_collapse) {
// collapse all unicode categories // collapse all unicode categories
@@ -940,25 +933,23 @@ std::vector<std::string> unicode_regex_split(const std::string& text, const std:
continue; continue;
} }
const auto flags = unicode_cpt_flags(cpts[i]); const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
if (flags.is_whitespace) { if (flags.is_whitespace) {
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does. //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback //text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
text_collapsed[i] = (char)0x0B; // <vertical tab> as whitespace fallback text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
} } else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag()); text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
} } else {
else { text_collapsed[i] = (char) 0xD0; // fallback
text_collapsed[i] = (char)0xD0; // fallback
} }
} }
} }
std::vector<size_t> bpe_offsets = { cpts.size() }; std::vector<size_t> bpe_offsets = { cpts.size() };
for (auto& regex_expr : regex_exprs) { for (const auto & regex_expr : regex_exprs) {
// first, see if we have an efficient custom regex implementation // first, see if we have an efficient custom regex implementation
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets); auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
@@ -972,7 +963,7 @@ std::vector<std::string> unicode_regex_split(const std::string& text, const std:
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
// with the corresponding collapsed representation // with the corresponding collapsed representation
bool use_collapsed = false; bool use_collapsed = false;
for (auto& ucat : k_ucat_enum) { for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) { if (std::string::npos != regex_expr.find(ucat.first)) {
use_collapsed = true; use_collapsed = true;
break; break;
@@ -1031,15 +1022,14 @@ std::vector<std::string> unicode_regex_split(const std::string& text, const std:
//printf("text_collapsed: %s\n", text_collapsed.c_str()); //printf("text_collapsed: %s\n", text_collapsed.c_str());
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str()); //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets); bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
} } else {
else {
// no unicode category used, we can use std::wregex directly // no unicode category used, we can use std::wregex directly
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr); const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
std::wstring wtext(cpts.begin(), cpts.end()); std::wstring wtext(cpts.begin(), cpts.end());
for (size_t i = 0; i < wtext.size(); ++i) { for (size_t i = 0; i < wtext.size(); ++i) {
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) { if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
wtext[i] = 0x0B; wtext[i] = 0x0B;
} }
} }
@@ -1048,8 +1038,7 @@ std::vector<std::string> unicode_regex_split(const std::string& text, const std:
//printf("regex_expr: %s\n", regex_expr.c_str()); //printf("regex_expr: %s\n", regex_expr.c_str());
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets); bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
} }
} } catch (std::regex_error & e) {
catch (std::regex_error& e) {
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str()); fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
fprintf(stderr, "Regex error: %s\n", e.what()); fprintf(stderr, "Regex error: %s\n", e.what());
throw std::runtime_error("Failed to process regex"); throw std::runtime_error("Failed to process regex");
@@ -1060,7 +1049,7 @@ std::vector<std::string> unicode_regex_split(const std::string& text, const std:
bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
size_t start = 0; size_t start = 0;
for (size_t& offset : bpe_offsets) { for (size_t & offset : bpe_offsets) {
bpe_words.emplace_back(); bpe_words.emplace_back();
for (size_t i = start; i < start + offset; ++i) { for (size_t i = start; i < start + offset; ++i) {
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]); bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);

View File

@@ -4,9 +4,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
// TODO: prefix all symbols with "llama_" struct unicode_cpt_flags {
struct codepoint_flags {
enum { enum {
UNDEFINED = 0x0001, UNDEFINED = 0x0001,
NUMBER = 0x0002, // regex: \p{N} NUMBER = 0x0002, // regex: \p{N}
@@ -35,7 +33,7 @@ struct codepoint_flags {
uint16_t is_nfd : 1; uint16_t is_nfd : 1;
// decode from uint16 // decode from uint16
inline codepoint_flags(const uint16_t flags=0) { inline unicode_cpt_flags(const uint16_t flags = 0) {
*reinterpret_cast<uint16_t*>(this) = flags; *reinterpret_cast<uint16_t*>(this) = flags;
} }
@@ -50,19 +48,20 @@ struct codepoint_flags {
size_t unicode_len_utf8(char src); size_t unicode_len_utf8(char src);
std::string unicode_cpt_to_utf8(uint32_t cp); std::string unicode_cpt_to_utf8 (uint32_t cpt);
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset); uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8); std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts); std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
codepoint_flags unicode_cpt_flags(const uint32_t cp); unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt);
codepoint_flags unicode_cpt_flags(const std::string & utf8); unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8);
std::string unicode_byte_to_utf8(uint8_t byte); std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8); uint8_t unicode_utf8_to_byte(const std::string & utf8);
uint32_t unicode_tolower(uint32_t cp); uint32_t unicode_tolower(uint32_t cpt);
bool unicode_cpt_is_han(uint32_t cpt); bool unicode_cpt_is_han(uint32_t cpt);