mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
Enable CUDA graphs for MoE models + GPT-OSS support (#689)
* gmp-oss: common * gpt-oss: attnetion sinks, swiglu_oai * gpt-oss: WIP llama Model loads and runs (CPU only), but PPL is much to high (~1500 for 1st batch vs ~200 in mainline). Is it because of SWA, because of vocab, or did I introduce a bug somewhere? * gpt-oss: CPU seems to be working It was the SWA thta was missing in the previous commit. There are issues with EOG tokens, so this still needs to be added. * CUDA: ADD_ID Just a copy from mainline * gpt-oss: Seems to be working on CUDA * gpt-oss: add sinks to the attn-vec kernels * CUDA: add head size of 64 to new mma Haven't turned it on yet, but observe slightly better PP and slightly worse TG performance with that. * gpt-oss: add ability to use -fmoe (only CUDA for now) * Move row sums to the write place * Add sinks to iqk flash attention * gpt_oss: Implement -fmoe on the CPU * Simdify swiglu_oai Turning it off for now as performance becomes more variable, so perhaps I'm running into thermal trottling imore often because of making the CPU work too hard. * llama: factor out model loader * Builds successfully * It runs, but mmap does not work * Fix llama_mmap so mmap works * Minor * Fix CUDA after latest changes * Attempt to use CUDA graphs with MoE models - not working * CUDA graphs WIP - still not working * CUDA graphs - seems to be working Likely not all MLA variants are working. I no longer remember why I added the q8_0 cpy that transposes the tensor, but if really needed, this is now missing. Also missing is q6_0. * Make q8_0 cache work for DeepSeek models with CUDA graphs * cuda: cpy for q6_0 * Fix llama_mmap on non-Linux platforms * Adding forgotten file * Iterating on Windows build failures * cuda: re-add q8_0 -> q8_0 transpose so mla = 2 can be used with CUDA graphs and q8_0 cache. * Disable graphs without -fmoe * Minor * Turn graphs on by default --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -24,9 +24,9 @@ class common_chat_msg_parser {
|
|||||||
std::string prelude;
|
std::string prelude;
|
||||||
std::vector<common_string_range> groups;
|
std::vector<common_string_range> groups;
|
||||||
};
|
};
|
||||||
|
|
||||||
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||||
|
|
||||||
// Accessors
|
// Accessors
|
||||||
const std::string & input() const { return input_; }
|
const std::string & input() const { return input_; }
|
||||||
size_t pos() const { return pos_; }
|
size_t pos() const { return pos_; }
|
||||||
@@ -42,7 +42,7 @@ class common_chat_msg_parser {
|
|||||||
}
|
}
|
||||||
pos_ = pos;
|
pos_ = pos;
|
||||||
}
|
}
|
||||||
|
|
||||||
void move_back(size_t n) {
|
void move_back(size_t n) {
|
||||||
if (pos_ < n) {
|
if (pos_ < n) {
|
||||||
throw std::runtime_error("Can't move back that far!");
|
throw std::runtime_error("Can't move back that far!");
|
||||||
@@ -56,46 +56,46 @@ class common_chat_msg_parser {
|
|||||||
// Content manipulation
|
// Content manipulation
|
||||||
void add_content(const std::string & content);
|
void add_content(const std::string & content);
|
||||||
void add_reasoning_content(const std::string & reasoning_content);
|
void add_reasoning_content(const std::string & reasoning_content);
|
||||||
|
|
||||||
// Tool call manipulation
|
// Tool call manipulation
|
||||||
void add_tool_call(const common_chat_tool_call & tool_call);
|
void add_tool_call(const common_chat_tool_call & tool_call);
|
||||||
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
|
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
|
||||||
bool add_tool_call(const json & tool_call);
|
bool add_tool_call(const json & tool_call);
|
||||||
bool add_tool_calls(const json & arr);
|
bool add_tool_calls(const json & arr);
|
||||||
void clear_tools();
|
void clear_tools();
|
||||||
|
|
||||||
// Parsing utilities
|
// Parsing utilities
|
||||||
std::string consume_rest();
|
std::string consume_rest();
|
||||||
bool try_consume_literal(const std::string & literal);
|
bool try_consume_literal(const std::string & literal);
|
||||||
void consume_literal(const std::string & literal);
|
void consume_literal(const std::string & literal);
|
||||||
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
|
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
|
||||||
|
|
||||||
// Regex-based parsing methods (new)
|
// Regex-based parsing methods (new)
|
||||||
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
|
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
|
||||||
find_regex_result consume_regex(const common_regex & regex);
|
find_regex_result consume_regex(const common_regex & regex);
|
||||||
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
|
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
|
||||||
|
|
||||||
// Progressive parsing primitives (for Phase 4)
|
// Progressive parsing primitives (for Phase 4)
|
||||||
std::optional<find_regex_result> try_find_literal(const std::string & literal);
|
std::optional<find_regex_result> try_find_literal(const std::string & literal);
|
||||||
bool consume_spaces();
|
bool consume_spaces();
|
||||||
void set_healing_marker(const std::string & marker);
|
void set_healing_marker(const std::string & marker);
|
||||||
|
|
||||||
|
|
||||||
// Main parsing entry point
|
// Main parsing entry point
|
||||||
void parse();
|
void parse();
|
||||||
|
|
||||||
// Finishing
|
// Finishing
|
||||||
void finish();
|
void finish();
|
||||||
|
|
||||||
// Result extraction
|
// Result extraction
|
||||||
common_chat_msg result_and_reset();
|
common_chat_msg result_and_reset();
|
||||||
|
|
||||||
// Advanced JSON parsing (following original llama.cpp patterns)
|
// Advanced JSON parsing (following original llama.cpp patterns)
|
||||||
struct consume_json_result {
|
struct consume_json_result {
|
||||||
json value;
|
json value;
|
||||||
bool is_partial;
|
bool is_partial;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::optional<common_json> try_consume_json();
|
std::optional<common_json> try_consume_json();
|
||||||
common_json consume_json();
|
common_json consume_json();
|
||||||
consume_json_result consume_json_with_dumped_args(
|
consume_json_result consume_json_with_dumped_args(
|
||||||
@@ -112,8 +112,8 @@ private:
|
|||||||
void parse_kimi_k2_format();
|
void parse_kimi_k2_format();
|
||||||
void parse_deepseek_r1_format();
|
void parse_deepseek_r1_format();
|
||||||
void parse_generic_format();
|
void parse_generic_format();
|
||||||
|
|
||||||
|
|
||||||
// JSON parsing utilities (enhanced streaming support)
|
// JSON parsing utilities (enhanced streaming support)
|
||||||
struct json_parse_result {
|
struct json_parse_result {
|
||||||
json value;
|
json value;
|
||||||
@@ -121,11 +121,11 @@ private:
|
|||||||
bool is_partial;
|
bool is_partial;
|
||||||
std::string healing_marker;
|
std::string healing_marker;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Partial detection utilities
|
// Partial detection utilities
|
||||||
bool detect_partial_function_call(const std::string& content);
|
bool detect_partial_function_call(const std::string& content);
|
||||||
void handle_partial_detection();
|
void handle_partial_detection();
|
||||||
|
|
||||||
// Legacy find_literal for compatibility
|
// Legacy find_literal for compatibility
|
||||||
std::optional<find_regex_result> try_find_literal_legacy(const std::string & literal);
|
std::optional<find_regex_result> try_find_literal_legacy(const std::string & literal);
|
||||||
};
|
};
|
||||||
@@ -133,4 +133,4 @@ private:
|
|||||||
// Main parsing function (public API)
|
// Main parsing function (public API)
|
||||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||||
|
|
||||||
// Content-only parsing for fallback scenarios (static internal function)
|
// Content-only parsing for fallback scenarios (static internal function)
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
|
|||||||
|
|
||||||
// Check for the new tools array format first (no DeepSeek markers)
|
// Check for the new tools array format first (no DeepSeek markers)
|
||||||
auto original_pos = builder.pos();
|
auto original_pos = builder.pos();
|
||||||
|
|
||||||
// First, try the tools array format for content like "function\n```json\n{"tools": [...]}"
|
// First, try the tools array format for content like "function\n```json\n{"tools": [...]}"
|
||||||
if (builder.try_find_regex(function_regex_simple)) {
|
if (builder.try_find_regex(function_regex_simple)) {
|
||||||
builder.move_to(original_pos);
|
builder.move_to(original_pos);
|
||||||
@@ -231,7 +231,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
|
|||||||
// Fall through to try standard DeepSeek patterns
|
// Fall through to try standard DeepSeek patterns
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If tools array format didn't work, try XML-wrapped format
|
// If tools array format didn't work, try XML-wrapped format
|
||||||
builder.move_to(original_pos);
|
builder.move_to(original_pos);
|
||||||
try {
|
try {
|
||||||
@@ -240,7 +240,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
|
|||||||
} catch (const common_chat_msg_partial_exception&) {
|
} catch (const common_chat_msg_partial_exception&) {
|
||||||
// Fall through to try standard DeepSeek patterns
|
// Fall through to try standard DeepSeek patterns
|
||||||
}
|
}
|
||||||
|
|
||||||
// If XML wrapper format didn't work, try standard DeepSeek patterns
|
// If XML wrapper format didn't work, try standard DeepSeek patterns
|
||||||
builder.move_to(original_pos);
|
builder.move_to(original_pos);
|
||||||
try {
|
try {
|
||||||
@@ -278,7 +278,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
|
|||||||
throw; // Re-throw for partial mode
|
throw; // Re-throw for partial mode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add any remaining content (critical for responses without tool calls)
|
// Add any remaining content (critical for responses without tool calls)
|
||||||
builder.add_content(builder.consume_rest());
|
builder.add_content(builder.consume_rest());
|
||||||
}
|
}
|
||||||
@@ -286,19 +286,19 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
|
|||||||
// Parse DeepSeek R1 tools array format following original llama.cpp parse_prefixed_json_tool_call_array pattern
|
// Parse DeepSeek R1 tools array format following original llama.cpp parse_prefixed_json_tool_call_array pattern
|
||||||
static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) {
|
static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) {
|
||||||
static const common_regex prefix("function\n```json\n");
|
static const common_regex prefix("function\n```json\n");
|
||||||
|
|
||||||
|
|
||||||
if (auto res = builder.try_find_regex(prefix)) {
|
if (auto res = builder.try_find_regex(prefix)) {
|
||||||
// Parse JSON and manually process tools array to convert arguments to strings
|
// Parse JSON and manually process tools array to convert arguments to strings
|
||||||
auto json_result = builder.try_consume_json();
|
auto json_result = builder.try_consume_json();
|
||||||
if (!json_result) {
|
if (!json_result) {
|
||||||
throw common_chat_msg_partial_exception("invalid JSON");
|
throw common_chat_msg_partial_exception("invalid JSON");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// DeepSeek R1 format has "tools" array, manually process each tool
|
// DeepSeek R1 format has "tools" array, manually process each tool
|
||||||
if (json_result->json.contains("tools") && json_result->json.at("tools").is_array()) {
|
if (json_result->json.contains("tools") && json_result->json.at("tools").is_array()) {
|
||||||
|
|
||||||
// Manually create tool calls array with string arguments (following original pattern)
|
// Manually create tool calls array with string arguments (following original pattern)
|
||||||
json tools_with_dumped_args = json::array();
|
json tools_with_dumped_args = json::array();
|
||||||
for (const auto& tool : json_result->json.at("tools")) {
|
for (const auto& tool : json_result->json.at("tools")) {
|
||||||
@@ -310,15 +310,15 @@ static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) {
|
|||||||
tools_with_dumped_args.push_back(formatted_tool);
|
tools_with_dumped_args.push_back(formatted_tool);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (!builder.add_tool_calls(tools_with_dumped_args) || !json_result->healing_marker.marker.empty()) {
|
if (!builder.add_tool_calls(tools_with_dumped_args) || !json_result->healing_marker.marker.empty()) {
|
||||||
throw common_chat_msg_partial_exception("incomplete tool call array");
|
throw common_chat_msg_partial_exception("incomplete tool call array");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
throw common_chat_msg_partial_exception("tools key not found or not array");
|
throw common_chat_msg_partial_exception("tools key not found or not array");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Consume closing ```
|
// Consume closing ```
|
||||||
builder.try_consume_regex(common_regex("```"));
|
builder.try_consume_regex(common_regex("```"));
|
||||||
} else {
|
} else {
|
||||||
@@ -326,41 +326,41 @@ static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse DeepSeek R1 XML-wrapped format following original Hermes-2-Pro pattern
|
// Parse DeepSeek R1 XML-wrapped format following original Hermes-2-Pro pattern
|
||||||
static void parse_deepseek_r1_xml_wrapped(common_chat_msg_parser & builder) {
|
static void parse_deepseek_r1_xml_wrapped(common_chat_msg_parser & builder) {
|
||||||
|
|
||||||
// Pattern for: <tool_call>\nfunction</think>FunctionName\n```json\n{...}\n```\n</tool_call>
|
// Pattern for: <tool_call>\nfunction</think>FunctionName\n```json\n{...}\n```\n</tool_call>
|
||||||
static const common_regex xml_pattern(
|
static const common_regex xml_pattern(
|
||||||
"<tool_call>\\s*" // Opening XML tag
|
"<tool_call>\\s*" // Opening XML tag
|
||||||
"function</think>([^\\n]+)" // Function name after "function</think>"
|
"function</think>([^\\n]+)" // Function name after "function</think>"
|
||||||
"\\s*```json\\s*" // JSON block start
|
"\\s*```json\\s*" // JSON block start
|
||||||
);
|
);
|
||||||
|
|
||||||
if (auto res = builder.try_find_regex(xml_pattern)) {
|
if (auto res = builder.try_find_regex(xml_pattern)) {
|
||||||
|
|
||||||
// Extract function name from capture group
|
// Extract function name from capture group
|
||||||
std::string function_name = builder.str(res->groups[1]);
|
std::string function_name = builder.str(res->groups[1]);
|
||||||
|
|
||||||
// Parse JSON arguments
|
// Parse JSON arguments
|
||||||
auto json_result = builder.try_consume_json();
|
auto json_result = builder.try_consume_json();
|
||||||
if (!json_result) {
|
if (!json_result) {
|
||||||
throw common_chat_msg_partial_exception("invalid JSON in XML wrapper");
|
throw common_chat_msg_partial_exception("invalid JSON in XML wrapper");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Create single tool call following original pattern
|
// Create single tool call following original pattern
|
||||||
json tool_call;
|
json tool_call;
|
||||||
tool_call["name"] = function_name;
|
tool_call["name"] = function_name;
|
||||||
tool_call["arguments"] = json_result->json.dump(); // Convert to string
|
tool_call["arguments"] = json_result->json.dump(); // Convert to string
|
||||||
|
|
||||||
json tool_calls_array = json::array();
|
json tool_calls_array = json::array();
|
||||||
tool_calls_array.push_back(tool_call);
|
tool_calls_array.push_back(tool_call);
|
||||||
|
|
||||||
|
|
||||||
if (!builder.add_tool_calls(tool_calls_array) || !json_result->healing_marker.marker.empty()) {
|
if (!builder.add_tool_calls(tool_calls_array) || !json_result->healing_marker.marker.empty()) {
|
||||||
throw common_chat_msg_partial_exception("incomplete XML wrapped tool call");
|
throw common_chat_msg_partial_exception("incomplete XML wrapped tool call");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Consume closing ```\n</tool_call>
|
// Consume closing ```\n</tool_call>
|
||||||
builder.try_consume_regex(common_regex("```\\s*</tool_call>"));
|
builder.try_consume_regex(common_regex("```\\s*</tool_call>"));
|
||||||
} else {
|
} else {
|
||||||
@@ -384,6 +384,15 @@ static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
|
|||||||
builder.add_content(kimi_k2::clean_content(builder.input()));
|
builder.add_content(kimi_k2::clean_content(builder.input()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
|
||||||
|
// TODO @ngxson : this won't work with --special enabled, we should fix that
|
||||||
|
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
|
||||||
|
if (!builder.syntax().enable_tool_calls) {
|
||||||
|
builder.add_content(builder.consume_rest());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Main parsing dispatch function
|
// Main parsing dispatch function
|
||||||
static void common_chat_parse(common_chat_msg_parser & builder) {
|
static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||||
switch (builder.syntax().format) {
|
switch (builder.syntax().format) {
|
||||||
@@ -399,6 +408,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
|||||||
case COMMON_CHAT_FORMAT_KIMI_K2:
|
case COMMON_CHAT_FORMAT_KIMI_K2:
|
||||||
common_chat_parse_kimi_k2(builder);
|
common_chat_parse_kimi_k2(builder);
|
||||||
break;
|
break;
|
||||||
|
case COMMON_CHAT_FORMAT_GPT_OSS:
|
||||||
|
common_chat_parse_gpt_oss(builder);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||||
}
|
}
|
||||||
@@ -432,6 +444,19 @@ const char* common_chat_format_name(common_chat_format format) {
|
|||||||
case COMMON_CHAT_FORMAT_GENERIC: return "generic";
|
case COMMON_CHAT_FORMAT_GENERIC: return "generic";
|
||||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "deepseek_r1";
|
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "deepseek_r1";
|
||||||
case COMMON_CHAT_FORMAT_KIMI_K2: return "kimi_k2";
|
case COMMON_CHAT_FORMAT_KIMI_K2: return "kimi_k2";
|
||||||
|
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
|
||||||
default: return "unknown";
|
default: return "unknown";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char * common_reasoning_format_name(common_reasoning_format format) {
|
||||||
|
switch (format) {
|
||||||
|
case COMMON_REASONING_FORMAT_NONE: return "none";
|
||||||
|
case COMMON_REASONING_FORMAT_AUTO: return "auto";
|
||||||
|
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
|
||||||
|
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unknown reasoning format");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,20 +13,20 @@ struct common_chat_templates;
|
|||||||
struct common_string_range {
|
struct common_string_range {
|
||||||
size_t begin;
|
size_t begin;
|
||||||
size_t end;
|
size_t end;
|
||||||
|
|
||||||
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
|
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
|
||||||
if (begin > end) {
|
if (begin > end) {
|
||||||
throw std::runtime_error("Invalid range");
|
throw std::runtime_error("Invalid range");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// prevent default ctor
|
// prevent default ctor
|
||||||
common_string_range() = delete;
|
common_string_range() = delete;
|
||||||
|
|
||||||
bool empty() const {
|
bool empty() const {
|
||||||
return begin == end;
|
return begin == end;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(const common_string_range & other) const {
|
bool operator==(const common_string_range & other) const {
|
||||||
return begin == other.begin && end == other.end;
|
return begin == other.begin && end == other.end;
|
||||||
}
|
}
|
||||||
@@ -40,7 +40,7 @@ struct common_chat_tool_call {
|
|||||||
bool operator==(const common_chat_tool_call & other) const {
|
bool operator==(const common_chat_tool_call & other) const {
|
||||||
return name == other.name && arguments == other.arguments && id == other.id;
|
return name == other.name && arguments == other.arguments && id == other.id;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator!=(const common_chat_tool_call & other) const {
|
bool operator!=(const common_chat_tool_call & other) const {
|
||||||
return !(*this == other);
|
return !(*this == other);
|
||||||
}
|
}
|
||||||
@@ -65,10 +65,10 @@ struct common_chat_msg {
|
|||||||
std::string tool_call_id;
|
std::string tool_call_id;
|
||||||
|
|
||||||
bool empty() const {
|
bool empty() const {
|
||||||
return content.empty() && content_parts.empty() && tool_calls.empty() &&
|
return content.empty() && content_parts.empty() && tool_calls.empty() &&
|
||||||
reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
||||||
for (auto i = 0u; i < tool_calls.size(); i++) {
|
for (auto i = 0u; i < tool_calls.size(); i++) {
|
||||||
if (ids_cache.size() <= i) {
|
if (ids_cache.size() <= i) {
|
||||||
@@ -91,7 +91,7 @@ struct common_chat_msg {
|
|||||||
&& tool_name == other.tool_name
|
&& tool_name == other.tool_name
|
||||||
&& tool_call_id == other.tool_call_id;
|
&& tool_call_id == other.tool_call_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator!=(const common_chat_msg & other) const {
|
bool operator!=(const common_chat_msg & other) const {
|
||||||
return !(*this == other);
|
return !(*this == other);
|
||||||
}
|
}
|
||||||
@@ -110,7 +110,7 @@ struct common_chat_msg_diff {
|
|||||||
&& tool_call_index == other.tool_call_index
|
&& tool_call_index == other.tool_call_index
|
||||||
&& tool_call_delta == other.tool_call_delta;
|
&& tool_call_delta == other.tool_call_delta;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator!=(const common_chat_msg_diff & other) const {
|
bool operator!=(const common_chat_msg_diff & other) const {
|
||||||
return !(*this == other);
|
return !(*this == other);
|
||||||
}
|
}
|
||||||
@@ -132,18 +132,20 @@ enum common_chat_format {
|
|||||||
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||||
COMMON_CHAT_FORMAT_GENERIC,
|
COMMON_CHAT_FORMAT_GENERIC,
|
||||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||||
|
COMMON_CHAT_FORMAT_GPT_OSS,
|
||||||
COMMON_CHAT_FORMAT_KIMI_K2, // Our custom format (keep last for backward compatibility)
|
COMMON_CHAT_FORMAT_KIMI_K2, // Our custom format (keep last for backward compatibility)
|
||||||
};
|
};
|
||||||
|
|
||||||
enum common_reasoning_format {
|
enum common_reasoning_format {
|
||||||
COMMON_REASONING_FORMAT_NONE,
|
COMMON_REASONING_FORMAT_NONE,
|
||||||
|
COMMON_REASONING_FORMAT_AUTO,
|
||||||
COMMON_REASONING_FORMAT_DEEPSEEK,
|
COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||||
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY,
|
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_chat_syntax {
|
struct common_chat_syntax {
|
||||||
common_chat_format format = COMMON_CHAT_FORMAT_KIMI_K2;
|
common_chat_format format = COMMON_CHAT_FORMAT_KIMI_K2;
|
||||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; //COMMON_REASONING_FORMAT_NONE;
|
||||||
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||||
bool reasoning_in_content = false;
|
bool reasoning_in_content = false;
|
||||||
bool thinking_forced_open = false;
|
bool thinking_forced_open = false;
|
||||||
@@ -165,11 +167,12 @@ class common_chat_msg_partial_exception : public std::runtime_error {
|
|||||||
// Format detection from chat template
|
// Format detection from chat template
|
||||||
common_chat_format common_chat_format_detect(const std::string & chat_template);
|
common_chat_format common_chat_format_detect(const std::string & chat_template);
|
||||||
const char* common_chat_format_name(common_chat_format format);
|
const char* common_chat_format_name(common_chat_format format);
|
||||||
|
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||||
|
|
||||||
// Main parsing function (entry point for original llama.cpp compatibility)
|
// Main parsing function (entry point for original llama.cpp compatibility)
|
||||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||||
|
|
||||||
// Forward declare parser class
|
// Forward declare parser class
|
||||||
class common_chat_msg_parser;
|
class common_chat_msg_parser;
|
||||||
|
|
||||||
// Format-specific parsing functions (accessible from chat-parser)
|
// Format-specific parsing functions (accessible from chat-parser)
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
|
|
||||||
const llama_vocab * vocab = llama_get_vocab(ctx);
|
const llama_vocab * vocab = llama_get_vocab(ctx);
|
||||||
llama_token bos = llama_token_bos_impl(*vocab);
|
llama_token bos = vocab->token_bos();
|
||||||
//llama_token eos = llama_token_eos_impl(*vocab);
|
//llama_token eos = llama_token_eos_impl(*vocab);
|
||||||
|
|
||||||
const unsigned int n_vocab = llama_n_vocab(model);
|
const unsigned int n_vocab = llama_n_vocab(model);
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ set (GGML_CUDA_MIN_BATCH_OFFLOAD "32" CACHE STRING
|
|||||||
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
|
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
|
||||||
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
|
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
|
||||||
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
|
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
|
||||||
option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF)
|
option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ON)
|
||||||
|
|
||||||
option(GGML_IQK_FLASH_ATTENTION "ggml: enable the IQK FlashAttention CPU kernels" ON)
|
option(GGML_IQK_FLASH_ATTENTION "ggml: enable the IQK FlashAttention CPU kernels" ON)
|
||||||
option(GGML_IQK_FA_ALL_QUANTS "ggml: compile all quants for IQK FlashAttention" OFF)
|
option(GGML_IQK_FA_ALL_QUANTS "ggml: compile all quants for IQK FlashAttention" OFF)
|
||||||
|
|||||||
@@ -325,6 +325,16 @@
|
|||||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
||||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||||
|
|
||||||
|
#define GGML_TENSOR_TERNARY_OP_LOCALS \
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||||
|
|
||||||
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
|
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
|
||||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
||||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
||||||
@@ -571,6 +581,7 @@ extern "C" {
|
|||||||
|
|
||||||
GGML_OP_DUP,
|
GGML_OP_DUP,
|
||||||
GGML_OP_ADD,
|
GGML_OP_ADD,
|
||||||
|
GGML_OP_ADD_ID,
|
||||||
GGML_OP_ADD1,
|
GGML_OP_ADD1,
|
||||||
GGML_OP_ACC,
|
GGML_OP_ACC,
|
||||||
GGML_OP_SUB,
|
GGML_OP_SUB,
|
||||||
@@ -674,6 +685,7 @@ extern "C" {
|
|||||||
GGML_UNARY_OP_HARDSWISH,
|
GGML_UNARY_OP_HARDSWISH,
|
||||||
GGML_UNARY_OP_HARDSIGMOID,
|
GGML_UNARY_OP_HARDSIGMOID,
|
||||||
GGML_UNARY_OP_SWIGLU,
|
GGML_UNARY_OP_SWIGLU,
|
||||||
|
GGML_UNARY_OP_SWIGLU_OAI,
|
||||||
|
|
||||||
GGML_UNARY_OP_COUNT,
|
GGML_UNARY_OP_COUNT,
|
||||||
};
|
};
|
||||||
@@ -1028,6 +1040,13 @@ extern "C" {
|
|||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
enum ggml_type type);
|
enum ggml_type type);
|
||||||
|
|
||||||
|
// dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
|
||||||
|
GGML_API struct ggml_tensor * ggml_add_id(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
struct ggml_tensor * ids);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_add1(
|
GGML_API struct ggml_tensor * ggml_add1(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
@@ -1268,6 +1287,13 @@ extern "C" {
|
|||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_swiglu_oai(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
float alpha,
|
||||||
|
float limit);
|
||||||
|
|
||||||
// a - x
|
// a - x
|
||||||
// b - dy
|
// b - dy
|
||||||
GGML_API struct ggml_tensor * ggml_silu_back(
|
GGML_API struct ggml_tensor * ggml_silu_back(
|
||||||
@@ -1370,6 +1396,16 @@ extern "C" {
|
|||||||
struct ggml_tensor * ids,
|
struct ggml_tensor * ids,
|
||||||
enum ggml_unary_op op);
|
enum ggml_unary_op op);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_moe_up_gate_ext(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a_up,
|
||||||
|
struct ggml_tensor * a_gate,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
struct ggml_tensor * ids,
|
||||||
|
struct ggml_tensor * a_up_b,
|
||||||
|
struct ggml_tensor * a_gate_b,
|
||||||
|
enum ggml_unary_op op);
|
||||||
|
|
||||||
// A: m columns, n rows,
|
// A: m columns, n rows,
|
||||||
// B: p columns, n rows,
|
// B: p columns, n rows,
|
||||||
// result is m columns, p rows
|
// result is m columns, p rows
|
||||||
@@ -1662,6 +1698,11 @@ extern "C" {
|
|||||||
float scale,
|
float scale,
|
||||||
float max_bias);
|
float max_bias);
|
||||||
|
|
||||||
|
GGML_API void ggml_soft_max_add_sinks(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * sinks);
|
||||||
|
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
@@ -1998,6 +2039,10 @@ extern "C" {
|
|||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_prec prec);
|
enum ggml_prec prec);
|
||||||
|
|
||||||
|
GGML_API void ggml_flash_attn_ext_add_sinks(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * sinks);
|
||||||
|
|
||||||
// TODO: needs to be adapted to ggml_flash_attn_ext
|
// TODO: needs to be adapted to ggml_flash_attn_ext
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
|
|||||||
case GGML_OP_DIAG_MASK_ZERO:
|
case GGML_OP_DIAG_MASK_ZERO:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
|
|||||||
@@ -37,6 +37,8 @@
|
|||||||
#include "ggml-cuda/unary.cuh"
|
#include "ggml-cuda/unary.cuh"
|
||||||
#include "ggml-cuda/upscale.cuh"
|
#include "ggml-cuda/upscale.cuh"
|
||||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||||
|
#include "ggml-cuda/add-id.cuh"
|
||||||
|
#include "ggml-cuda/graph.cuh"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
@@ -49,6 +51,7 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include <condition_variable>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdarg.h>
|
#include <stdarg.h>
|
||||||
@@ -77,6 +80,7 @@ GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback,
|
|||||||
#define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
#define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
||||||
#define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
#define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
||||||
#define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
#define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||||
|
#define GGML_CUDA_LOG_DEBUG(...) ggml_cuda_log(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||||
|
|
||||||
GGML_ATTRIBUTE_FORMAT(2, 3)
|
GGML_ATTRIBUTE_FORMAT(2, 3)
|
||||||
static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) {
|
static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) {
|
||||||
@@ -444,6 +448,35 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
|
|||||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::mutex ggml_cuda_lock;
|
||||||
|
static std::condition_variable ggml_cuda_lock_cv;
|
||||||
|
static std::atomic<int> ggml_cuda_lock_counter;
|
||||||
|
|
||||||
|
ggml_backend_cuda_context::ggml_backend_cuda_context(int device) :
|
||||||
|
device(device), name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_cuda_context::~ggml_backend_cuda_context() {
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
|
||||||
|
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
|
||||||
|
|
||||||
|
if (copy_event != nullptr) {
|
||||||
|
CUDA_CHECK(cudaEventDestroy(copy_event));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
|
||||||
|
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
|
||||||
|
if (streams[i][j] != nullptr) {
|
||||||
|
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (cublas_handles[i] != nullptr) {
|
||||||
|
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// cuda buffer
|
// cuda buffer
|
||||||
|
|
||||||
struct ggml_backend_cuda_buffer_context {
|
struct ggml_backend_cuda_buffer_context {
|
||||||
@@ -2220,6 +2253,24 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//static __global__ void k_quick_add(uint32_t n, uint32_t n_per_row, const float * src1, const float * src2, float * dst) {
|
||||||
|
//
|
||||||
|
// for (uint32_t j = threadIdx.x; j < n; j += blockDim.x) {
|
||||||
|
// dst[j] = src1[j] + src2[j % n_per_row];
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
static __global__ void k_quick_add(uint32_t n_per_row, const float * src1, const float * src2, float * dst) {
|
||||||
|
|
||||||
|
uint32_t row = blockIdx.x;
|
||||||
|
const float * src1_row = src1 + row*n_per_row;
|
||||||
|
float * dst_row = dst + row*n_per_row;
|
||||||
|
|
||||||
|
for (uint32_t j = threadIdx.x; j < n_per_row; j += blockDim.x) {
|
||||||
|
dst_row[j] = src1_row[j] + src2[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
|
static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
|
||||||
const ggml_tensor * ids, std::vector<int>& moe_counts, std::vector<int>& cum_moe_counts,
|
const ggml_tensor * ids, std::vector<int>& moe_counts, std::vector<int>& cum_moe_counts,
|
||||||
ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {
|
ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {
|
||||||
@@ -2270,7 +2321,7 @@ static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n
|
|||||||
return is_ser;
|
return is_ser;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const ggml_tensor * src1 = dst->src[1];
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
const ggml_tensor * ids = dst->src[2];
|
const ggml_tensor * ids = dst->src[2];
|
||||||
@@ -2319,7 +2370,25 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
0, src0->ne[1], 1, src1_padded_col_size, stream);
|
0, src0->ne[1], 1, src1_padded_col_size, stream);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
return;
|
if (next && next->op == GGML_OP_MUL_MAT_ID && next->src[0]->type == src0->type && src1 == next->src[1] &&
|
||||||
|
ggml_are_same_shape(src0, next->src[0]) &&
|
||||||
|
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
|
||||||
|
ggml_backend_buffer_is_cuda(next->buffer) &&
|
||||||
|
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer)) {
|
||||||
|
ggml_backend_cuda_buffer_context * next_src0_ctx = (ggml_backend_cuda_buffer_context *) next->src[0]->buffer->context;
|
||||||
|
ggml_backend_cuda_buffer_context * next_dst_ctx = (ggml_backend_cuda_buffer_context *) next->buffer->context;
|
||||||
|
if (next_src0_ctx->device == device_id &&
|
||||||
|
next_dst_ctx->device == device_id) {
|
||||||
|
local_dst.data = next->data;
|
||||||
|
ggml_cuda_op_mul_mat_vec_q_id(ctx, next->src[0], &local_src1, ids, &local_dst,
|
||||||
|
(const char *)next->src[0]->data, nullptr, src1_quantized.get(), (float *)next->data,
|
||||||
|
0, src0->ne[1], 1, src1_padded_col_size, stream);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2356,7 +2425,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
dst_row.nb[2] = nb1;
|
dst_row.nb[2] = nb1;
|
||||||
dst_row.nb[3] = nb1;
|
dst_row.nb[3] = nb1;
|
||||||
|
|
||||||
if (ne12 == 1) {
|
if (false && ne12 == 1) {
|
||||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||||
const char * ids_dev = (const char *) ids->data;
|
const char * ids_dev = (const char *) ids->data;
|
||||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
||||||
@@ -2442,6 +2511,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
|
static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
|
||||||
@@ -2470,6 +2540,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
|||||||
src0_2_ctx->device == device_id &&
|
src0_2_ctx->device == device_id &&
|
||||||
src1_ctx->device == device_id &&
|
src1_ctx->device == device_id &&
|
||||||
dst_ctx->device == device_id) {
|
dst_ctx->device == device_id) {
|
||||||
|
//printf("%s(%s, %s): %ld x %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, src0_1->name, src0_2->name,
|
||||||
|
// src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2], ids->ne[0], ids->ne[1], ids->ne[2]);
|
||||||
// Fast TG path
|
// Fast TG path
|
||||||
const int64_t n_ids = ids->ne[0];
|
const int64_t n_ids = ids->ne[0];
|
||||||
auto stream = ctx.stream(device_id, 0);
|
auto stream = ctx.stream(device_id, 0);
|
||||||
@@ -2505,12 +2577,26 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
|||||||
0, src0_1->ne[1], 1, src1_padded_col_size, stream);
|
0, src0_1->ne[1], 1, src1_padded_col_size, stream);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
if (dst->src[4]) {
|
||||||
|
ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[4]->data,
|
||||||
|
(const int32_t *)ids->data, (float *)local_dst.data,
|
||||||
|
local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2],
|
||||||
|
local_dst.nb[1], local_dst.nb[2], dst->src[4]->nb[1], ids->nb[2], stream);
|
||||||
|
}
|
||||||
|
|
||||||
local_dst.data = dst_gate_contiguous.get();
|
local_dst.data = dst_gate_contiguous.get();
|
||||||
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst,
|
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst,
|
||||||
(const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(),
|
(const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(),
|
||||||
0, src0_2->ne[1], 1, src1_padded_col_size, stream);
|
0, src0_2->ne[1], 1, src1_padded_col_size, stream);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
if (dst->src[5]) {
|
||||||
|
ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[5]->data,
|
||||||
|
(const int32_t *)ids->data, (float *)local_dst.data,
|
||||||
|
local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2],
|
||||||
|
local_dst.nb[1], local_dst.nb[2], dst->src[5]->nb[1], ids->nb[2], stream);
|
||||||
|
}
|
||||||
|
|
||||||
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
|
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
|
||||||
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
|
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
|
||||||
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) &&
|
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) &&
|
||||||
@@ -2518,8 +2604,15 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
|||||||
ggml_backend_buffer_is_cuda(next->buffer) &&
|
ggml_backend_buffer_is_cuda(next->buffer) &&
|
||||||
((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) {
|
((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) {
|
||||||
|
|
||||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst->ne[0]*n_ids,
|
auto unary_op = (ggml_unary_op)dst->op_params[0];
|
||||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
|
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
|
||||||
|
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||||
|
(float *)dst_gate_contiguous.get(), dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream);
|
||||||
|
} else {
|
||||||
|
ggml_fused_mul_unary(ctx, unary_op, dst->ne[0]*n_ids,
|
||||||
|
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||||
|
(float *)dst_gate_contiguous.get());
|
||||||
|
}
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING);
|
const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING);
|
||||||
@@ -2555,8 +2648,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
|||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
|
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
|
||||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
|
auto unary_op = (ggml_unary_op)dst->op_params[0];
|
||||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
|
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
|
||||||
|
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||||
|
(float *)dst->data, dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream);
|
||||||
|
} else {
|
||||||
|
ggml_fused_mul_unary(ctx, unary_op, ggml_nelements(dst),
|
||||||
|
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
|
||||||
|
}
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -2624,7 +2723,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
|||||||
final_src.nb[3] = final_src.nb[2];
|
final_src.nb[3] = final_src.nb[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ne12 == 1) {
|
if (false && ne12 == 1) {
|
||||||
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
|
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
|
||||||
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
|
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
|
||||||
if (fuse_down) {
|
if (fuse_down) {
|
||||||
@@ -2761,6 +2860,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
|||||||
}
|
}
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
if (dst->src[4]) {
|
||||||
|
dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u));
|
||||||
|
dim3 grid_dims(num_src1_rows);
|
||||||
|
k_quick_add<<<grid_dims, block_dims, 0, stream>>>(dst_row.ne[0], (const float *)dst_row.data,
|
||||||
|
(const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
dst_row.data = dst_gate_contiguous.get();
|
dst_row.data = dst_gate_contiguous.get();
|
||||||
if (use_quantized_src1) {
|
if (use_quantized_src1) {
|
||||||
ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
|
ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
|
||||||
@@ -2770,8 +2877,24 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
|||||||
}
|
}
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
|
if (dst->src[5]) {
|
||||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
|
dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u));
|
||||||
|
dim3 grid_dims(num_src1_rows);
|
||||||
|
k_quick_add<<<grid_dims, block_dims, 0, stream>>>(dst_row.ne[0], (const float *)dst_row.data,
|
||||||
|
(const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto unary_op = (ggml_unary_op)dst->op_params[0];
|
||||||
|
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
|
||||||
|
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||||
|
(float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0],
|
||||||
|
1.702f, 7.0f, stream);
|
||||||
|
} else {
|
||||||
|
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
|
||||||
|
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||||
|
(float *)dst_gate_contiguous.get());
|
||||||
|
}
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
if (fuse_down) {
|
if (fuse_down) {
|
||||||
@@ -2851,6 +2974,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
ggml_cuda_op_add(ctx, dst);
|
ggml_cuda_op_add(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
|
ggml_cuda_op_add_id(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_MULTI_ADD:
|
case GGML_OP_MULTI_ADD:
|
||||||
ggml_cuda_op_multi_add(ctx, dst);
|
ggml_cuda_op_multi_add(ctx, dst);
|
||||||
break;
|
break;
|
||||||
@@ -2877,6 +3003,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||||||
case GGML_UNARY_OP_SWIGLU:
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
ggml_cuda_op_swiglu(ctx, dst);
|
ggml_cuda_op_swiglu(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_SWIGLU_OAI:
|
||||||
|
ggml_cuda_op_swiglu_oai(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
ggml_cuda_op_gelu_quick(ctx, dst);
|
ggml_cuda_op_gelu_quick(ctx, dst);
|
||||||
break;
|
break;
|
||||||
@@ -2938,7 +3067,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
ggml_cuda_mul_mat_id(ctx, dst);
|
skip_next = ggml_cuda_mul_mat_id(ctx, dst, next);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_MOE_FUSED_UP_GATE:
|
case GGML_OP_MOE_FUSED_UP_GATE:
|
||||||
skip_next = ggml_cuda_up_gate_unary(ctx, dst, next);
|
skip_next = ggml_cuda_up_gate_unary(ctx, dst, next);
|
||||||
@@ -3119,6 +3248,105 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
|||||||
GGML_UNUSED(backend);
|
GGML_UNUSED(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
|
||||||
|
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||||
|
bool use_cuda_graph) {
|
||||||
|
|
||||||
|
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||||
|
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
||||||
|
|
||||||
|
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||||
|
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
||||||
|
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
|
||||||
|
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
|
||||||
|
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
|
||||||
|
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
|
||||||
|
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
|
||||||
|
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->op == GGML_OP_MUL_MAT_ID && (node->ne[2] != 1 || node->src[2]->ne[0] != 1)) {
|
||||||
|
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_CUDA_LOG_DEBUG("%s(%s): disabling CUDA graphs due to unsupported node type %ld %ld\n",
|
||||||
|
__func__, node->src[0]->name, node->ne[2], node->src[2]->ne[0]);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
if (node->op == GGML_OP_MOE_FUSED_UP_GATE) {
|
||||||
|
auto src0_1 = node->src[0];
|
||||||
|
auto src0_2 = node->src[1];
|
||||||
|
auto src1 = node->src[2];
|
||||||
|
if (src1->ne[1] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || src1->type != GGML_TYPE_F32 ||
|
||||||
|
!ggml_is_quantized(src0_1->type) || !ggml_is_quantized(src0_2->type)) {
|
||||||
|
use_cuda_graph = false;
|
||||||
|
} else {
|
||||||
|
if (i < cgraph->n_nodes-1) {
|
||||||
|
auto next = cgraph->nodes[i+1];
|
||||||
|
if (next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type)) {
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->op == GGML_OP_ADD &&
|
||||||
|
node->src[1] && node->src[1]->ne[1] > 1 &&
|
||||||
|
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
|
||||||
|
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
|
||||||
|
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
|
||||||
|
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
|
||||||
|
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
|
||||||
|
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
||||||
|
// by means of matching node names. See
|
||||||
|
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
||||||
|
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
|
||||||
|
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||||
|
use_cuda_graph = false;
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->op == GGML_OP_CPY) {
|
||||||
|
|
||||||
|
// Store the pointers which are updated for each token, such that these can be sent
|
||||||
|
// to the device and accessed using indirection from CUDA graph
|
||||||
|
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
|
||||||
|
|
||||||
|
// store a pointer to each copy op CUDA kernel to identify it later
|
||||||
|
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
||||||
|
if (!ptr) {
|
||||||
|
use_cuda_graph = false;
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!use_cuda_graph) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_cuda_graph) {
|
||||||
|
cuda_ctx->cuda_graph->use_cpy_indirection = true;
|
||||||
|
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
|
||||||
|
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
return use_cuda_graph;
|
||||||
|
}
|
||||||
|
|
||||||
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||||
graph_node_properties->node_address = node->data;
|
graph_node_properties->node_address = node->data;
|
||||||
graph_node_properties->node_op = node->op;
|
graph_node_properties->node_op = node->op;
|
||||||
@@ -3129,6 +3357,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
|
|||||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||||
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
|
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
|
||||||
}
|
}
|
||||||
|
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||||
@@ -3160,9 +3389,246 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node->op == GGML_OP_SCALE &&
|
||||||
|
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
||||||
|
|
||||||
|
bool cuda_graph_update_required = false;
|
||||||
|
|
||||||
|
if (cuda_ctx->cuda_graph->instance == nullptr) {
|
||||||
|
cuda_graph_update_required = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the graph size has changed
|
||||||
|
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
|
||||||
|
cuda_graph_update_required = true;
|
||||||
|
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||||
|
// and store properties to allow this comparison for the next token
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
bool has_matching_properties = true;
|
||||||
|
if (!cuda_graph_update_required) {
|
||||||
|
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
||||||
|
}
|
||||||
|
if (!has_matching_properties) {
|
||||||
|
cuda_graph_update_required = true;
|
||||||
|
}
|
||||||
|
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return cuda_graph_update_required;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||||
|
|
||||||
|
#if CUDART_VERSION >= 12000
|
||||||
|
cudaGraphExecUpdateResultInfo result_info;
|
||||||
|
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||||
|
#else
|
||||||
|
cudaGraphNode_t errorNode;
|
||||||
|
cudaGraphExecUpdateResult result_info;
|
||||||
|
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
||||||
|
#endif // CUDART_VERSION >= 12000
|
||||||
|
|
||||||
|
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_CUDA_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// The pre-existing graph exec cannot be updated due to violated constraints
|
||||||
|
// so instead clear error and re-instantiate
|
||||||
|
(void)cudaGetLastError();
|
||||||
|
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
|
||||||
|
cuda_ctx->cuda_graph->instance = nullptr;
|
||||||
|
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(stat == cudaSuccess);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||||
|
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
||||||
|
// flag used to determine whether it is an integrated_gpu
|
||||||
|
// TODO
|
||||||
|
const bool integrated = false; //ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
||||||
|
|
||||||
|
while (!graph_evaluated_or_captured) {
|
||||||
|
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
||||||
|
// With the use of CUDA graphs, the execution will be performed by the graph launch.
|
||||||
|
if (!use_cuda_graph || cuda_graph_update_required) {
|
||||||
|
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr;
|
||||||
|
|
||||||
|
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
||||||
|
if (!disable_fusion) {
|
||||||
|
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
|
||||||
|
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||||
|
i++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
|
||||||
|
i += 2;
|
||||||
|
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#ifndef NDEBUG
|
||||||
|
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|
||||||
|
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||||
|
if (node->src[j] != nullptr) {
|
||||||
|
assert(node->src[j]->buffer);
|
||||||
|
//assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
|
||||||
|
// ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(integrated);
|
||||||
|
#endif // NDEBUG
|
||||||
|
|
||||||
|
bool skip_next = false;
|
||||||
|
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next);
|
||||||
|
if (!ok) {
|
||||||
|
GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||||
|
}
|
||||||
|
GGML_ASSERT(ok);
|
||||||
|
if (skip_next) ++i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
|
||||||
|
if (cuda_ctx->cuda_graph->graph != nullptr) {
|
||||||
|
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
|
||||||
|
cuda_ctx->cuda_graph->graph = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
|
||||||
|
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
|
||||||
|
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
|
||||||
|
ggml_cuda_lock_cv.notify_all();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_cuda_graph) {
|
||||||
|
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
|
||||||
|
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||||
|
}
|
||||||
|
if (cuda_graph_update_required) { // Update graph executable
|
||||||
|
update_cuda_graph_executable(cuda_ctx);
|
||||||
|
}
|
||||||
|
// Launch graph
|
||||||
|
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
||||||
|
#else
|
||||||
|
graph_evaluated_or_captured = true;
|
||||||
|
#endif // USE_CUDA_GRAPH
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||||
|
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||||
|
|
||||||
|
ggml_cuda_set_device(cuda_ctx->device);
|
||||||
|
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||||
|
|
||||||
|
// Objects required for CUDA Graph
|
||||||
|
if (cuda_ctx->cuda_graph == nullptr) {
|
||||||
|
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool use_cuda_graph = true;
|
||||||
|
bool cuda_graph_update_required = false;
|
||||||
|
|
||||||
|
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
||||||
|
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
|
||||||
|
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
|
||||||
|
// or previous graph capture failure.
|
||||||
|
// Also disable for multi-gpu for now. TO DO investigate
|
||||||
|
if (disable_cuda_graphs_due_to_env
|
||||||
|
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|
||||||
|
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
|
||||||
|
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
|
||||||
|
use_cuda_graph = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_cuda_graph) {
|
||||||
|
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||||
|
|
||||||
|
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
|
||||||
|
|
||||||
|
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
||||||
|
if (use_cuda_graph && cuda_graph_update_required) {
|
||||||
|
cuda_ctx->cuda_graph->number_consecutive_updates++;
|
||||||
|
} else {
|
||||||
|
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
|
||||||
|
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_cuda_graph && cuda_graph_update_required) {
|
||||||
|
// Start CUDA graph capture
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
|
||||||
|
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!use_cuda_graph) {
|
||||||
|
cuda_ctx->cuda_graph->use_cpy_indirection = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
bool use_cuda_graph = false;
|
||||||
|
bool cuda_graph_update_required = false;
|
||||||
|
#endif // USE_CUDA_GRAPH
|
||||||
|
|
||||||
|
bool graph_evaluated_or_captured = false;
|
||||||
|
|
||||||
|
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
|
||||||
|
|
||||||
|
return GGML_STATUS_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||||
|
|
||||||
@@ -3431,6 +3897,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||||||
|
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
||||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
||||||
@@ -3440,6 +3907,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
case GGML_UNARY_OP_SWIGLU:
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
|
case GGML_UNARY_OP_SWIGLU_OAI:
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
case GGML_UNARY_OP_HARDSIGMOID:
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
@@ -3629,6 +4097,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
case GGML_OP_PERMUTE:
|
case GGML_OP_PERMUTE:
|
||||||
case GGML_OP_TRANSPOSE:
|
case GGML_OP_TRANSPOSE:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
case GGML_OP_MULTI_ADD:
|
case GGML_OP_MULTI_ADD:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
|
|||||||
72
ggml/src/ggml-cuda/add-id.cu
Normal file
72
ggml/src/ggml-cuda/add-id.cu
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
#include "add-id.cuh"
|
||||||
|
|
||||||
|
static __global__ void add_id_kernel(
|
||||||
|
const float * src0, const float * src1, const int32_t * src2, float * dst,
|
||||||
|
int64_t ne0, int64_t ne1,
|
||||||
|
size_t nb01, size_t nb02,
|
||||||
|
size_t nb11,
|
||||||
|
size_t nb21
|
||||||
|
) {
|
||||||
|
|
||||||
|
const int64_t i1 = blockIdx.x;
|
||||||
|
const int64_t i2 = blockIdx.y;
|
||||||
|
|
||||||
|
const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
|
||||||
|
|
||||||
|
const size_t nb1 = ne0 * sizeof(float);
|
||||||
|
const size_t nb2 = ne1 * nb1;
|
||||||
|
|
||||||
|
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
|
||||||
|
const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02);
|
||||||
|
const float * src1_row = (const float *)((char *)src1 + i11*nb11);
|
||||||
|
|
||||||
|
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
|
||||||
|
dst_row[i0] = src0_row[i0] + src1_row[i0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
const ggml_tensor * src2 = dst->src[2];
|
||||||
|
|
||||||
|
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
GGML_ASSERT(nb00 == sizeof(float));
|
||||||
|
GGML_ASSERT(nb10 == sizeof(float));
|
||||||
|
GGML_ASSERT(nb20 == sizeof(int32_t));
|
||||||
|
|
||||||
|
const float * src0_d = (const float *)src0->data;
|
||||||
|
const float * src1_d = (const float *)src1->data;
|
||||||
|
const int32_t * src2_d = (const int32_t *)src2->data;
|
||||||
|
float * dst_d = (float *)dst->data;
|
||||||
|
|
||||||
|
int threads = std::min((int)ne00, 768); // cols
|
||||||
|
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
|
||||||
|
add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
|
||||||
|
src0_d, src1_d, src2_d, dst_d,
|
||||||
|
ne0, ne1,
|
||||||
|
nb01, nb02,
|
||||||
|
nb11,
|
||||||
|
nb21
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst,
|
||||||
|
int64_t ne00, int64_t ne01, int64_t ne02,
|
||||||
|
int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream) {
|
||||||
|
int threads = std::min((int)ne00, 768); // cols
|
||||||
|
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
|
||||||
|
add_id_kernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
src0, src1, src2, dst,
|
||||||
|
ne0, ne1,
|
||||||
|
nb01, nb02,
|
||||||
|
nb11,
|
||||||
|
nb21
|
||||||
|
);
|
||||||
|
}
|
||||||
8
ggml/src/ggml-cuda/add-id.cuh
Normal file
8
ggml/src/ggml-cuda/add-id.cuh
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst,
|
||||||
|
int64_t ne00, int64_t ne01, int64_t ne02,
|
||||||
|
int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream);
|
||||||
|
|
||||||
@@ -108,6 +108,23 @@ static const char * cu_get_error_str(CUresult err) {
|
|||||||
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||||
|
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||||
|
do { \
|
||||||
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
|
||||||
|
const int id = ggml_cuda_get_device(); \
|
||||||
|
if (!shared_memory_limit_raised[id]) { \
|
||||||
|
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
||||||
|
shared_memory_limit_raised[id] = true; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
#else
|
||||||
|
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||||
|
do { \
|
||||||
|
GGML_UNUSED(nbytes); \
|
||||||
|
} while (0)
|
||||||
|
#endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||||
|
|
||||||
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
||||||
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
||||||
#else
|
#else
|
||||||
@@ -808,37 +825,7 @@ struct ggml_tensor_extra_gpu {
|
|||||||
#define USE_CUDA_GRAPH
|
#define USE_CUDA_GRAPH
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
struct ggml_graph_node_properties {
|
struct ggml_cuda_graph;
|
||||||
void * node_address;
|
|
||||||
ggml_op node_op;
|
|
||||||
int64_t ne[GGML_MAX_DIMS];
|
|
||||||
size_t nb[GGML_MAX_DIMS];
|
|
||||||
void * src_address[GGML_MAX_SRC];
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ggml_cuda_graph {
|
|
||||||
#ifdef USE_CUDA_GRAPH
|
|
||||||
~ggml_cuda_graph() {
|
|
||||||
if (instance != nullptr) {
|
|
||||||
CUDA_CHECK(cudaGraphExecDestroy(instance));
|
|
||||||
}
|
|
||||||
if (graph != nullptr) {
|
|
||||||
CUDA_CHECK(cudaGraphDestroy(graph));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cudaGraph_t graph = nullptr;
|
|
||||||
cudaGraphExec_t instance = nullptr;
|
|
||||||
size_t num_nodes = 0;
|
|
||||||
std::vector<cudaGraphNode_t> nodes;
|
|
||||||
std::vector<cudaKernelNodeParams> params;
|
|
||||||
bool disable_due_to_gpu_arch = false;
|
|
||||||
bool disable_due_to_too_many_updates = false;
|
|
||||||
bool disable_due_to_failed_graph_capture = false;
|
|
||||||
int number_consecutive_updates = 0;
|
|
||||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
|
||||||
std::vector<char **> updated_kernel_arg;
|
|
||||||
#endif
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ggml_backend_cuda_context {
|
struct ggml_backend_cuda_context {
|
||||||
int device;
|
int device;
|
||||||
@@ -850,26 +837,9 @@ struct ggml_backend_cuda_context {
|
|||||||
|
|
||||||
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
||||||
|
|
||||||
explicit ggml_backend_cuda_context(int device) :
|
explicit ggml_backend_cuda_context(int device);
|
||||||
device(device),
|
|
||||||
name(GGML_CUDA_NAME + std::to_string(device)) {
|
|
||||||
}
|
|
||||||
|
|
||||||
~ggml_backend_cuda_context() {
|
~ggml_backend_cuda_context();
|
||||||
if (copy_event != nullptr) {
|
|
||||||
CUDA_CHECK(cudaEventDestroy(copy_event));
|
|
||||||
}
|
|
||||||
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
|
|
||||||
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
|
|
||||||
if (streams[i][j] != nullptr) {
|
|
||||||
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (cublas_handles[i] != nullptr) {
|
|
||||||
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cudaStream_t stream(int device, int stream) {
|
cudaStream_t stream(int device, int stream) {
|
||||||
if (streams[device][stream] == nullptr) {
|
if (streams[device][stream] == nullptr) {
|
||||||
|
|||||||
262
ggml/src/ggml-cuda/cpy-utils.cuh
Normal file
262
ggml/src/ggml-cuda/cpy-utils.cuh
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml-common.h"
|
||||||
|
|
||||||
|
template<typename src_t, typename dst_t>
|
||||||
|
static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
|
||||||
|
if constexpr (std::is_same_v<src_t, dst_t>) {
|
||||||
|
*dst = *src;
|
||||||
|
} else {
|
||||||
|
*dst = float(*src);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
||||||
|
if (x <= val[0]) return 0;
|
||||||
|
if (x >= val[n-1]) return n-1;
|
||||||
|
int ml = 0, mu = n-1;
|
||||||
|
while (mu-ml > 1) {
|
||||||
|
int mav = (ml+mu)/2;
|
||||||
|
if (x < val[mav]) mu = mav; else ml = mav;
|
||||||
|
}
|
||||||
|
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
|
||||||
|
float amax = 0.0f;
|
||||||
|
float vmax = 0.0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_0; ++j) {
|
||||||
|
const float v = x[j];
|
||||||
|
if (amax < fabsf(v)) {
|
||||||
|
amax = fabsf(v);
|
||||||
|
vmax = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = vmax / -8;
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
y->d = d;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_0/2; ++j) {
|
||||||
|
const float x0 = x[0 + j]*id;
|
||||||
|
const float x1 = x[QK4_0/2 + j]*id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
|
||||||
|
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
|
||||||
|
|
||||||
|
y->qs[j] = xi0;
|
||||||
|
y->qs[j] |= xi1 << 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
|
||||||
|
float vmin = FLT_MAX;
|
||||||
|
float vmax = -FLT_MAX;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_1; ++j) {
|
||||||
|
const float v = x[j];
|
||||||
|
if (v < vmin) vmin = v;
|
||||||
|
if (v > vmax) vmax = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = (vmax - vmin) / ((1 << 4) - 1);
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
y->dm.x = d;
|
||||||
|
y->dm.y = vmin;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_1/2; ++j) {
|
||||||
|
const float x0 = (x[0 + j] - vmin)*id;
|
||||||
|
const float x1 = (x[QK4_1/2 + j] - vmin)*id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
|
||||||
|
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
|
||||||
|
|
||||||
|
y->qs[j] = xi0;
|
||||||
|
y->qs[j] |= xi1 << 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
|
||||||
|
float amax = 0.0f;
|
||||||
|
float vmax = 0.0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK5_0; ++j) {
|
||||||
|
const float v = x[j];
|
||||||
|
if (amax < fabsf(v)) {
|
||||||
|
amax = fabsf(v);
|
||||||
|
vmax = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = vmax / -16;
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
y->d = d;
|
||||||
|
|
||||||
|
uint32_t qh = 0;
|
||||||
|
for (int j = 0; j < QK5_0/2; ++j) {
|
||||||
|
const float x0 = x[0 + j]*id;
|
||||||
|
const float x1 = x[QK5_0/2 + j]*id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
|
||||||
|
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
|
||||||
|
|
||||||
|
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||||
|
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||||
|
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
||||||
|
}
|
||||||
|
memcpy(y->qh, &qh, sizeof(qh));
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
|
||||||
|
float min = x[0];
|
||||||
|
float max = x[0];
|
||||||
|
|
||||||
|
for (int j = 1; j < QK5_1; ++j) {
|
||||||
|
const float v = x[j];
|
||||||
|
min = v < min ? v : min;
|
||||||
|
max = v > max ? v : max;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = (max - min) / 31;
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
y->dm.x = d;
|
||||||
|
y->dm.y = min;
|
||||||
|
|
||||||
|
uint32_t qh = 0;
|
||||||
|
for (int j = 0; j < QK5_1/2; ++j) {
|
||||||
|
const float x0 = (x[0 + j] - min)*id;
|
||||||
|
const float x1 = (x[QK5_1/2 + j] - min)*id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
||||||
|
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
||||||
|
|
||||||
|
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||||
|
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||||
|
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
||||||
|
}
|
||||||
|
memcpy(y->qh, &qh, sizeof(qh));
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
|
||||||
|
float amax = 0.0f; // absolute max
|
||||||
|
|
||||||
|
for (int j = 0; j < QK8_0; j++) {
|
||||||
|
const float v = x[j];
|
||||||
|
amax = fmaxf(amax, fabsf(v));
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = amax / ((1 << 7) - 1);
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
y->d = d;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK8_0; ++j) {
|
||||||
|
const float x0 = x[j]*id;
|
||||||
|
y->qs[j] = roundf(x0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
|
||||||
|
float amax = 0.0f;
|
||||||
|
float vmax = 0.0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_NL; ++j) {
|
||||||
|
const float v = x[j];
|
||||||
|
if (amax < fabsf(v)) {
|
||||||
|
amax = fabsf(v);
|
||||||
|
vmax = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float d = vmax / kvalues_iq4nl[0];
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
float sumqx = 0, sumq2 = 0;
|
||||||
|
for (int j = 0; j < QK4_NL/2; ++j) {
|
||||||
|
const float x0 = x[0 + j]*id;
|
||||||
|
const float x1 = x[QK4_NL/2 + j]*id;
|
||||||
|
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
|
||||||
|
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
|
||||||
|
y->qs[j] = xi0 | (xi1 << 4);
|
||||||
|
const float v0 = kvalues_iq4nl[xi0];
|
||||||
|
const float v1 = kvalues_iq4nl[xi1];
|
||||||
|
const float w0 = x[0 + j]*x[0 + j];
|
||||||
|
const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
|
||||||
|
sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
|
||||||
|
sumq2 += w0*v0*v0 + w1*v1*v1;
|
||||||
|
}
|
||||||
|
|
||||||
|
y->d = sumq2 > 0 ? sumqx/sumq2 : d;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void quantize_f32_q6_0_block(const float * __restrict__ xi, block_q6_0 * __restrict__ y) {
|
||||||
|
|
||||||
|
float amax = 0.0f;
|
||||||
|
float vmax = 0.0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK6_0; ++j) {
|
||||||
|
const float v = xi[j];
|
||||||
|
const float av = fabsf(xi[j]);
|
||||||
|
if (amax < av) {
|
||||||
|
amax = av;
|
||||||
|
vmax = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = vmax / -32;
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
y->d = d;
|
||||||
|
memset(y->qh, 0, QK6_0/4);
|
||||||
|
|
||||||
|
for (int j = 0; j < QK6_0/2; ++j) {
|
||||||
|
const float x0 = xi[0 + j]*id;
|
||||||
|
const float x1 = xi[QK4_0/2 + j]*id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f));
|
||||||
|
const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f));
|
||||||
|
|
||||||
|
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||||
|
const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
|
||||||
|
y->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrapper functions for cpy.cu compatibility
|
||||||
|
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
||||||
|
quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
|
||||||
|
quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
|
||||||
|
quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
||||||
|
quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) {
|
||||||
|
quantize_f32_q6_0_block((const float *)cxi, (block_q6_0 *)cdsti);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||||
|
quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
||||||
|
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename src_t, typename dst_t>
|
||||||
|
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
|
||||||
|
convert_flt((const src_t *)cxi, (dst_t *)cdsti);
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,11 @@
|
|||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
|
||||||
#define CUDA_CPY_BLOCK_SIZE 32
|
#define CUDA_CPY_BLOCK_SIZE 64
|
||||||
|
|
||||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false);
|
||||||
|
|
||||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
|
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
|
||||||
|
|
||||||
|
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
|
||||||
|
|||||||
@@ -86,6 +86,24 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
|
|||||||
#endif // GGML_CUDA_F16
|
#endif // GGML_CUDA_F16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void dequantize_q6_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||||
|
const block_q6_0 * x = (const block_q6_0 *) vx;
|
||||||
|
|
||||||
|
const dfloat d = x[ib].d;
|
||||||
|
|
||||||
|
const uint8_t h = x[ib].qh[iqs%8] >> 2*(iqs/8);
|
||||||
|
v.x = ((x[ib].qs[iqs] & 0xf) | ((h & 0x3) << 4));
|
||||||
|
v.y = ((x[ib].qs[iqs] >> 4) | ((h & 0xc) << 2));
|
||||||
|
|
||||||
|
#ifdef GGML_CUDA_F16
|
||||||
|
v = __hsub2(v, {32.0f, 32.0f});
|
||||||
|
v = __hmul2(v, {d, d});
|
||||||
|
#else
|
||||||
|
v.x = (v.x - 32.0f) * d;
|
||||||
|
v.y = (v.y - 32.0f) * d;
|
||||||
|
#endif // GGML_CUDA_F16
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||||
const block_q8_0 * x = (const block_q8_0 *) vx;
|
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)(
|
|||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
const float scale,
|
const float scale,
|
||||||
@@ -747,6 +748,7 @@ void launch_fattn(
|
|||||||
const ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
const ggml_tensor * mask = dst->src[3];
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
const ggml_tensor * sinks = dst->src[4];
|
||||||
|
|
||||||
ggml_tensor * KQV = dst;
|
ggml_tensor * KQV = dst;
|
||||||
|
|
||||||
@@ -837,6 +839,7 @@ void launch_fattn(
|
|||||||
K_data,
|
K_data,
|
||||||
V_data,
|
V_data,
|
||||||
mask ? ((const char *) mask->data) : nullptr,
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
|
sinks ? ((const char *) sinks->data) : nullptr,
|
||||||
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
scale, max_bias, m0, m1, softcap, n_head_log2,
|
scale, max_bias, m0, m1, softcap, n_head_log2,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
@@ -1008,7 +1011,8 @@ void launch_fattn_mma(
|
|||||||
const ggml_tensor * K = dst->src[1];
|
const ggml_tensor * K = dst->src[1];
|
||||||
const ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
const ggml_tensor * mask = dst->src[3];
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
const ggml_tensor * sinks = dst->src[4];
|
||||||
|
|
||||||
ggml_tensor * KQV = dst;
|
ggml_tensor * KQV = dst;
|
||||||
|
|
||||||
@@ -1162,6 +1166,7 @@ void launch_fattn_mma(
|
|||||||
K_data,
|
K_data,
|
||||||
V_data,
|
V_data,
|
||||||
mask ? ((const char *) mask->data) : nullptr,
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
|
sinks ? ((const char *)sinks->data) : nullptr,
|
||||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
|||||||
@@ -425,6 +425,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
const half2 * const __restrict__ K_h2,
|
const half2 * const __restrict__ K_h2,
|
||||||
const half2 * const __restrict__ V_h2,
|
const half2 * const __restrict__ V_h2,
|
||||||
const half2 * const __restrict__ mask_h2,
|
const half2 * const __restrict__ mask_h2,
|
||||||
|
const float * const __restrict__ sinks_f,
|
||||||
float2 * const __restrict__ dstk,
|
float2 * const __restrict__ dstk,
|
||||||
float2 * const __restrict__ dstk_fixup,
|
float2 * const __restrict__ dstk_fixup,
|
||||||
const float scale,
|
const float scale,
|
||||||
@@ -584,6 +585,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
||||||
|
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
|
||||||
|
// so it's being done unconditionally for every thread.
|
||||||
|
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
|
||||||
|
float KQ_max_scale[cols_per_thread];
|
||||||
|
#pragma unroll
|
||||||
|
for (int col = 0; col < cols_per_thread; ++col) {
|
||||||
|
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
|
||||||
|
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
|
||||||
|
const float sink = sinks_f[jc % ncols2];
|
||||||
|
|
||||||
|
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
||||||
|
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
|
||||||
|
KQ_max_scale[col] = expf(KQ_max_diff);
|
||||||
|
KQ_max[col] = KQ_max_new;
|
||||||
|
|
||||||
|
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
||||||
|
|
||||||
|
const float KQ_max_add = expf(sink - KQ_max_new);
|
||||||
|
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ntiles == 1) {
|
||||||
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
||||||
|
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int col = 0; col < cols_per_thread; ++col) {
|
||||||
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
|
||||||
|
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Write VKQ accumulators to shared memory in column-major format.
|
// Write VKQ accumulators to shared memory in column-major format.
|
||||||
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
||||||
// Also for np > 1 the combination is done via these values in shared memory.
|
// Also for np > 1 the combination is done via these values in shared memory.
|
||||||
@@ -823,6 +870,7 @@ static __global__ void flash_attn_mma_ext_f16(
|
|||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
const float scale,
|
const float scale,
|
||||||
@@ -896,6 +944,7 @@ static __global__ void flash_attn_mma_ext_f16(
|
|||||||
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
||||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
|
||||||
|
const float * sinks_f = sinks ? (const float *) sinks + channel * ncols2 : nullptr;
|
||||||
|
|
||||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||||
|
|
||||||
@@ -906,12 +955,12 @@ static __global__ void flash_attn_mma_ext_f16(
|
|||||||
if (kb0_start == 0) {
|
if (kb0_start == 0) {
|
||||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -934,6 +983,7 @@ static __global__ void flash_attn_mma_ext_f16(
|
|||||||
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
||||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
|
||||||
|
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
|
||||||
|
|
||||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||||
|
|
||||||
@@ -943,10 +993,10 @@ static __global__ void flash_attn_mma_ext_f16(
|
|||||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||||
constexpr bool needs_fixup = false;
|
constexpr bool needs_fixup = false;
|
||||||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
||||||
|
|||||||
@@ -43,37 +43,37 @@ struct fattn_mma_f16_config;
|
|||||||
// Perhaps the 256 head size needs a closer look
|
// Perhaps the 256 head size needs a closer look
|
||||||
// to see if this implementation is better.
|
// to see if this implementation is better.
|
||||||
//
|
//
|
||||||
//template <>
|
template <>
|
||||||
//struct fattn_mma_f16_config< 64, 64> {
|
struct fattn_mma_f16_config< 64, 64> {
|
||||||
// static constexpr int nbatch_fa = 64;
|
static constexpr int nbatch_fa = 64;
|
||||||
// static constexpr int nwarps_max = 4;
|
static constexpr int nwarps_max = 4;
|
||||||
// static constexpr bool Q_in_reg = true;
|
static constexpr bool Q_in_reg = true;
|
||||||
// static constexpr int nstages_target = 2;
|
static constexpr int nstages_target = 2;
|
||||||
//
|
|
||||||
// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
// return 32;
|
return 32;
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||||
// return 32;
|
return 32;
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
// return 32;
|
return 32;
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||||
// return 32;
|
return 32;
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
// return 32;
|
return 32;
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||||
// return 32;
|
return 32;
|
||||||
// }
|
}
|
||||||
//};
|
};
|
||||||
//
|
//
|
||||||
//template <>
|
//template <>
|
||||||
//struct fattn_mma_f16_config< 80, 80> {
|
//struct fattn_mma_f16_config< 80, 80> {
|
||||||
@@ -493,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_cp_async = nstages == 1;
|
constexpr bool use_cp_async = nstages == 1;
|
||||||
if constexpr (ncols2 > 1 || mask_h2) {
|
if (ncols2 > 1 || mask_h2) {
|
||||||
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
|
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -576,7 +576,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
float KQ_rowsum_add[cols_per_thread] = {0.0f};
|
float KQ_rowsum_add[cols_per_thread] = {0.0f};
|
||||||
|
|
||||||
if constexpr (ntiles == 1) {
|
if constexpr (ntiles == 1) {
|
||||||
if constexpr (ncols2 > 1 || mask_h2) {
|
if (ncols2 > 1 || mask_h2) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
|
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
|
||||||
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
|
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
|
||||||
@@ -818,6 +818,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
const half2 * const __restrict__ K_h2,
|
const half2 * const __restrict__ K_h2,
|
||||||
const half2 * const __restrict__ V_h2,
|
const half2 * const __restrict__ V_h2,
|
||||||
const half2 * const __restrict__ mask_h2,
|
const half2 * const __restrict__ mask_h2,
|
||||||
|
const float * const __restrict__ sinks_f,
|
||||||
float2 * const __restrict__ dstk,
|
float2 * const __restrict__ dstk,
|
||||||
float2 * const __restrict__ dstk_fixup,
|
float2 * const __restrict__ dstk_fixup,
|
||||||
const float scale,
|
const float scale,
|
||||||
@@ -975,6 +976,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
||||||
|
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
|
||||||
|
// so it's being done unconditionally for every thread.
|
||||||
|
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
|
||||||
|
float KQ_max_scale[cols_per_thread];
|
||||||
|
#pragma unroll
|
||||||
|
for (int col = 0; col < cols_per_thread; ++col) {
|
||||||
|
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
|
||||||
|
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
|
||||||
|
const float sink = sinks_f[jc % ncols2];
|
||||||
|
|
||||||
|
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
||||||
|
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
|
||||||
|
KQ_max_scale[col] = expf(KQ_max_diff);
|
||||||
|
KQ_max[col] = KQ_max_new;
|
||||||
|
|
||||||
|
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
||||||
|
|
||||||
|
const float KQ_max_add = expf(sink - KQ_max_new);
|
||||||
|
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ntiles == 1) {
|
||||||
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
||||||
|
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int col = 0; col < cols_per_thread; ++col) {
|
||||||
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
|
||||||
|
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Finally, sum up partial KQ rowsums.
|
// Finally, sum up partial KQ rowsums.
|
||||||
// The partial sums are spread across 8/4 threads each, does not need full reduce.
|
// The partial sums are spread across 8/4 threads each, does not need full reduce.
|
||||||
{
|
{
|
||||||
@@ -1222,7 +1269,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
|
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(sinks_f);
|
||||||
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
||||||
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
||||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
|
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
|
||||||
@@ -1239,6 +1286,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
const float scale,
|
const float scale,
|
||||||
@@ -1323,6 +1371,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||||
|
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
|
||||||
|
|
||||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||||
|
|
||||||
@@ -1335,12 +1384,12 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
if (kb0_start == 0) {
|
if (kb0_start == 0) {
|
||||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1362,6 +1411,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||||
|
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
|
||||||
|
|
||||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||||
|
|
||||||
@@ -1373,7 +1423,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||||
constexpr bool needs_fixup = false;
|
constexpr bool needs_fixup = false;
|
||||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||||
@@ -1535,7 +1585,8 @@ static void launch_fattn_new_mma(
|
|||||||
const ggml_tensor * K = dst->src[1];
|
const ggml_tensor * K = dst->src[1];
|
||||||
const ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
const ggml_tensor * mask = dst->src[3];
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
const ggml_tensor * sinks = dst->src[4];
|
||||||
|
|
||||||
ggml_tensor * KQV = dst;
|
ggml_tensor * KQV = dst;
|
||||||
|
|
||||||
@@ -1709,6 +1760,7 @@ static void launch_fattn_new_mma(
|
|||||||
K_data,
|
K_data,
|
||||||
V_data,
|
V_data,
|
||||||
mask ? ((const char *) mask->data) : nullptr,
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
|
sinks ? ((const char *)sinks->data) : nullptr,
|
||||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||||
scale, max_bias, m0, m1, logit_softcap, n_head_log2,
|
scale, max_bias, m0, m1, logit_softcap, n_head_log2,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
@@ -1853,6 +1905,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
|||||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||||
|
|
||||||
|
if (use_gqa_opt && gqa_ratio % 16 == 0) {
|
||||||
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 16>(ctx, dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
|
||||||
return;
|
return;
|
||||||
@@ -1878,8 +1935,6 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
|||||||
const ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
const ggml_tensor * mask = dst->src[3];
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
|
||||||
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
|
|
||||||
|
|
||||||
float max_bias = 0.0f;
|
float max_bias = 0.0f;
|
||||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
@@ -1888,6 +1943,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
|||||||
|
|
||||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||||
|
|
||||||
|
if (K->ne[0] == 64 && V->ne[0] == 64) {
|
||||||
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
|
||||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
const float scale,
|
const float scale,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
const float scale,
|
const float scale,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
const float scale,
|
const float scale,
|
||||||
@@ -71,6 +72,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
V += nb22*(blockIdx.y / gqa_ratio);
|
V += nb22*(blockIdx.y / gqa_ratio);
|
||||||
|
|
||||||
const half * maskh = (const half *) mask + ne11*ic0;
|
const half * maskh = (const half *) mask + ne11*ic0;
|
||||||
|
const float * sinksf = (const float *) (sinks);
|
||||||
|
|
||||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||||
const half slopeh = __float2half(slopef);
|
const half slopeh = __float2half(slopef);
|
||||||
@@ -270,6 +272,39 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sinksf) {
|
||||||
|
const half sink = __float2half(sinksf[blockIdx.y]);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
half kqmax_new_j = kqmax_shared[j][threadIdx.y];
|
||||||
|
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||||
|
|
||||||
|
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
||||||
|
kqmax[j] = kqmax_new_j;
|
||||||
|
|
||||||
|
const half val = hexp(sink - kqmax[j]);
|
||||||
|
kqsum[j] = kqsum[j]*KQ_max_scale;
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
kqsum[j] += val;
|
||||||
|
}
|
||||||
|
|
||||||
|
VKQ[j] *= __half2half2(KQ_max_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
const float scale,
|
const float scale,
|
||||||
@@ -69,6 +70,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
K += nb12*(blockIdx.y / gqa_ratio);
|
K += nb12*(blockIdx.y / gqa_ratio);
|
||||||
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
|
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
|
||||||
const half * maskh = (const half *) mask + ne11*ic0;
|
const half * maskh = (const half *) mask + ne11*ic0;
|
||||||
|
const float * sinksf = (const float *) (sinks);
|
||||||
|
|
||||||
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||||
|
|
||||||
@@ -254,6 +256,39 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sinksf) {
|
||||||
|
const float sink = sinksf[blockIdx.y];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
float kqmax_new_j = kqmax_shared[j][threadIdx.y];
|
||||||
|
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||||
|
|
||||||
|
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
|
||||||
|
kqmax[j] = kqmax_new_j;
|
||||||
|
|
||||||
|
const float val = expf(sink - kqmax[j]);
|
||||||
|
kqsum[j] = kqsum[j]*KQ_max_scale;
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
kqsum[j] += val;
|
||||||
|
}
|
||||||
|
|
||||||
|
VKQ[j] *= KQ_max_scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
||||||
|
|||||||
@@ -5,6 +5,8 @@
|
|||||||
// SPDX-License-Identifier: MIT
|
// SPDX-License-Identifier: MIT
|
||||||
//
|
//
|
||||||
|
|
||||||
|
// TODO: attention sinks !!!
|
||||||
|
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "fattn-common.cuh"
|
#include "fattn-common.cuh"
|
||||||
|
|
||||||
@@ -22,6 +24,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
const char * __restrict__ V,
|
const char * __restrict__ V,
|
||||||
const char * __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
|
const char * __restrict__ sinks,
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
const float scale,
|
const float scale,
|
||||||
@@ -93,6 +96,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||||
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
||||||
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
||||||
|
const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr;
|
||||||
|
|
||||||
const int stride_Q = nb01 / sizeof(float);
|
const int stride_Q = nb01 / sizeof(float);
|
||||||
const int stride_K = nb11 / sizeof(half);
|
const int stride_K = nb11 / sizeof(half);
|
||||||
|
|||||||
@@ -539,7 +539,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// As mentioned above, the new new MMA is slower than then the new MMA.
|
// As mentioned above, the new-new MMA is slower then the new MMA.
|
||||||
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||||
//ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
//ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||||
}
|
}
|
||||||
|
|||||||
41
ggml/src/ggml-cuda/graph.cuh
Normal file
41
ggml/src/ggml-cuda/graph.cuh
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
struct ggml_graph_node_properties {
|
||||||
|
void * node_address;
|
||||||
|
ggml_op node_op;
|
||||||
|
int64_t ne[GGML_MAX_DIMS];
|
||||||
|
size_t nb[GGML_MAX_DIMS];
|
||||||
|
void * src_address[GGML_MAX_SRC];
|
||||||
|
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ggml_cuda_graph {
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
~ggml_cuda_graph() {
|
||||||
|
if (instance != nullptr) {
|
||||||
|
CUDA_CHECK(cudaGraphExecDestroy(instance));
|
||||||
|
}
|
||||||
|
if (graph != nullptr) {
|
||||||
|
CUDA_CHECK(cudaGraphDestroy(graph));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cudaGraph_t graph = nullptr;
|
||||||
|
cudaGraphExec_t instance = nullptr;
|
||||||
|
size_t num_nodes = 0;
|
||||||
|
std::vector<cudaGraphNode_t> nodes;
|
||||||
|
std::vector<cudaKernelNodeParams> params;
|
||||||
|
bool disable_due_to_gpu_arch = false;
|
||||||
|
bool disable_due_to_too_many_updates = false;
|
||||||
|
bool disable_due_to_failed_graph_capture = false;
|
||||||
|
int number_consecutive_updates = 0;
|
||||||
|
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||||
|
bool use_cpy_indirection = false;
|
||||||
|
std::vector<char *> cpy_dest_ptrs;
|
||||||
|
char ** dest_ptrs_d;
|
||||||
|
int dest_ptrs_size = 0;
|
||||||
|
// Index to allow each cpy kernel to be aware of it's position within the graph
|
||||||
|
// relative to other cpy nodes.
|
||||||
|
int graph_cpynode_index = -1;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
@@ -19,7 +19,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
||||||
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) {
|
static __global__ void soft_max_f32_nosinks(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) {
|
||||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
@@ -124,7 +124,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
|
static void soft_max_f32_cuda_nosinks(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
const dim3 block_dims(nth, 1, 1);
|
const dim3 block_dims(nth, 1, 1);
|
||||||
@@ -142,39 +142,40 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
|||||||
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
||||||
switch (ncols_x) {
|
switch (ncols_x) {
|
||||||
case 32:
|
case 32:
|
||||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
soft_max_f32_nosinks<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||||
break;
|
break;
|
||||||
case 64:
|
case 64:
|
||||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
soft_max_f32_nosinks<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
soft_max_f32_nosinks<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||||
break;
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
soft_max_f32_nosinks<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||||
break;
|
break;
|
||||||
case 512:
|
case 512:
|
||||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
soft_max_f32_nosinks<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||||
break;
|
break;
|
||||||
case 1024:
|
case 1024:
|
||||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
soft_max_f32_nosinks<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||||
break;
|
break;
|
||||||
case 2048:
|
case 2048:
|
||||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
soft_max_f32_nosinks<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||||
break;
|
break;
|
||||||
case 4096:
|
case 4096:
|
||||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
soft_max_f32_nosinks<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
soft_max_f32_nosinks<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
||||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
soft_max_f32_nosinks<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const ggml_tensor * src1 = dst->src[1];
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
@@ -205,13 +206,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
if (use_f16) {
|
if (use_f16) {
|
||||||
const half * src1_dd = (const half *)src1_d;
|
const half * src1_dd = (const half *)src1_d;
|
||||||
|
|
||||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
|
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
|
||||||
} else {
|
} else {
|
||||||
const float * src1_dd = (const float *)src1_d;
|
const float * src1_dd = (const float *)src1_d;
|
||||||
|
|
||||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
|
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
@@ -241,10 +243,283 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds
|
|||||||
if (use_f16) {
|
if (use_f16) {
|
||||||
const half * src1_dd = (const half *)src1_d;
|
const half * src1_dd = (const half *)src1_d;
|
||||||
|
|
||||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
|
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
|
||||||
} else {
|
} else {
|
||||||
const float * src1_dd = (const float *)src1_d;
|
const float * src1_dd = (const float *)src1_d;
|
||||||
|
|
||||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
|
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct soft_max_params {
|
||||||
|
|
||||||
|
int64_t nheads;
|
||||||
|
uint32_t n_head_log2;
|
||||||
|
int64_t ncols;
|
||||||
|
int64_t nrows_x;
|
||||||
|
int64_t nrows_y;
|
||||||
|
int64_t ne00;
|
||||||
|
int64_t ne01;
|
||||||
|
int64_t ne02;
|
||||||
|
int64_t ne03;
|
||||||
|
int64_t nb11;
|
||||||
|
int64_t nb12;
|
||||||
|
int64_t nb13;
|
||||||
|
|
||||||
|
int64_t ne12;
|
||||||
|
int64_t ne13;
|
||||||
|
float scale;
|
||||||
|
float max_bias;
|
||||||
|
float m0;
|
||||||
|
float m1;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef __clang__
|
||||||
|
#pragma clang diagnostic push
|
||||||
|
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||||
|
#endif // __clang__
|
||||||
|
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
||||||
|
static __global__ void soft_max_f32(
|
||||||
|
const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
|
||||||
|
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
|
const int64_t i03 = blockIdx.z;
|
||||||
|
const int64_t i02 = blockIdx.y;
|
||||||
|
const int64_t i01 = blockIdx.x;
|
||||||
|
|
||||||
|
//TODO: noncontigous inputs/outputs
|
||||||
|
const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
|
||||||
|
|
||||||
|
const int64_t i11 = i01;
|
||||||
|
const int64_t i12 = i02 % p.ne12;
|
||||||
|
const int64_t i13 = i03 % p.ne13;
|
||||||
|
|
||||||
|
x += int64_t(rowx)*ncols;
|
||||||
|
mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
|
||||||
|
dst += int64_t(rowx)*ncols;
|
||||||
|
|
||||||
|
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
||||||
|
|
||||||
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
|
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
|
||||||
|
|
||||||
|
extern __shared__ float data_soft_max_f32[];
|
||||||
|
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||||
|
// shared memory buffer to cache values between iterations:
|
||||||
|
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
|
||||||
|
|
||||||
|
float max_val = sinks ? sinks[i02] : -INFINITY;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||||
|
const int col = col0 + tid;
|
||||||
|
|
||||||
|
if (ncols_template == 0 && col >= ncols) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
|
||||||
|
|
||||||
|
vals[col] = val;
|
||||||
|
max_val = max(max_val, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the max value in the block
|
||||||
|
max_val = warp_reduce_max(max_val);
|
||||||
|
if (block_size > WARP_SIZE) {
|
||||||
|
if (warp_id == 0) {
|
||||||
|
buf_iw[lane_id] = -INFINITY;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (lane_id == 0) {
|
||||||
|
buf_iw[warp_id] = max_val;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
max_val = buf_iw[lane_id];
|
||||||
|
max_val = warp_reduce_max(max_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
float tmp = 0.0f; // partial sum
|
||||||
|
#pragma unroll
|
||||||
|
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||||
|
const int col = col0 + tid;
|
||||||
|
|
||||||
|
if (ncols_template == 0 && col >= ncols) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float val = expf(vals[col] - max_val);
|
||||||
|
tmp += val;
|
||||||
|
vals[col] = val;
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the sum of exps in the block
|
||||||
|
tmp = warp_reduce_sum(tmp);
|
||||||
|
if (block_size > WARP_SIZE) {
|
||||||
|
__syncthreads();
|
||||||
|
if (warp_id == 0) {
|
||||||
|
buf_iw[lane_id] = 0.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (lane_id == 0) {
|
||||||
|
buf_iw[warp_id] = tmp;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
tmp = buf_iw[lane_id];
|
||||||
|
tmp = warp_reduce_sum(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sinks) {
|
||||||
|
tmp += expf(sinks[i02] - max_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float inv_sum = 1.0f / tmp;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||||
|
const int col = col0 + tid;
|
||||||
|
|
||||||
|
if (ncols_template == 0 && col >= ncols) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[col] = vals[col] * inv_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#ifdef __clang__
|
||||||
|
#pragma clang diagnostic pop
|
||||||
|
#endif // __clang__
|
||||||
|
|
||||||
|
template<int... Ns, typename T>
|
||||||
|
static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
|
||||||
|
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
||||||
|
{
|
||||||
|
const int id = ggml_cuda_get_device();
|
||||||
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||||
|
|
||||||
|
auto launch_kernel = [=](auto I) -> bool {
|
||||||
|
constexpr int ncols = decltype(I)::value;
|
||||||
|
constexpr int block = (ncols > 1024 ? 1024 : ncols);
|
||||||
|
|
||||||
|
if (p.ncols == ncols) {
|
||||||
|
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
||||||
|
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
|
(x, mask, sinks, dst, p);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// unary fold over launch_kernel
|
||||||
|
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
//default case
|
||||||
|
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
||||||
|
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
||||||
|
int nth = WARP_SIZE;
|
||||||
|
const int64_t ncols_x = params.ncols;
|
||||||
|
|
||||||
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
|
const dim3 block_dims(nth, 1, 1);
|
||||||
|
const dim3 block_nums(params.ne01, params.ne02, params.ne03);
|
||||||
|
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
||||||
|
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||||
|
|
||||||
|
|
||||||
|
const int id = ggml_cuda_get_device();
|
||||||
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||||
|
|
||||||
|
|
||||||
|
if (nbytes_shared <= smpbo) {
|
||||||
|
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
||||||
|
} else {
|
||||||
|
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
||||||
|
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
const ggml_tensor * src2 = dst->src[2];
|
||||||
|
|
||||||
|
const float * src0_d = (const float *) src0->data;
|
||||||
|
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
|
||||||
|
const void * src2_d = src2 ? (const void *) src2->data : nullptr;
|
||||||
|
float * dst_d = (float *) dst->data;
|
||||||
|
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||||
|
|
||||||
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
|
const int64_t nrows_y = src0->ne[1];
|
||||||
|
|
||||||
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
|
||||||
|
float scale = 1.0f;
|
||||||
|
float max_bias = 0.0f;
|
||||||
|
|
||||||
|
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
||||||
|
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
|
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
||||||
|
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
||||||
|
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
||||||
|
|
||||||
|
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
||||||
|
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
||||||
|
|
||||||
|
const uint32_t n_head = src0->ne[2];
|
||||||
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
|
|
||||||
|
soft_max_params params = {};
|
||||||
|
params.nheads = src0->ne[2];
|
||||||
|
params.n_head_log2 = n_head_log2;
|
||||||
|
params.ncols = ne00;
|
||||||
|
params.nrows_x = nrows_x;
|
||||||
|
params.nrows_y = nrows_y;
|
||||||
|
params.ne00 = src0->ne[0];
|
||||||
|
params.ne01 = src0->ne[1];
|
||||||
|
params.ne02 = src0->ne[2];
|
||||||
|
params.ne03 = src0->ne[3];
|
||||||
|
params.nb11 = nb11;
|
||||||
|
params.nb12 = nb12;
|
||||||
|
params.nb13 = nb13;
|
||||||
|
params.ne12 = ne12;
|
||||||
|
params.ne13 = ne13;
|
||||||
|
params.scale = scale;
|
||||||
|
params.max_bias = max_bias;
|
||||||
|
params.m0 = m0;
|
||||||
|
params.m1 = m1;
|
||||||
|
|
||||||
|
if (use_f16) {
|
||||||
|
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||||
|
} else {
|
||||||
|
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -470,3 +470,83 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
|
|
||||||
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
|
||||||
|
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// perform base op and multiply with gate (either offset in same tensor or a separate one)
|
||||||
|
const int64_t j0 = (i / n) * o0 + (i % n);
|
||||||
|
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
||||||
|
|
||||||
|
float xi = x[j0];
|
||||||
|
float gi = g[j1];
|
||||||
|
xi = fminf(xi, limit);
|
||||||
|
gi = fmaxf(fminf(gi, limit), -limit);
|
||||||
|
|
||||||
|
float out_glu = xi / (1.0f + expf(-xi * alpha));
|
||||||
|
out_glu = out_glu * (1.0f + gi);
|
||||||
|
|
||||||
|
dst[i] = out_glu;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
|
||||||
|
const int64_t num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
||||||
|
swiglu_oai_kernel<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
void * src0_d = src0->data;
|
||||||
|
void * src1_d = src1 ? src1->data : src0->data;
|
||||||
|
const int64_t src0_o = src0->nb[1];
|
||||||
|
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
void * dst_d = dst->data;
|
||||||
|
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src0->type == dst->type);
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
|
||||||
|
GGML_ASSERT(src1->ne[0] == nc);
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
//const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
||||||
|
const int32_t swapped = false; //ggml_get_op_params_i32(dst, 1);
|
||||||
|
const float * op_params = (const float *)dst->op_params;
|
||||||
|
const float alpha = op_params[2];
|
||||||
|
const float limit = op_params[3];
|
||||||
|
|
||||||
|
float * src0_p = (float *) src0_d;
|
||||||
|
float * src1_p = (float *) src1_d;
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc,
|
||||||
|
src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n,
|
||||||
|
const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
|
||||||
|
swiglu_oai_cuda(x, g, dst, k, n, o0, o1, alpha, limit, stream);
|
||||||
|
}
|
||||||
|
|||||||
@@ -47,3 +47,9 @@ void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
|
|||||||
int64_t nelements, const float * x, const float * y, float * z);
|
int64_t nelements, const float * x, const float * y, float * z);
|
||||||
|
|
||||||
void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n,
|
||||||
|
const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream);
|
||||||
|
|
||||||
|
|||||||
509
ggml/src/ggml.c
509
ggml/src/ggml.c
@@ -2823,7 +2823,6 @@ inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t
|
|||||||
|
|
||||||
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
||||||
|
|
||||||
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
|
|
||||||
inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
|
inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
|
||||||
inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
|
inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
|
||||||
inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
|
inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
|
||||||
@@ -2834,6 +2833,19 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)
|
|||||||
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
|
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
|
||||||
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
|
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
|
||||||
|
|
||||||
|
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
|
||||||
|
int i = 0;
|
||||||
|
#if defined(__AVX2__)
|
||||||
|
for (; i + 7 < n; i += 8) {
|
||||||
|
__m256 vx = _mm256_loadu_ps(x + i);
|
||||||
|
__m256 vy = _mm256_loadu_ps(y + i);
|
||||||
|
__m256 vz = _mm256_add_ps(vx, vy);
|
||||||
|
_mm256_storeu_ps(z + i, vz);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; i < n; ++i) z[i] = x[i] + y[i];
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
|
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
|
||||||
assert(nrc == 1);
|
assert(nrc == 1);
|
||||||
UNUSED(nrc);
|
UNUSED(nrc);
|
||||||
@@ -4004,6 +4016,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||||||
|
|
||||||
"DUP",
|
"DUP",
|
||||||
"ADD",
|
"ADD",
|
||||||
|
"ADD_ID",
|
||||||
"ADD1",
|
"ADD1",
|
||||||
"ACC",
|
"ACC",
|
||||||
"SUB",
|
"SUB",
|
||||||
@@ -4092,13 +4105,14 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
|
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
|
||||||
"x",
|
"x",
|
||||||
"x+y",
|
"x+y",
|
||||||
|
"x[i]+y",
|
||||||
"x+y",
|
"x+y",
|
||||||
"view(x,nb,offset)+=y->x",
|
"view(x,nb,offset)+=y->x",
|
||||||
"x-y",
|
"x-y",
|
||||||
@@ -4187,7 +4201,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
|
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
@@ -4207,9 +4221,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
|||||||
"HARDSWISH",
|
"HARDSWISH",
|
||||||
"HARDSIGMOID",
|
"HARDSIGMOID",
|
||||||
"SWIGLU",
|
"SWIGLU",
|
||||||
|
"SWIGLU_OAI",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
|
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
|
||||||
|
|
||||||
|
|
||||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||||
@@ -5917,6 +5932,29 @@ struct ggml_tensor * ggml_add_cast(
|
|||||||
return ggml_add_cast_impl(ctx, a, b, type);
|
return ggml_add_cast_impl(ctx, a, b, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_add_id
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_add_id(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
struct ggml_tensor * ids) {
|
||||||
|
|
||||||
|
GGML_ASSERT(a->ne[0] == b->ne[0]);
|
||||||
|
GGML_ASSERT(a->ne[1] == ids->ne[0]);
|
||||||
|
GGML_ASSERT(a->ne[2] == ids->ne[1]);
|
||||||
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
|
result->op = GGML_OP_ADD_ID;
|
||||||
|
result->src[0] = a;
|
||||||
|
result->src[1] = b;
|
||||||
|
result->src[2] = ids;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_add1
|
// ggml_add1
|
||||||
|
|
||||||
static struct ggml_tensor * ggml_add1_impl(
|
static struct ggml_tensor * ggml_add1_impl(
|
||||||
@@ -6662,6 +6700,36 @@ struct ggml_tensor * ggml_swiglu(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_swiglu_oai(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
float alpha,
|
||||||
|
float limit) {
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(a));
|
||||||
|
if (b) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(b));
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(a, b));
|
||||||
|
GGML_ASSERT(a->type == b->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t ne[4] = {a->ne[0]/2, a->ne[1], a->ne[2], a->ne[3]};
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
|
||||||
|
|
||||||
|
result->op = GGML_OP_UNARY;
|
||||||
|
result->grad = NULL;
|
||||||
|
result->src[0] = a;
|
||||||
|
result->src[1] = b;
|
||||||
|
|
||||||
|
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU_OAI);
|
||||||
|
ggml_set_op_params_f32(result, 2, alpha);
|
||||||
|
ggml_set_op_params_f32(result, 3, limit);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_silu_back
|
// ggml_silu_back
|
||||||
|
|
||||||
struct ggml_tensor * ggml_silu_back(
|
struct ggml_tensor * ggml_silu_back(
|
||||||
@@ -7017,6 +7085,66 @@ struct ggml_tensor * ggml_moe_up_gate(
|
|||||||
result->src[1] = as_gate;
|
result->src[1] = as_gate;
|
||||||
result->src[2] = b;
|
result->src[2] = b;
|
||||||
result->src[3] = ids;
|
result->src[3] = ids;
|
||||||
|
result->src[4] = NULL;
|
||||||
|
result->src[5] = NULL;
|
||||||
|
|
||||||
|
ggml_set_op_params_i32(result, 0, (int32_t) op);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_moe_up_gate_ext(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * as_up,
|
||||||
|
struct ggml_tensor * as_gate,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
struct ggml_tensor * ids,
|
||||||
|
struct ggml_tensor * as_up_b,
|
||||||
|
struct ggml_tensor * as_gate_b,
|
||||||
|
enum ggml_unary_op op) {
|
||||||
|
|
||||||
|
if (!as_up_b && !as_gate_b) {
|
||||||
|
return ggml_moe_up_gate(ctx, as_up, as_gate, b, ids, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) {
|
||||||
|
struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids);
|
||||||
|
if (as_up_b) {
|
||||||
|
result_up = ggml_add_id(ctx, result_up, as_up_b, ids);
|
||||||
|
}
|
||||||
|
struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids);
|
||||||
|
if (as_gate_b) {
|
||||||
|
result_gate = ggml_add_id(ctx, result_gate, as_gate_b, ids);
|
||||||
|
}
|
||||||
|
return ggml_fused_mul_unary(ctx, result_gate, result_up, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(!ggml_is_transposed(as_up));
|
||||||
|
GGML_ASSERT(!ggml_is_transposed(as_gate));
|
||||||
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert)
|
||||||
|
GGML_ASSERT(b->ne[3] == 1); // b is 3d
|
||||||
|
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
|
||||||
|
GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
|
||||||
|
GGML_ASSERT(as_up->ne[0] == b->ne[0]); // can_mul_mat
|
||||||
|
GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
|
||||||
|
|
||||||
|
GGML_ASSERT(as_up->ne[1] == as_up_b->ne[0]);
|
||||||
|
GGML_ASSERT(as_gate->ne[1] == as_gate_b->ne[0]);
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 };
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||||
|
|
||||||
|
result->op = GGML_OP_MOE_FUSED_UP_GATE;
|
||||||
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
result->src[0] = as_up;
|
||||||
|
result->src[1] = as_gate;
|
||||||
|
result->src[2] = b;
|
||||||
|
result->src[3] = ids;
|
||||||
|
result->src[4] = as_up_b;
|
||||||
|
result->src[5] = as_gate_b;
|
||||||
|
|
||||||
ggml_set_op_params_i32(result, 0, (int32_t) op);
|
ggml_set_op_params_i32(result, 0, (int32_t) op);
|
||||||
|
|
||||||
@@ -7970,6 +8098,22 @@ struct ggml_tensor * ggml_soft_max_ext(
|
|||||||
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
|
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_soft_max_add_sinks(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * sinks) {
|
||||||
|
if (!sinks) {
|
||||||
|
a->src[2] = NULL;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(a->op == GGML_OP_SOFT_MAX);
|
||||||
|
GGML_ASSERT(a->src[2] == NULL);
|
||||||
|
GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
|
||||||
|
GGML_ASSERT(sinks->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
a->src[2] = sinks;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_soft_max_back
|
// ggml_soft_max_back
|
||||||
|
|
||||||
static struct ggml_tensor * ggml_soft_max_back_impl(
|
static struct ggml_tensor * ggml_soft_max_back_impl(
|
||||||
@@ -8833,6 +8977,22 @@ void ggml_flash_attn_ext_set_prec(
|
|||||||
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_flash_attn_ext_add_sinks(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * sinks) {
|
||||||
|
if (!sinks) {
|
||||||
|
a->src[4] = NULL;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
||||||
|
GGML_ASSERT(a->src[4] == NULL);
|
||||||
|
GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
|
||||||
|
GGML_ASSERT(sinks->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
a->src[4] = sinks;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_flash_attn_back
|
// ggml_flash_attn_back
|
||||||
|
|
||||||
struct ggml_tensor * ggml_flash_attn_back(
|
struct ggml_tensor * ggml_flash_attn_back(
|
||||||
@@ -11497,6 +11657,77 @@ static void ggml_compute_forward_multi_add(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_add_id
|
||||||
|
|
||||||
|
static void ggml_compute_forward_add_id_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
const struct ggml_tensor * src1 = dst->src[1];
|
||||||
|
const struct ggml_tensor * src2 = dst->src[2];
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||||
|
|
||||||
|
GGML_ASSERT( nb0 == sizeof(float));
|
||||||
|
GGML_ASSERT(nb10 == sizeof(float));
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
|
// src0 indices
|
||||||
|
const int i3 = ir/(ne2*ne1);
|
||||||
|
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
||||||
|
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||||
|
|
||||||
|
// src1 indices
|
||||||
|
const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
|
||||||
|
|
||||||
|
GGML_ASSERT(i11 >= 0 && i11 < ne11);
|
||||||
|
|
||||||
|
ggml_vec_add_f32(ne0,
|
||||||
|
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
|
||||||
|
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
|
||||||
|
(float *) ((char *) src1->data + i11*nb11));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_add_id(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_add_id_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_add1
|
// ggml_compute_forward_add1
|
||||||
|
|
||||||
static void ggml_compute_forward_add1_f32(
|
static void ggml_compute_forward_add1_f32(
|
||||||
@@ -13760,6 +13991,93 @@ static void ggml_compute_forward_swiglu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_swiglu_oai
|
||||||
|
|
||||||
|
static void ggml_compute_forward_swiglu_oai_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
const struct ggml_tensor * src1 = dst->src[1];
|
||||||
|
char * src0_d = (char *) src0->data;
|
||||||
|
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||||
|
const size_t src0_o = src0->nb[1];
|
||||||
|
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = false; //ggml_get_op_params_i32(dst, 1);
|
||||||
|
const float alpha = ggml_get_op_params_f32(dst, 2);
|
||||||
|
const float limit = ggml_get_op_params_f32(dst, 3);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
float * src0_p = (float *) (src0_d + i1*src0_o);
|
||||||
|
float * src1_p = (float *) (src1_d + i1*src1_o);
|
||||||
|
float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const float x = MIN(src0_p[k], limit);
|
||||||
|
const float y = MAX(MIN(src1_p[k], limit), -limit);
|
||||||
|
const float out_glu = x / (1.f + expf(alpha * (-x)));
|
||||||
|
dst_p[k] = out_glu * (y + 1.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const float x = dst_p[k];
|
||||||
|
GGML_UNUSED(x);
|
||||||
|
assert(!isnan(x));
|
||||||
|
assert(!isinf(x));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_swiglu_oai(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_swiglu_oai_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_fused_mul_unary
|
// ggml_compute_forward_fused_mul_unary
|
||||||
|
|
||||||
static void ggml_compute_forward_fused_mul_unary_f32(
|
static void ggml_compute_forward_fused_mul_unary_f32(
|
||||||
@@ -15167,6 +15485,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
|
|||||||
|
|
||||||
const struct ggml_tensor * src1 = dst->src[2];
|
const struct ggml_tensor * src1 = dst->src[2];
|
||||||
const struct ggml_tensor * ids = dst->src[3];
|
const struct ggml_tensor * ids = dst->src[3];
|
||||||
|
const struct ggml_tensor * up_b = dst->src[4];
|
||||||
|
const struct ggml_tensor * gate_b = dst->src[5];
|
||||||
const struct ggml_tensor * src0_1 = dst->src[0];
|
const struct ggml_tensor * src0_1 = dst->src[0];
|
||||||
const struct ggml_tensor * src0_2 = dst->src[1];
|
const struct ggml_tensor * src0_2 = dst->src[1];
|
||||||
const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works
|
const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works
|
||||||
@@ -15191,6 +15511,9 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
|
|||||||
GGML_ASSERT(nb2 <= nb3);
|
GGML_ASSERT(nb2 <= nb3);
|
||||||
GGML_ASSERT(ne13 == 1);
|
GGML_ASSERT(ne13 == 1);
|
||||||
|
|
||||||
|
const size_t nb41 = up_b ? up_b->nb[1] : 0;
|
||||||
|
const size_t nb51 = up_b ? gate_b->nb[1] : 0;
|
||||||
|
|
||||||
// row groups
|
// row groups
|
||||||
const int n_ids = ids->ne[0]; // n_expert_used
|
const int n_ids = ids->ne[0]; // n_expert_used
|
||||||
const int n_as = ne02; // n_expert
|
const int n_as = ne02; // n_expert
|
||||||
@@ -15278,6 +15601,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
|
|||||||
|
|
||||||
const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02;
|
const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02;
|
||||||
const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02;
|
const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02;
|
||||||
|
const char * up_b_cur = up_b ? (const char *)up_b->data + cur_a*nb41 : NULL;
|
||||||
|
const char * gate_b_cur = gate_b ? (const char *)gate_b->data + cur_a*nb51 : NULL;
|
||||||
|
|
||||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||||
@@ -15288,6 +15613,7 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
|
|||||||
if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0],
|
if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0],
|
||||||
type, src0_1_cur, src0_2_cur, nb01,
|
type, src0_1_cur, src0_2_cur, nb01,
|
||||||
vec_dot_type, (const char *)wdata, row_size,
|
vec_dot_type, (const char *)wdata, row_size,
|
||||||
|
up_b_cur, gate_b_cur,
|
||||||
(float *)dst->data, nb1, nb2,
|
(float *)dst->data, nb1, nb2,
|
||||||
matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
|
matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
|
||||||
|
|
||||||
@@ -16645,6 +16971,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|||||||
|
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
const struct ggml_tensor * src1 = dst->src[1];
|
const struct ggml_tensor * src1 = dst->src[1];
|
||||||
|
const struct ggml_tensor * src2 = dst->src[2];
|
||||||
|
|
||||||
assert(ggml_is_contiguous(dst));
|
assert(ggml_is_contiguous(dst));
|
||||||
assert(ggml_are_same_shape(src0, dst));
|
assert(ggml_are_same_shape(src0, dst));
|
||||||
@@ -16662,6 +16989,13 @@ static void ggml_compute_forward_soft_max_f32(
|
|||||||
|
|
||||||
GGML_TENSOR_UNARY_OP_LOCALS
|
GGML_TENSOR_UNARY_OP_LOCALS
|
||||||
|
|
||||||
|
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
||||||
|
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
||||||
|
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
||||||
|
|
||||||
|
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
||||||
|
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
||||||
|
|
||||||
//const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
//const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
||||||
|
|
||||||
// TODO: is this supposed to be ceil instead of floor?
|
// TODO: is this supposed to be ceil instead of floor?
|
||||||
@@ -16673,67 +17007,80 @@ static void ggml_compute_forward_soft_max_f32(
|
|||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
const int nc = src0->ne[0];
|
const int nc = src0->ne[0];
|
||||||
const int nr = ggml_nrows(src0);
|
|
||||||
|
|
||||||
// rows per thread
|
|
||||||
const int dr = (nr + nth - 1)/nth;
|
|
||||||
|
|
||||||
// row range for this thread
|
|
||||||
const int ir0 = dr*ith;
|
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
|
||||||
|
|
||||||
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
||||||
|
|
||||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
// sinks
|
||||||
// ALiBi
|
const float * sk = src2 ? (float *)((char *) src2->data) : NULL;
|
||||||
const uint32_t h = (i1/ne01)%ne02; // head
|
|
||||||
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
|
||||||
|
|
||||||
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||||
|
const int64_t i11 = i01;
|
||||||
|
const int64_t i12 = i02%ne12;
|
||||||
|
const int64_t i13 = i03%ne13;
|
||||||
|
|
||||||
// broadcast the mask across rows
|
// ALiBi
|
||||||
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
const uint32_t h = i02; // head
|
||||||
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||||
|
|
||||||
ggml_vec_cpy_f32 (nc, wp, sp);
|
float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
ggml_vec_scale_f32(nc, wp, scale);
|
float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||||
if (mp_f32) {
|
|
||||||
if (use_f16) {
|
// broadcast the mask across rows
|
||||||
for (int i = 0; i < nc; ++i) {
|
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
||||||
wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
|
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
||||||
|
|
||||||
|
ggml_vec_cpy_f32 (ne00, wp, sp);
|
||||||
|
ggml_vec_scale_f32(ne00, wp, scale);
|
||||||
|
if (mp_f32) {
|
||||||
|
if (use_f16) {
|
||||||
|
for (int i = 0; i < ne00; ++i) {
|
||||||
|
wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < ne00; ++i) {
|
||||||
|
wp[i] += slope*mp_f32[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
for (int i = 0; i < nc; ++i) {
|
#ifndef NDEBUG
|
||||||
wp[i] += slope*mp_f32[i];
|
for (int i = 0; i < ne00; ++i) {
|
||||||
|
//printf("p[%d] = %f\n", i, p[i]);
|
||||||
|
assert(!isnan(wp[i]));
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
float max = -INFINITY;
|
||||||
|
ggml_vec_max_f32(ne00, &max, wp);
|
||||||
|
|
||||||
|
// if we have sinks, make a correction as if they were included in the softmax
|
||||||
|
if (sk) {
|
||||||
|
max = MAX(max, sk[i02]);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
||||||
|
assert(sum > 0.0);
|
||||||
|
|
||||||
|
if (sk) {
|
||||||
|
sum += (ggml_float) expf(sk[i02] - max);
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = 1.0/sum;
|
||||||
|
ggml_vec_scale_f32(ne00, dp, sum);
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int i = 0; i < ne00; ++i) {
|
||||||
|
assert(!isnan(dp[i]));
|
||||||
|
assert(!isinf(dp[i]));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//#ifndef NDEBUG
|
|
||||||
// for (int i = 0; i < nc; ++i) {
|
|
||||||
// //printf("p[%d] = %f\n", i, p[i]);
|
|
||||||
// assert(!isnan(wp[i]));
|
|
||||||
// }
|
|
||||||
//#endif
|
|
||||||
|
|
||||||
float max = -INFINITY;
|
|
||||||
ggml_vec_max_f32(nc, &max, wp);
|
|
||||||
|
|
||||||
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
|
|
||||||
//assert(sum > 0.0);
|
|
||||||
|
|
||||||
sum = 1.0/sum;
|
|
||||||
ggml_vec_scale_f32(nc, dp, sum);
|
|
||||||
|
|
||||||
//#ifndef NDEBUG
|
|
||||||
// for (int i = 0; i < nc; ++i) {
|
|
||||||
// assert(!isnan(dp[i]));
|
|
||||||
// assert(!isinf(dp[i]));
|
|
||||||
// }
|
|
||||||
//#endif
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -16755,7 +17102,6 @@ static void ggml_compute_forward_soft_max(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// ggml_compute_forward_soft_max_back
|
// ggml_compute_forward_soft_max_back
|
||||||
|
|
||||||
static void ggml_compute_forward_soft_max_back_f32(
|
static void ggml_compute_forward_soft_max_back_f32(
|
||||||
@@ -18308,12 +18654,14 @@ static void ggml_compute_forward_argsort_thresh(
|
|||||||
|
|
||||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * q,
|
|
||||||
const struct ggml_tensor * k,
|
|
||||||
const struct ggml_tensor * v,
|
|
||||||
const struct ggml_tensor * mask,
|
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * q = dst->src[0];
|
||||||
|
const struct ggml_tensor * k = dst->src[1];
|
||||||
|
const struct ggml_tensor * v = dst->src[2];
|
||||||
|
const struct ggml_tensor * mask = dst->src[3];
|
||||||
|
const struct ggml_tensor * sinks = dst->src[4];
|
||||||
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
||||||
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
||||||
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
||||||
@@ -18383,6 +18731,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if GGML_USE_IQK_MULMAT
|
#if GGML_USE_IQK_MULMAT
|
||||||
|
// For now we do not implement sinks in the iqk FA implementation
|
||||||
if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias,
|
if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias,
|
||||||
q->ne[3], q->ne[2], q->nb[3], q->nb[2],
|
q->ne[3], q->ne[2], q->nb[3], q->nb[2],
|
||||||
k->ne[3], k->ne[2], k->nb[3], k->nb[2],
|
k->ne[3], k->ne[2], k->nb[3], k->nb[2],
|
||||||
@@ -18390,7 +18739,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
dst->ne[2], dst->ne[1], dst->nb[1],
|
dst->ne[2], dst->ne[1], dst->nb[1],
|
||||||
k->type, v->type,
|
k->type, v->type,
|
||||||
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
|
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
|
||||||
q->data, k->data, v->data, mask->data,
|
q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL,
|
||||||
scale, softcap, (float *)dst->data,
|
scale, softcap, (float *)dst->data,
|
||||||
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return;
|
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return;
|
||||||
|
|
||||||
@@ -18447,6 +18796,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
||||||
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
||||||
|
|
||||||
|
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
||||||
|
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
||||||
|
|
||||||
const int64_t Dkv = MAX(Dk, Dv);
|
const int64_t Dkv = MAX(Dk, Dv);
|
||||||
|
|
||||||
// loop over n_batch and n_head
|
// loop over n_batch and n_head
|
||||||
@@ -18552,6 +18904,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sinks) {
|
||||||
|
const float s = ((float *)((char *) sinks->data))[h];
|
||||||
|
|
||||||
|
float ms = 1.0f;
|
||||||
|
float vs = 1.0f;
|
||||||
|
|
||||||
|
if (s > M) {
|
||||||
|
ms = expf(M - s);
|
||||||
|
ggml_vec_scale_f32(Dv, VKQ32, ms);
|
||||||
|
} else {
|
||||||
|
vs = expf(s - M);
|
||||||
|
}
|
||||||
|
|
||||||
|
S = S*ms + vs;
|
||||||
|
}
|
||||||
|
|
||||||
// V /= S
|
// V /= S
|
||||||
const float S_inv = 1.0f/S;
|
const float S_inv = 1.0f/S;
|
||||||
ggml_vec_scale_f32(Dv, VKQ32, S_inv);
|
ggml_vec_scale_f32(Dv, VKQ32, S_inv);
|
||||||
@@ -18571,17 +18939,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
|
|
||||||
static void ggml_compute_forward_flash_attn_ext(
|
static void ggml_compute_forward_flash_attn_ext(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * q,
|
|
||||||
const struct ggml_tensor * k,
|
|
||||||
const struct ggml_tensor * v,
|
|
||||||
const struct ggml_tensor * mask,
|
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (dst->op_params[3]) {
|
switch (dst->op_params[3]) {
|
||||||
case GGML_PREC_DEFAULT:
|
case GGML_PREC_DEFAULT:
|
||||||
case GGML_PREC_F32:
|
case GGML_PREC_F32:
|
||||||
{
|
{
|
||||||
// uses F32 accumulators
|
// uses F32 accumulators
|
||||||
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
|
ggml_compute_forward_flash_attn_ext_f16(params, dst);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
@@ -19350,6 +19714,10 @@ static void ggml_compute_forward_unary(
|
|||||||
{
|
{
|
||||||
ggml_compute_forward_swiglu(params, dst);
|
ggml_compute_forward_swiglu(params, dst);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_UNARY_OP_SWIGLU_OAI:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_swiglu_oai(params, dst);
|
||||||
|
} break;
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_hardswish(params, dst);
|
ggml_compute_forward_hardswish(params, dst);
|
||||||
@@ -19898,6 +20266,10 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||||||
{
|
{
|
||||||
ggml_compute_forward_add(params, tensor);
|
ggml_compute_forward_add(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_add_id(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_add1(params, tensor);
|
ggml_compute_forward_add1(params, tensor);
|
||||||
@@ -20136,7 +20508,7 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
ggml_compute_forward_flash_attn_ext(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN_BACK:
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
{
|
{
|
||||||
@@ -20486,6 +20858,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||||||
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error"); // TODO: implement
|
||||||
|
} break;
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
{
|
{
|
||||||
if (src0->grad) {
|
if (src0->grad) {
|
||||||
@@ -21719,6 +22095,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
case GGML_OP_MULTI_ADD:
|
case GGML_OP_MULTI_ADD:
|
||||||
@@ -21758,6 +22135,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
case GGML_UNARY_OP_SWIGLU:
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
|
case GGML_UNARY_OP_SWIGLU_OAI:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
@@ -21952,6 +22330,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_ADD_ID:
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
{
|
{
|
||||||
if (ggml_is_quantized(node->src[0]->type)) {
|
if (ggml_is_quantized(node->src[0]->type)) {
|
||||||
|
|||||||
@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_128_128) {
|
|||||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||||
if (nk%64 == 0) {
|
if (nk%64 == 0) {
|
||||||
iqk_flash_helper_T<128, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
iqk_flash_helper_T<128, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
iqk_flash_helper_T<128, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
iqk_flash_helper_T<128, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (nk%128 == 0) {
|
if (nk%128 == 0) {
|
||||||
return iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
if (nk%64 == 0) {
|
if (nk%64 == 0) {
|
||||||
return iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
return iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_192_128) {
|
|||||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||||
if (nk%64 == 0) {
|
if (nk%64 == 0) {
|
||||||
iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (nk%128 == 0) {
|
if (nk%128 == 0) {
|
||||||
return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
if (nk%64 == 0) {
|
if (nk%64 == 0) {
|
||||||
return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_256_256) {
|
|||||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||||
if (nk%64 == 0) {
|
if (nk%64 == 0) {
|
||||||
iqk_flash_helper_T<256, 256, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
iqk_flash_helper_T<256, 256, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
iqk_flash_helper_T<256, 256, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
iqk_flash_helper_T<256, 256, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (nk%128 == 0) {
|
if (nk%128 == 0) {
|
||||||
return iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
if (nk%64 == 0) {
|
if (nk%64 == 0) {
|
||||||
return iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
return iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ namespace {
|
|||||||
template <int step_k, typename KHelper, typename VHelper>
|
template <int step_k, typename KHelper, typename VHelper>
|
||||||
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
|
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
|
||||||
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||||
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
|
const float * q, const char * mask, float scale, float softcap, float * qkv,
|
||||||
|
const float * sinkf, float * M, float * S) {
|
||||||
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
|
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
|
||||||
nq1 -= n;
|
nq1 -= n;
|
||||||
if (nq1 == 0) return true;
|
if (nq1 == 0) return true;
|
||||||
@@ -21,29 +22,29 @@ inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
|
|||||||
};
|
};
|
||||||
if (nq1 >= 16) {
|
if (nq1 >= 16) {
|
||||||
int n_step = nq1/16;
|
int n_step = nq1/16;
|
||||||
FlashAttn<576, 512, 16, step_k> fa(scale, softcap);
|
FlashAttn<576, 512, 16, step_k> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||||
if (update(16*n_step)) return;
|
if (update(16*n_step)) return;
|
||||||
}
|
}
|
||||||
if (nq1 >= 8) {
|
if (nq1 >= 8) {
|
||||||
int n_step = nq1/8;
|
int n_step = nq1/8;
|
||||||
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
|
FlashAttn<576, 512, 8, step_k> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||||
if (update(8*n_step)) return;
|
if (update(8*n_step)) return;
|
||||||
}
|
}
|
||||||
if (nq1 >= 4) {
|
if (nq1 >= 4) {
|
||||||
int n_step = nq1/4;
|
int n_step = nq1/4;
|
||||||
FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
|
FlashAttn<576, 512, 4, step_k> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||||
if (update(4*n_step)) return;
|
if (update(4*n_step)) return;
|
||||||
}
|
}
|
||||||
if (nq1 >= 2) {
|
if (nq1 >= 2) {
|
||||||
int n_step = nq1/2;
|
int n_step = nq1/2;
|
||||||
FlashAttn<576, 512, 2, step_k> fa(scale, softcap);
|
FlashAttn<576, 512, 2, step_k> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||||
if (update(2*n_step)) return;
|
if (update(2*n_step)) return;
|
||||||
}
|
}
|
||||||
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
|
FlashAttn<576, 512, 1, step_k> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,37 +52,37 @@ template <int step_k>
|
|||||||
inline bool iqk_deepseek_helper(ggml_type type_k,
|
inline bool iqk_deepseek_helper(ggml_type type_k,
|
||||||
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||||
const float * q, const char * k, const char * v, const char * mask,
|
const float * q, const char * k, const char * v, const char * mask,
|
||||||
float scale, float softcap, float * qkv, float * M, float * S) {
|
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
|
||||||
if (type_k == GGML_TYPE_Q8_0) {
|
if (type_k == GGML_TYPE_Q8_0) {
|
||||||
HelperQ80 kh((const char *)k, stride_k);
|
HelperQ80 kh((const char *)k, stride_k);
|
||||||
HelperQ80 vh((const char *)v, stride_v);
|
HelperQ80 vh((const char *)v, stride_v);
|
||||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (type_k == GGML_TYPE_Q8_0_R8) {
|
if (type_k == GGML_TYPE_Q8_0_R8) {
|
||||||
HelperQ80R8<576> kh((const char *)k, stride_k);
|
HelperQ80R8<576> kh((const char *)k, stride_k);
|
||||||
HelperQ80 vh((const char *)v, stride_v);
|
HelperQ80 vh((const char *)v, stride_v);
|
||||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (type_k == GGML_TYPE_Q6_0) {
|
if (type_k == GGML_TYPE_Q6_0) {
|
||||||
HelperQ60 kh((const char *)k, stride_k);
|
HelperQ60 kh((const char *)k, stride_k);
|
||||||
HelperQ60 vh((const char *)v, stride_v);
|
HelperQ60 vh((const char *)v, stride_v);
|
||||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#if GGML_IQK_FA_ALL_QUANTS
|
#if GGML_IQK_FA_ALL_QUANTS
|
||||||
if (type_k == GGML_TYPE_Q8_KV) {
|
if (type_k == GGML_TYPE_Q8_KV) {
|
||||||
HelperQ8KV<576> kh((const char *)k, stride_k);
|
HelperQ8KV<576> kh((const char *)k, stride_k);
|
||||||
HelperQ8KV<512> vh((const char *)v, stride_v);
|
HelperQ8KV<512> vh((const char *)v, stride_v);
|
||||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
if (type_k == GGML_TYPE_F16) {
|
if (type_k == GGML_TYPE_F16) {
|
||||||
HelperF16 kh((const char *)k, stride_k);
|
HelperF16 kh((const char *)k, stride_k);
|
||||||
HelperF16 vh((const char *)v, stride_v);
|
HelperF16 vh((const char *)v, stride_v);
|
||||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#ifdef __AVX512BF16__
|
#ifdef __AVX512BF16__
|
||||||
@@ -89,10 +90,10 @@ inline bool iqk_deepseek_helper(ggml_type type_k,
|
|||||||
HelperBF16<576, step_k> kh((const char *)k, stride_k);
|
HelperBF16<576, step_k> kh((const char *)k, stride_k);
|
||||||
HelperBF16<512, step_k> vh((const char *)v, stride_v);
|
HelperBF16<512, step_k> vh((const char *)v, stride_v);
|
||||||
if (nq1 % 8 == 0) {
|
if (nq1 % 8 == 0) {
|
||||||
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap);
|
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
} else {
|
} else {
|
||||||
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap);
|
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@@ -113,7 +114,7 @@ IQK_FA_CASE(iqk_fa_576_512) {
|
|||||||
}
|
}
|
||||||
stride_q /= sizeof(float); // q stride as float
|
stride_q /= sizeof(float); // q stride as float
|
||||||
return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S);
|
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_64_64) {
|
|||||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||||
if (nk%64 == 0) {
|
if (nk%64 == 0) {
|
||||||
iqk_flash_helper_T<64, 64, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
iqk_flash_helper_T<64, 64, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
iqk_flash_helper_T<64, 64, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
iqk_flash_helper_T<64, 64, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (nk%128 == 0) {
|
if (nk%128 == 0) {
|
||||||
return iqk_flash_helper_T<64, 64, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<64, 64, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
if (nk%64 == 0) {
|
if (nk%64 == 0) {
|
||||||
return iqk_flash_helper_T<64, 64, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<64, 64, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
return iqk_flash_helper_T<64, 64, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<64, 64, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_96_96) {
|
|||||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||||
if (nk%64 == 0) {
|
if (nk%64 == 0) {
|
||||||
iqk_flash_helper_T<96, 96, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
iqk_flash_helper_T<96, 96, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
iqk_flash_helper_T<96, 96, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
iqk_flash_helper_T<96, 96, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (nk%128 == 0) {
|
if (nk%128 == 0) {
|
||||||
return iqk_flash_helper_T<96, 96, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<96, 96, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
if (nk%64 == 0) {
|
if (nk%64 == 0) {
|
||||||
return iqk_flash_helper_T<96, 96, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<96, 96, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
return iqk_flash_helper_T<96, 96, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_flash_helper_T<96, 96, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1141,10 +1141,25 @@ struct FlashQKV {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename FMS>
|
template <typename FMS>
|
||||||
inline void normalize_and_store_1row(const FMS& fms, int j, const qkv_cache_t * R, float * qkv) const {
|
inline void normalize_and_store_1row(const FMS& fms, int j, qkv_cache_t * R, float * qkv, const float * sinkf) const {
|
||||||
static_assert(q_step == FMS::q_step);
|
static_assert(q_step == FMS::q_step);
|
||||||
GGML_ASSERT(fms.S[j] > 0);
|
float S = fms.S[j];
|
||||||
auto norm = F16::set1(1/fms.S[j]);
|
if (sinkf) {
|
||||||
|
float s = *sinkf;
|
||||||
|
if (s > fms.M[j]) {
|
||||||
|
float m = expf(fms.M[j] - s);
|
||||||
|
auto vm = F16::set1(m);
|
||||||
|
for (int i = 0; i < D/F16::block_size; ++i) {
|
||||||
|
auto Ri = R + F16::block_size*i;
|
||||||
|
F16::store(Ri, F16::mul(vm, F16::load(Ri)));
|
||||||
|
}
|
||||||
|
S = S*m + 1;
|
||||||
|
} else {
|
||||||
|
S += expf(s - fms.M[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_ASSERT(S > 0);
|
||||||
|
auto norm = F16::set1(1/S);
|
||||||
//auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
|
//auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
|
||||||
for (int i = 0; i < D/F16::block_size; ++i) {
|
for (int i = 0; i < D/F16::block_size; ++i) {
|
||||||
auto r = F16::load(R + F16::block_size*i);
|
auto r = F16::load(R + F16::block_size*i);
|
||||||
@@ -1153,7 +1168,7 @@ struct FlashQKV {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename FMS>
|
template <typename FMS>
|
||||||
inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const {
|
inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, const float * sinkf, float * M, float * S) {
|
||||||
static_assert(q_step == FMS::q_step);
|
static_assert(q_step == FMS::q_step);
|
||||||
if (M && S) {
|
if (M && S) {
|
||||||
std::memcpy(M, fms.M, nq1*sizeof(float));
|
std::memcpy(M, fms.M, nq1*sizeof(float));
|
||||||
@@ -1173,7 +1188,7 @@ struct FlashQKV {
|
|||||||
} else {
|
} else {
|
||||||
auto R = qkv_cache;
|
auto R = qkv_cache;
|
||||||
for (int j = 0; j < nq1; ++j) {
|
for (int j = 0; j < nq1; ++j) {
|
||||||
normalize_and_store_1row(fms, j, R, qkv);
|
normalize_and_store_1row(fms, j, R, qkv, sinkf);
|
||||||
qkv += stride_qkv;
|
qkv += stride_qkv;
|
||||||
R += D;
|
R += D;
|
||||||
}
|
}
|
||||||
@@ -1181,7 +1196,7 @@ struct FlashQKV {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename FMS>
|
template <typename FMS>
|
||||||
inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, float * M, float * S) const {
|
inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, const float * sinkf, float * M, float * S) {
|
||||||
static_assert(q_step == FMS::q_step);
|
static_assert(q_step == FMS::q_step);
|
||||||
if (M && S) {
|
if (M && S) {
|
||||||
std::memcpy(M, fms.M, q_step*sizeof(float));
|
std::memcpy(M, fms.M, q_step*sizeof(float));
|
||||||
@@ -1201,7 +1216,7 @@ struct FlashQKV {
|
|||||||
} else {
|
} else {
|
||||||
auto R = qkv_cache;
|
auto R = qkv_cache;
|
||||||
for (int j = 0; j < q_step; ++j) {
|
for (int j = 0; j < q_step; ++j) {
|
||||||
normalize_and_store_1row(fms, j, R, qkv);
|
normalize_and_store_1row(fms, j, R, qkv, sinkf);
|
||||||
qkv += stride_qkv;
|
qkv += stride_qkv;
|
||||||
R += D;
|
R += D;
|
||||||
}
|
}
|
||||||
@@ -1332,7 +1347,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
|||||||
FlashMS<q_step, k_step>& fms,
|
FlashMS<q_step, k_step>& fms,
|
||||||
FlashQKV<Dv, q_step, k_step>& fqkv,
|
FlashQKV<Dv, q_step, k_step>& fqkv,
|
||||||
const float * q, const char * mask, float * qkv,
|
const float * q, const char * mask, float * qkv,
|
||||||
float * M, float * S) {
|
const float * sinkf, float * M, float * S) {
|
||||||
#ifdef __aarch64__
|
#ifdef __aarch64__
|
||||||
float16_t q_f16[Dk*q_step];
|
float16_t q_f16[Dk*q_step];
|
||||||
#endif
|
#endif
|
||||||
@@ -1356,7 +1371,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
|||||||
vh.next_block(k_step);
|
vh.next_block(k_step);
|
||||||
mr += k_step*sizeof(ggml_half);
|
mr += k_step*sizeof(ggml_half);
|
||||||
}
|
}
|
||||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
|
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
|
||||||
|
|
||||||
q += q_step*stride_q;
|
q += q_step*stride_q;
|
||||||
mask += q_step*stride_m;
|
mask += q_step*stride_m;
|
||||||
@@ -1383,7 +1398,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
|||||||
vh.next_block(k_step);
|
vh.next_block(k_step);
|
||||||
mr += k_step*sizeof(ggml_half);
|
mr += k_step*sizeof(ggml_half);
|
||||||
}
|
}
|
||||||
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
|
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1392,7 +1407,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
|||||||
FlashMS<q_step, k_step>& fms,
|
FlashMS<q_step, k_step>& fms,
|
||||||
FlashQKV<Dv, q_step, k_step>& fqkv,
|
FlashQKV<Dv, q_step, k_step>& fqkv,
|
||||||
const float * q, const char * mask, float * qkv,
|
const float * q, const char * mask, float * qkv,
|
||||||
float * M, float * S, char * qptr) {
|
const float * sinkf, float * M, float * S, char * qptr) {
|
||||||
auto q8 = (typename KHelper::block_q8 *)qptr;
|
auto q8 = (typename KHelper::block_q8 *)qptr;
|
||||||
if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80>) {
|
if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80>) {
|
||||||
if (nq1 == q_step) {
|
if (nq1 == q_step) {
|
||||||
@@ -1412,7 +1427,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
|||||||
vh.next_block(k_step);
|
vh.next_block(k_step);
|
||||||
mr += k_step*sizeof(ggml_half);
|
mr += k_step*sizeof(ggml_half);
|
||||||
}
|
}
|
||||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
|
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1449,10 +1464,10 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
|||||||
}
|
}
|
||||||
#if FA_TIMING
|
#if FA_TIMING
|
||||||
t1 = Perf::cur_time();
|
t1 = Perf::cur_time();
|
||||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
|
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
|
||||||
perf.accum_nolock(3, t1);
|
perf.accum_nolock(3, t1);
|
||||||
#else
|
#else
|
||||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
|
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
q += q_step*stride_q;
|
q += q_step*stride_q;
|
||||||
@@ -1474,7 +1489,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
|||||||
vh.next_block(k_step);
|
vh.next_block(k_step);
|
||||||
mr += k_step*sizeof(ggml_half);
|
mr += k_step*sizeof(ggml_half);
|
||||||
}
|
}
|
||||||
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
|
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
#if FA_TIMING
|
#if FA_TIMING
|
||||||
Perf::instance().add(perf);
|
Perf::instance().add(perf);
|
||||||
@@ -1504,7 +1519,7 @@ struct FlashAttn {
|
|||||||
static_assert(k_step%F16::block_size == 0);
|
static_assert(k_step%F16::block_size == 0);
|
||||||
static_assert(q_step <= 4 || q_step%4 == 0);
|
static_assert(q_step <= 4 || q_step%4 == 0);
|
||||||
|
|
||||||
FlashAttn(float scale, float softcap) : fms(scale, softcap) {}
|
FlashAttn(float scale, float softcap, const float * sinkf) : fms(scale, softcap), sinkf(sinkf) {}
|
||||||
|
|
||||||
template <typename KHelper, typename VHelper>
|
template <typename KHelper, typename VHelper>
|
||||||
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||||
@@ -1533,7 +1548,7 @@ struct FlashAttn {
|
|||||||
HelperQ80R8<Dk> khr4(nk1, kh);
|
HelperQ80R8<Dk> khr4(nk1, kh);
|
||||||
#endif
|
#endif
|
||||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr);
|
||||||
return;
|
return;
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1547,29 +1562,30 @@ struct FlashAttn {
|
|||||||
HelperQ8KVR8<Dk> khr4(nk1, kh);
|
HelperQ8KVR8<Dk> khr4(nk1, kh);
|
||||||
#endif
|
#endif
|
||||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr);
|
||||||
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
|
typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
|
||||||
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, (char *)q8);
|
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, (char *)q8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
|
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FlashMS<q_step, k_step> fms;
|
FlashMS<q_step, k_step> fms;
|
||||||
FlashQKV<Dv, q_step, k_step> fqkv;
|
FlashQKV<Dv, q_step, k_step> fqkv;
|
||||||
|
const float * sinkf;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1927,7 +1943,7 @@ struct FlashAttnBF16 {
|
|||||||
static_assert(k_step%32 == 0);
|
static_assert(k_step%32 == 0);
|
||||||
static_assert(q_step <= 4 || q_step%4 == 0);
|
static_assert(q_step <= 4 || q_step%4 == 0);
|
||||||
|
|
||||||
FlashAttnBF16(float scale, float softcap) : fms(scale, softcap) {}
|
FlashAttnBF16(float scale, float softcap, const float * sinkf) : fms(scale, softcap), sinkf(sinkf) {}
|
||||||
|
|
||||||
template <typename KHelper, typename VHelper>
|
template <typename KHelper, typename VHelper>
|
||||||
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||||
@@ -1967,7 +1983,7 @@ struct FlashAttnBF16 {
|
|||||||
#if FA_TIMING
|
#if FA_TIMING
|
||||||
t1 = Perf::cur_time();
|
t1 = Perf::cur_time();
|
||||||
#endif
|
#endif
|
||||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
|
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
|
||||||
#if FA_TIMING
|
#if FA_TIMING
|
||||||
perf.accum_nolock(4, t1);
|
perf.accum_nolock(4, t1);
|
||||||
#endif
|
#endif
|
||||||
@@ -1990,7 +2006,7 @@ struct FlashAttnBF16 {
|
|||||||
vh.next_block(k_step);
|
vh.next_block(k_step);
|
||||||
mr += k_step*sizeof(ggml_half);
|
mr += k_step*sizeof(ggml_half);
|
||||||
}
|
}
|
||||||
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
|
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S);
|
||||||
}
|
}
|
||||||
#if FA_TIMING
|
#if FA_TIMING
|
||||||
Perf::instance().add(perf);
|
Perf::instance().add(perf);
|
||||||
@@ -1999,12 +2015,14 @@ struct FlashAttnBF16 {
|
|||||||
|
|
||||||
FlashMS<q_step, k_step> fms;
|
FlashMS<q_step, k_step> fms;
|
||||||
FlashQKV<Dv, q_step, k_step> fqkv;
|
FlashQKV<Dv, q_step, k_step> fqkv;
|
||||||
|
const float * sinkf;
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
|
template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
|
||||||
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||||
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
|
const float * q, const char * mask, float scale, float softcap, float * qkv,
|
||||||
|
const float * sinkf, float * M, float * S) {
|
||||||
|
|
||||||
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
|
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
|
||||||
nq1 -= n;
|
nq1 -= n;
|
||||||
@@ -2018,48 +2036,48 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
|
|||||||
if (nk1 >= 512) {
|
if (nk1 >= 512) {
|
||||||
if (nq1 >= 128) {
|
if (nq1 >= 128) {
|
||||||
int n_step = nq1/128;
|
int n_step = nq1/128;
|
||||||
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
|
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
if (update(128*n_step)) return;
|
if (update(128*n_step)) return;
|
||||||
}
|
}
|
||||||
if (nq1 >= 64) {
|
if (nq1 >= 64) {
|
||||||
int n_step = nq1/64;
|
int n_step = nq1/64;
|
||||||
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
|
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
if (update(64*n_step)) return;
|
if (update(64*n_step)) return;
|
||||||
}
|
}
|
||||||
if (nq1 >= 32) {
|
if (nq1 >= 32) {
|
||||||
int n_step = nq1/32;
|
int n_step = nq1/32;
|
||||||
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
|
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
if (update(32*n_step)) return;
|
if (update(32*n_step)) return;
|
||||||
}
|
}
|
||||||
if (nq1 >= 16) {
|
if (nq1 >= 16) {
|
||||||
int n_step = nq1/16;
|
int n_step = nq1/16;
|
||||||
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
|
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
if (update(16*n_step)) return;
|
if (update(16*n_step)) return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (nq1 >= 8) {
|
if (nq1 >= 8) {
|
||||||
int n_step = nq1/8;
|
int n_step = nq1/8;
|
||||||
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
|
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
if (update(8*n_step)) return;
|
if (update(8*n_step)) return;
|
||||||
}
|
}
|
||||||
else if (nq1 >= 4) {
|
else if (nq1 >= 4) {
|
||||||
int n_step = nq1/4;
|
int n_step = nq1/4;
|
||||||
FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap);
|
FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
if (update(4*n_step)) return;
|
if (update(4*n_step)) return;
|
||||||
}
|
}
|
||||||
else if (nq1 >= 2) {
|
else if (nq1 >= 2) {
|
||||||
int n_step = nq1/2;
|
int n_step = nq1/2;
|
||||||
FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap);
|
FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
if (update(2*n_step)) return;
|
if (update(2*n_step)) return;
|
||||||
}
|
}
|
||||||
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
|
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2067,26 +2085,26 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
|
|||||||
template <int Dk, int Dv, int k_step>
|
template <int Dk, int Dv, int k_step>
|
||||||
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||||
const float * q, const char * k, const char * v, const char * mask,
|
const float * q, const char * k, const char * v, const char * mask,
|
||||||
float scale, float softcap, float * qkv, float * M, float * S) {
|
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
|
||||||
HelperBF16<Dk, k_step> kh(k, stride_k);
|
HelperBF16<Dk, k_step> kh(k, stride_k);
|
||||||
HelperBF16<Dv, k_step> vh(v, stride_v);
|
HelperBF16<Dv, k_step> vh(v, stride_v);
|
||||||
if (nk1 >= 4096) {
|
if (nk1 >= 4096) {
|
||||||
if (nq1 >= 64) {
|
if (nq1 >= 64) {
|
||||||
FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap);
|
FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
else if (nq1 >= 16) {
|
else if (nq1 >= 16) {
|
||||||
FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap);
|
FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (nq1 >= 8) {
|
if (nq1 >= 8) {
|
||||||
FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap);
|
FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
} else {
|
} else {
|
||||||
FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap);
|
FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap, sinkf);
|
||||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2096,43 +2114,43 @@ template <int Dk, int Dv, int k_step, typename KHelper>
|
|||||||
inline bool iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
|
inline bool iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
|
||||||
int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
|
int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
|
||||||
const float * q, const char * v, const char * mask,
|
const float * q, const char * v, const char * mask,
|
||||||
float scale, float softcap, float * qkv, float * M, float * S) {
|
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
|
||||||
|
|
||||||
switch (type_v) {
|
switch (type_v) {
|
||||||
case GGML_TYPE_F16: {
|
case GGML_TYPE_F16: {
|
||||||
HelperF16 vh(v, stride_v);
|
HelperF16 vh(v, stride_v);
|
||||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
#ifdef __AVX512BF16__
|
#ifdef __AVX512BF16__
|
||||||
case GGML_TYPE_BF16: {
|
case GGML_TYPE_BF16: {
|
||||||
HelperBF16<Dv, k_step> vh(v, stride_v);
|
HelperBF16<Dv, k_step> vh(v, stride_v);
|
||||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
#endif
|
#endif
|
||||||
case GGML_TYPE_Q8_0: {
|
case GGML_TYPE_Q8_0: {
|
||||||
HelperQ80 vh(v, stride_v);
|
HelperQ80 vh(v, stride_v);
|
||||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q8_KV: {
|
case GGML_TYPE_Q8_KV: {
|
||||||
HelperQ8KV<Dv> vh(v, stride_v);
|
HelperQ8KV<Dv> vh(v, stride_v);
|
||||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q6_0: {
|
case GGML_TYPE_Q6_0: {
|
||||||
HelperQ60 vh(v, stride_v);
|
HelperQ60 vh(v, stride_v);
|
||||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
#if GGML_IQK_FA_ALL_QUANTS
|
#if GGML_IQK_FA_ALL_QUANTS
|
||||||
case GGML_TYPE_Q4_0: {
|
case GGML_TYPE_Q4_0: {
|
||||||
HelperQ40 vh(v, stride_v);
|
HelperQ40 vh(v, stride_v);
|
||||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_1: {
|
case GGML_TYPE_Q4_1: {
|
||||||
HelperQ41 vh(v, stride_v);
|
HelperQ41 vh(v, stride_v);
|
||||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_IQ4_NL: {
|
case GGML_TYPE_IQ4_NL: {
|
||||||
HelperIQ4nl vh(v, stride_v);
|
HelperIQ4nl vh(v, stride_v);
|
||||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
#endif
|
#endif
|
||||||
default: return false;
|
default: return false;
|
||||||
@@ -2144,42 +2162,42 @@ template <int Dk, int Dv, int k_step>
|
|||||||
inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||||
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||||
const float * q, const char * k, const char * v, const char * mask,
|
const float * q, const char * k, const char * v, const char * mask,
|
||||||
float scale, float softcap, float * qkv, float * M, float * S) {
|
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
|
||||||
|
|
||||||
bool result = false;
|
bool result = false;
|
||||||
switch (type_k) {
|
switch (type_k) {
|
||||||
case GGML_TYPE_F16: {
|
case GGML_TYPE_F16: {
|
||||||
HelperF16 kh(k, stride_k);
|
HelperF16 kh(k, stride_k);
|
||||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q8_0: {
|
case GGML_TYPE_Q8_0: {
|
||||||
HelperQ80 kh(k, stride_k);
|
HelperQ80 kh(k, stride_k);
|
||||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q8_0_R8: {
|
case GGML_TYPE_Q8_0_R8: {
|
||||||
HelperQ80R8<Dk> kh(k, stride_k);
|
HelperQ80R8<Dk> kh(k, stride_k);
|
||||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q6_0: {
|
case GGML_TYPE_Q6_0: {
|
||||||
HelperQ60 kh(k, stride_k);
|
HelperQ60 kh(k, stride_k);
|
||||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
#if GGML_IQK_FA_ALL_QUANTS
|
#if GGML_IQK_FA_ALL_QUANTS
|
||||||
case GGML_TYPE_Q8_KV: {
|
case GGML_TYPE_Q8_KV: {
|
||||||
HelperQ8KV<Dk> kh(k, stride_k);
|
HelperQ8KV<Dk> kh(k, stride_k);
|
||||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_0: {
|
case GGML_TYPE_Q4_0: {
|
||||||
HelperQ40 kh(k, stride_k);
|
HelperQ40 kh(k, stride_k);
|
||||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_1: {
|
case GGML_TYPE_Q4_1: {
|
||||||
HelperQ41 kh(k, stride_k);
|
HelperQ41 kh(k, stride_k);
|
||||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_IQ4_NL: {
|
case GGML_TYPE_IQ4_NL: {
|
||||||
HelperIQ4nl kh(k, stride_k);
|
HelperIQ4nl kh(k, stride_k);
|
||||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||||
} break;
|
} break;
|
||||||
#endif
|
#endif
|
||||||
default: break;
|
default: break;
|
||||||
@@ -2194,7 +2212,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
|||||||
int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,\
|
int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,\
|
||||||
const float * q, const void * k, const void * v, const void * mask,\
|
const float * q, const void * k, const void * v, const void * mask,\
|
||||||
float scale, float softcap,\
|
float scale, float softcap,\
|
||||||
float * qkv, float * M, float * S)
|
float * qkv, const float * sinkf, float * M, float * S)
|
||||||
|
|
||||||
IQK_FA_CASE(iqk_fa_576_512);
|
IQK_FA_CASE(iqk_fa_576_512);
|
||||||
IQK_FA_CASE(iqk_fa_192_128);
|
IQK_FA_CASE(iqk_fa_192_128);
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
|||||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||||
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||||
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||||
|
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||||
float scale, // scale applied before softmax
|
float scale, // scale applied before softmax
|
||||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||||
float * qkv, // v*softmax(scale*(k*q))
|
float * qkv, // v*softmax(scale*(k*q))
|
||||||
@@ -139,7 +140,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
|||||||
auto work_this_thread = (float *)(result_buffer + ith*size_thread);
|
auto work_this_thread = (float *)(result_buffer + ith*size_thread);
|
||||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||||
Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
|
Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
|
||||||
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth,
|
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, nullptr, 0,
|
||||||
scale, softcap,
|
scale, softcap,
|
||||||
work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false;
|
work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false;
|
||||||
|
|
||||||
@@ -182,51 +183,6 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
|||||||
if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
|
if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
|
||||||
auto result_size = (Dv + 16)*rk2*sizeof(float);
|
auto result_size = (Dv + 16)*rk2*sizeof(float);
|
||||||
int gcd = simple_gcd(nek2, nth);
|
int gcd = simple_gcd(nek2, nth);
|
||||||
if (false && gcd > 1) {
|
|
||||||
int nth_g = nth/gcd;
|
|
||||||
int ith_g = ith%nth_g;
|
|
||||||
int nek1_32 = nek1/32;
|
|
||||||
int nek1_pt = (nek1_32 + nth_g - 1)/nth_g;
|
|
||||||
int ith_mid = nth_g;
|
|
||||||
if (nek1_pt*nth_g > nek1_32) {
|
|
||||||
ith_mid = nek1_32 - nth_g*(nek1_pt - 1);
|
|
||||||
}
|
|
||||||
nek1_pt *= 32;
|
|
||||||
int nek1_mid = ith_mid*nek1_pt;
|
|
||||||
int nek1_thread = ith_g < ith_mid ? nek1_pt : nek1_pt - 32;
|
|
||||||
for (int ik02 = ith/nth_g; ik02 < nek2; ik02 += gcd) {
|
|
||||||
int ik01 = ith_g < ith_mid ? ith_g*nek1_pt : nek1_mid + (ith_g - ith_mid)*nek1_thread;
|
|
||||||
auto this_result = (float *)((char *)work_buffer + (ik02*nth_g + ith_g)*result_size);
|
|
||||||
auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
|
|
||||||
auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
|
|
||||||
auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
|
|
||||||
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
|
|
||||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
|
||||||
Dk, Dv, rk2, nek1_thread, nbq2, stride_k, stride_v, 0, Dv,
|
|
||||||
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
|
|
||||||
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
barrier(barrier_data);
|
|
||||||
|
|
||||||
for (int iq2 = ith; iq2 < neq2; iq2 += nth) {
|
|
||||||
int ik02 = iq2/rk2;
|
|
||||||
int il = iq2 - ik02*rk2;
|
|
||||||
auto Racc = qkv + iq2*nb1/sizeof(float);
|
|
||||||
float M = -INFINITY, S = 0;
|
|
||||||
for (int ig = 0; ig < nth_g; ++ig) {
|
|
||||||
int istep_k = ik02*nth_g + ig;
|
|
||||||
auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
|
|
||||||
const float * R = this_result + il*Dv;
|
|
||||||
const float * Mj = this_result + Dv*rk2;
|
|
||||||
const float * Sj = Mj + rk2;
|
|
||||||
accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
|
|
||||||
}
|
|
||||||
float norm = S > 0 ? 1/S : 1;
|
|
||||||
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
int nth_k = nth/gcd;
|
int nth_k = nth/gcd;
|
||||||
int nek2_k = nek2/gcd;
|
int nek2_k = nek2/gcd;
|
||||||
int nchunk = nek2_k*nek1/32;
|
int nchunk = nek2_k*nek1/32;
|
||||||
@@ -259,7 +215,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
|||||||
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
|
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
|
||||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||||
Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv,
|
Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv,
|
||||||
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
|
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, nullptr, 0,
|
||||||
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
|
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,6 +237,16 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
|||||||
const float * Sj = Mj + rk2;
|
const float * Sj = Mj + rk2;
|
||||||
accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
|
accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
|
||||||
}
|
}
|
||||||
|
if (sinks) {
|
||||||
|
float s = ((const float *)sinks)[iq2];
|
||||||
|
if (s > M) {
|
||||||
|
float m = expf(M - s);
|
||||||
|
for (int i = 0; i < Dv; ++i) Racc[i] *= m;
|
||||||
|
S = S*m + 1;
|
||||||
|
} else {
|
||||||
|
S += expf(s - M);
|
||||||
|
}
|
||||||
|
}
|
||||||
float norm = S > 0 ? 1/S : 1;
|
float norm = S > 0 ? 1/S : 1;
|
||||||
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
|
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
|
||||||
}
|
}
|
||||||
@@ -306,6 +272,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
|||||||
int counter = 0;
|
int counter = 0;
|
||||||
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
|
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
|
||||||
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
|
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
|
||||||
|
auto sinksf = sinks ? (const float *)sinks + iq2 : nullptr;
|
||||||
if (counter++ % (nth/ntg) == ith/ntg) {
|
if (counter++ % (nth/ntg) == ith/ntg) {
|
||||||
int iq1 = (ith%ntg)*neq1g;
|
int iq1 = (ith%ntg)*neq1g;
|
||||||
int this_neq1 = std::min(neq1g, neq1-iq1);
|
int this_neq1 = std::min(neq1g, neq1-iq1);
|
||||||
@@ -314,7 +281,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
|||||||
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
|
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
|
||||||
(const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3),
|
(const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3),
|
||||||
(const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3),
|
(const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3),
|
||||||
(const void *)((const char *)mask + iq1*stride_m),
|
(const void *)((const char *)mask + iq1*stride_m), sinksf, 1,
|
||||||
scale, softcap,
|
scale, softcap,
|
||||||
(float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
|
(float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ bool iqk_flash_attn_impl(int type_k, // type of k
|
|||||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||||
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||||
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||||
|
const float * sinksf, // attention sinks
|
||||||
|
int nsinks, // number of sinks
|
||||||
float scale, // scale applied before softmax
|
float scale, // scale applied before softmax
|
||||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||||
float * qkv, // v*softmax(scale*(k*q))
|
float * qkv, // v*softmax(scale*(k*q))
|
||||||
|
|||||||
@@ -120,16 +120,21 @@ struct MulMat {
|
|||||||
funcs[n_left-1](n, vx, bx, info, nrc_x);
|
funcs[n_left-1](n, vx, bx, info, nrc_x);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inline void gelu(int n, const float * src, float * dst);
|
inline static void gelu(int n, const float * src, float * dst);
|
||||||
inline void relu(int n, const float * src, float * dst);
|
inline static void relu(int n, const float * src, float * dst);
|
||||||
inline void silu(int n, const float * src, float * dst);
|
inline static void silu(int n, const float * src, float * dst);
|
||||||
inline void activate(ggml_unary_op op, int n, const float * src, float * dst) {
|
inline static void swiglu_oai(int n, const float * src, float * dst);
|
||||||
|
inline static void clamp_oai(int n, float *x);
|
||||||
|
inline static void activate(ggml_unary_op op, int n, const float * src, float * dst) {
|
||||||
if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst);
|
if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst);
|
||||||
else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst);
|
else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst);
|
||||||
else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst);
|
else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst);
|
||||||
|
else if (op == GGML_UNARY_OP_SWIGLU_OAI) swiglu_oai(n, src, dst);
|
||||||
else GGML_ABORT("fatal error");
|
else GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, DataInfo& info, int nrc_x, int nrc_y, int unary_op) {
|
inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx,
|
||||||
|
const float * up_b, const float * gate_b,
|
||||||
|
DataInfo& info, int nrc_x, int nrc_y, int unary_op) {
|
||||||
#ifdef __aarch64__
|
#ifdef __aarch64__
|
||||||
constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small)
|
constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small)
|
||||||
#else
|
#else
|
||||||
@@ -137,6 +142,29 @@ struct MulMat {
|
|||||||
#endif
|
#endif
|
||||||
auto op = ggml_unary_op(unary_op);
|
auto op = ggml_unary_op(unary_op);
|
||||||
float tmp[k_x_step*16];
|
float tmp[k_x_step*16];
|
||||||
|
auto process = [&tmp, n, op, vx_gate, vx_up, gate_b, up_b, bx, xstep = k_x_step] (mul_mat_t func, const DataInfo& this_info, int ix, int this_nrc_x, int ny) {
|
||||||
|
func(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
||||||
|
for (int ky = 0; ky < ny; ++ky) {
|
||||||
|
if (gate_b) {
|
||||||
|
auto b = gate_b + ix;
|
||||||
|
auto x = this_info.dst_row(ky);
|
||||||
|
for (int j = 0; j < this_nrc_x; ++j) x[j] += b[j];
|
||||||
|
}
|
||||||
|
activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*xstep);
|
||||||
|
}
|
||||||
|
func(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
||||||
|
for (int ky = 0; ky < ny; ++ky) {
|
||||||
|
auto result = this_info.dst_row(ky);
|
||||||
|
if (up_b) {
|
||||||
|
auto b = up_b + ix;
|
||||||
|
for (int j = 0; j < this_nrc_x; ++j) result[j] += b[j];
|
||||||
|
}
|
||||||
|
if (op == GGML_UNARY_OP_SWIGLU_OAI) {
|
||||||
|
clamp_oai(this_nrc_x, result);
|
||||||
|
}
|
||||||
|
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*xstep + j];
|
||||||
|
}
|
||||||
|
};
|
||||||
if (func16 && nrc_y >= 16) {
|
if (func16 && nrc_y >= 16) {
|
||||||
int n_step = (nrc_y - info.cur_y)/16;
|
int n_step = (nrc_y - info.cur_y)/16;
|
||||||
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
|
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
|
||||||
@@ -144,15 +172,7 @@ struct MulMat {
|
|||||||
this_info.s += ix;
|
this_info.s += ix;
|
||||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||||
for (int iy = 0; iy < n_step; ++iy) {
|
for (int iy = 0; iy < n_step; ++iy) {
|
||||||
func16(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
process(func16, this_info, ix, this_nrc_x, 16);
|
||||||
for (int ky = 0; ky < 16; ++ky) {
|
|
||||||
activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
|
||||||
}
|
|
||||||
func16(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
|
||||||
for (int ky = 0; ky < 16; ++ky) {
|
|
||||||
auto result = this_info.dst_row(ky);
|
|
||||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
|
||||||
}
|
|
||||||
this_info.cur_y += 16;
|
this_info.cur_y += 16;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -175,23 +195,11 @@ struct MulMat {
|
|||||||
this_info.s += ix;
|
this_info.s += ix;
|
||||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||||
for (int iy = 0; iy < my1; ++iy) {
|
for (int iy = 0; iy < my1; ++iy) {
|
||||||
funcs[ny1-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
process(funcs[ny1-1], this_info, ix, this_nrc_x, ny1);
|
||||||
for (int ky = 0; ky < ny1; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
|
||||||
funcs[ny1-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
|
||||||
for (int ky = 0; ky < ny1; ++ky) {
|
|
||||||
auto result = this_info.dst_row(ky);
|
|
||||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
|
||||||
}
|
|
||||||
this_info.cur_y += ny1;
|
this_info.cur_y += ny1;
|
||||||
}
|
}
|
||||||
for (int iy = 0; iy < my2; ++iy) {
|
for (int iy = 0; iy < my2; ++iy) {
|
||||||
funcs[ny2-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
process(funcs[ny2-1], this_info, ix, this_nrc_x, ny2);
|
||||||
for (int ky = 0; ky < ny2; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
|
||||||
funcs[ny2-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
|
||||||
for (int ky = 0; ky < ny2; ++ky) {
|
|
||||||
auto result = this_info.dst_row(ky);
|
|
||||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
|
||||||
}
|
|
||||||
this_info.cur_y += ny2;
|
this_info.cur_y += ny2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -203,13 +211,7 @@ struct MulMat {
|
|||||||
this_info.s += ix;
|
this_info.s += ix;
|
||||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||||
for (int iy = 0; iy < n_step; ++iy) {
|
for (int iy = 0; iy < n_step; ++iy) {
|
||||||
funcs[ny-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
process(funcs[ny-1], this_info, ix, this_nrc_x, ny);
|
||||||
for (int ky = 0; ky < ny; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
|
||||||
funcs[ny-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
|
||||||
for (int ky = 0; ky < ny; ++ky) {
|
|
||||||
auto result = this_info.dst_row(ky);
|
|
||||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
|
||||||
}
|
|
||||||
this_info.cur_y += ny;
|
this_info.cur_y += ny;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -222,13 +224,7 @@ struct MulMat {
|
|||||||
auto this_info = info;
|
auto this_info = info;
|
||||||
this_info.s += ix;
|
this_info.s += ix;
|
||||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||||
funcs[n_left-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
process(funcs[n_left-1], this_info, ix, this_nrc_x, n_left);
|
||||||
for (int ky = 0; ky < n_left; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
|
||||||
funcs[n_left-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
|
||||||
for (int ky = 0; ky < n_left; ++ky) {
|
|
||||||
auto result = this_info.dst_row(ky);
|
|
||||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -731,6 +727,7 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
|
|||||||
extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
|
extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
|
||||||
int typeA, const void * Aup, const void * Agate, long strideA,
|
int typeA, const void * Aup, const void * Agate, long strideA,
|
||||||
int typeB, const void * B, long strideB,
|
int typeB, const void * B, long strideB,
|
||||||
|
const char * up_b_c, const char * gate_b_c,
|
||||||
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
|
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
|
||||||
|
|
||||||
const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;
|
const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;
|
||||||
@@ -774,7 +771,9 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
|
|||||||
if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) {
|
if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) {
|
||||||
GGML_ABORT("Fatal error");
|
GGML_ABORT("Fatal error");
|
||||||
}
|
}
|
||||||
mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, this_info, this_nrc_x, Ny, unary_op);
|
auto up_b = up_b_c ? (const float *)up_b_c + first_x + ix : nullptr;
|
||||||
|
auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x + ix : nullptr;
|
||||||
|
mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, up_b, gate_b, this_info, this_nrc_x, Ny, unary_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@@ -795,7 +794,10 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
|
|||||||
nrc_x *= num_rows;
|
nrc_x *= num_rows;
|
||||||
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
|
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
|
||||||
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
|
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
|
||||||
mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny, unary_op);
|
auto up_b = up_b_c ? (const float *)up_b_c + first_x : nullptr;
|
||||||
|
auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x : nullptr;
|
||||||
|
mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx,
|
||||||
|
up_b, gate_b, info, nrc_x, Ny, unary_op);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -993,6 +995,46 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// TODO: these swiglu_oai constants shouldn't be hard coded
|
||||||
|
constexpr float k_swiglu_oai_alpha = 1.702f;
|
||||||
|
constexpr float k_swiglu_oai_limit = 7.f;
|
||||||
|
|
||||||
|
void MulMat::swiglu_oai(int n, const float * x, float * y) {
|
||||||
|
// int i = 0;
|
||||||
|
//#if defined __AVX512F__ && defined __AVX512DQ__
|
||||||
|
// {
|
||||||
|
// auto max = _mm512_set1_ps(k_swiglu_oai_limit);
|
||||||
|
// auto alpha = _mm512_set1_ps(-k_swiglu_oai_alpha);
|
||||||
|
// for (; i + 15 < n; i += 16) {
|
||||||
|
// auto xc = v_clamp_max(_mm512_loadu_ps(x + i), max);
|
||||||
|
// _mm512_storeu_ps(y + i, v_silu_oai(xc, alpha));
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//#endif
|
||||||
|
//#if defined __AVX2__ && defined __FMA__
|
||||||
|
// if (i + 7 < n) {
|
||||||
|
// auto max = _mm256_set1_ps(k_swiglu_oai_limit);
|
||||||
|
// auto alpha = _mm256_set1_ps(-k_swiglu_oai_alpha);
|
||||||
|
// for (; i + 7 < n; i += 8) {
|
||||||
|
// auto xc = v_clamp_max(_mm256_loadu_ps(x + i), max);
|
||||||
|
// _mm256_storeu_ps(y + i, v_silu_oai(xc, alpha));
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//#endif
|
||||||
|
// for (; i < n; ++i) {
|
||||||
|
// auto xi = std::min(x[i], k_swiglu_oai_limit);
|
||||||
|
// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
|
||||||
|
// }
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
auto xi = std::min(x[i], k_swiglu_oai_limit);
|
||||||
|
y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MulMat::clamp_oai(int n, float * x) {
|
||||||
|
for (int i = 0; i < n; ++i) x[i] = 1.f + std::max(std::min(x[i], k_swiglu_oai_limit), -k_swiglu_oai_limit);
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(__ARM_NEON) && defined(__aarch64__)
|
#if defined(__ARM_NEON) && defined(__aarch64__)
|
||||||
void MulMat::gelu(int n, const float * x, float * y) {
|
void MulMat::gelu(int n, const float * x, float * y) {
|
||||||
constexpr float GELU_COEF_A = 0.044715f;
|
constexpr float GELU_COEF_A = 0.044715f;
|
||||||
@@ -1040,6 +1082,37 @@ void MulMat::gelu(int n, const float * x, float * y) {
|
|||||||
for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i])));
|
for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i])));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//void MulMat::swiglu_oai(int n, const float * x, float * y) {
|
||||||
|
// int i = 0;
|
||||||
|
//#if defined __AVX512F__ && defined __AVX512DQ__
|
||||||
|
// {
|
||||||
|
// auto limit = _mm512_set1_ps(k_swiglu_oai_limit);
|
||||||
|
// auto alpha = _mm512_set1_ps(k_swiglu_oai_alpha);
|
||||||
|
// for (; i + 15 < n; i += 16) {
|
||||||
|
// auto xi = _mm512_loadu_ps(x + i);
|
||||||
|
// auto mask = _mm512_cmp
|
||||||
|
//
|
||||||
|
// }
|
||||||
|
// __m512 c1 = _mm512_set1_ps(GELU_COEF_A);
|
||||||
|
// __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI);
|
||||||
|
// for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_gelu(_mm512_loadu_ps(x + i), c1, c2));
|
||||||
|
// }
|
||||||
|
//#endif
|
||||||
|
//#if defined __AVX2__ && defined __FMA__
|
||||||
|
// if (i + 7 < n) {
|
||||||
|
// __m256 c1 = _mm256_set1_ps(GELU_COEF_A);
|
||||||
|
// __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI);
|
||||||
|
// for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_gelu(_mm256_loadu_ps(x + i), c1, c2));
|
||||||
|
//
|
||||||
|
// }
|
||||||
|
//#endif
|
||||||
|
// for (; i < n; ++i) {
|
||||||
|
// auto xi = std::min(x[i], k_swiglu_oai_limit);
|
||||||
|
// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
|
||||||
void MulMat::silu(int n, const float * x, float * y) {
|
void MulMat::silu(int n, const float * x, float * y) {
|
||||||
int i = 0;
|
int i = 0;
|
||||||
#if defined __AVX512F__ && defined __AVX512DQ__
|
#if defined __AVX512F__ && defined __AVX512DQ__
|
||||||
@@ -1188,6 +1261,8 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
|
|||||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||||
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||||
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||||
|
const float * sinksf, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||||
|
[[maybe_unused]] int nsinks,
|
||||||
float scale, // scale applied before softmax
|
float scale, // scale applied before softmax
|
||||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||||
float * qkv, // v*softmax(scale*(k*q))
|
float * qkv, // v*softmax(scale*(k*q))
|
||||||
@@ -1197,32 +1272,32 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
|
|||||||
|
|
||||||
if (Dk == 576 && Dv == 512) {
|
if (Dk == 576 && Dv == 512) {
|
||||||
return iqk_fa_576_512(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_fa_576_512(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, k, v, mask, scale, softcap, qkv, M, S);
|
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Dk == 192 && Dv == 128) {
|
if (Dk == 192 && Dv == 128) {
|
||||||
return iqk_fa_192_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_fa_192_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, k, v, mask, scale, softcap, qkv, M, S);
|
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Dk == 256 && Dv == 256) {
|
if (Dk == 256 && Dv == 256) {
|
||||||
return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, k, v, mask, scale, softcap, qkv, M, S);
|
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Dk == 128 && Dv == 128) {
|
if (Dk == 128 && Dv == 128) {
|
||||||
return iqk_fa_128_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_fa_128_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, k, v, mask, scale, softcap, qkv, M, S);
|
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Dk == 96 && Dv == 96) {
|
if (Dk == 96 && Dv == 96) {
|
||||||
return iqk_fa_96_96(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_fa_96_96(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, k, v, mask, scale, softcap, qkv, M, S);
|
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Dk == 64 && Dv == 64) {
|
if (Dk == 64 && Dv == 64) {
|
||||||
return iqk_fa_64_64(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
return iqk_fa_64_64(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||||
q, k, v, mask, scale, softcap, qkv, M, S);
|
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
|
|||||||
IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
|
IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
|
||||||
int typeA, const void * Aup, const void * Agate, long strideA,
|
int typeA, const void * Aup, const void * Agate, long strideA,
|
||||||
int typeB, const void * B, long strideB,
|
int typeB, const void * B, long strideB,
|
||||||
|
const char * up_b, const char * gate_b,
|
||||||
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
|
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
|
||||||
|
|
||||||
IQK_API int iqk_dequant_type(int type, int Ny);
|
IQK_API int iqk_dequant_type(int type, int Ny);
|
||||||
@@ -57,6 +58,7 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
|
|||||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||||
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||||
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||||
|
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||||
float scale, // scale applied before softmax
|
float scale, // scale applied before softmax
|
||||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||||
float * qkv, // v*softmax(scale*(k*q))
|
float * qkv, // v*softmax(scale*(k*q))
|
||||||
|
|||||||
@@ -61,6 +61,13 @@ static inline float32x4_t v_silu(float32x4_t x) {
|
|||||||
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
|
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
|
||||||
return vdivq_f32(x, one_plus_exp_neg_x);
|
return vdivq_f32(x, one_plus_exp_neg_x);
|
||||||
}
|
}
|
||||||
|
static inline float32x4_t v_silu_oai(float32x4_t x, float32x4_t alpha) {
|
||||||
|
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||||
|
const float32x4_t neg_x = vmulq_f32(alpha, x);
|
||||||
|
const float32x4_t exp_neg_x = v_expf(neg_x);
|
||||||
|
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
|
||||||
|
return vdivq_f32(x, one_plus_exp_neg_x);
|
||||||
|
}
|
||||||
static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
|
static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
|
||||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||||
float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
|
float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
|
||||||
@@ -131,6 +138,17 @@ static inline __m512 v_silu(__m512 x) {
|
|||||||
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
|
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
|
||||||
return _mm512_div_ps(x, one_plus_exp_neg_x);
|
return _mm512_div_ps(x, one_plus_exp_neg_x);
|
||||||
}
|
}
|
||||||
|
static inline __m512 v_silu_oai(__m512 x, __m512 alpha) {
|
||||||
|
const __m512 one = _mm512_set1_ps(1);
|
||||||
|
const __m512 neg_x = _mm512_mul_ps(alpha, x);
|
||||||
|
const __m512 exp_neg_x = v_expf(neg_x);
|
||||||
|
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
|
||||||
|
return _mm512_div_ps(x, one_plus_exp_neg_x);
|
||||||
|
}
|
||||||
|
static inline __m512 v_clamp_max(__m512 x, __m512 max) {
|
||||||
|
auto mask = _mm512_cmp_ps_mask(x, max, _CMP_GT_OQ);
|
||||||
|
return _mm512_mask_blend_ps(mask, x, max);
|
||||||
|
}
|
||||||
#endif // __AVX512__
|
#endif // __AVX512__
|
||||||
|
|
||||||
#if defined(__AVX2__) && defined(__FMA__)
|
#if defined(__AVX2__) && defined(__FMA__)
|
||||||
@@ -195,12 +213,23 @@ static inline __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) {
|
|||||||
}
|
}
|
||||||
static inline __m256 v_silu(__m256 x) {
|
static inline __m256 v_silu(__m256 x) {
|
||||||
const __m256 one = _mm256_set1_ps(1);
|
const __m256 one = _mm256_set1_ps(1);
|
||||||
const __m256 zero = _mm256_setzero_ps();
|
const __m256 zero = _mm256_setzero_ps();
|
||||||
const __m256 neg_x = _mm256_sub_ps(zero, x);
|
const __m256 neg_x = _mm256_sub_ps(zero, x);
|
||||||
const __m256 exp_neg_x = v_expf(neg_x);
|
const __m256 exp_neg_x = v_expf(neg_x);
|
||||||
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
|
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
|
||||||
return _mm256_div_ps(x, one_plus_exp_neg_x);
|
return _mm256_div_ps(x, one_plus_exp_neg_x);
|
||||||
}
|
}
|
||||||
|
static inline __m256 v_silu_oai(__m256 x, __m256 alpha) {
|
||||||
|
const __m256 one = _mm256_set1_ps(1);
|
||||||
|
const __m256 neg_x = _mm256_mul_ps(alpha, x);
|
||||||
|
const __m256 exp_neg_x = v_expf(neg_x);
|
||||||
|
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
|
||||||
|
return _mm256_div_ps(x, one_plus_exp_neg_x);
|
||||||
|
}
|
||||||
|
static inline __m256 v_clamp_max(__m256 x, __m256 max) {
|
||||||
|
auto mask = _mm256_cmp_ps(x, max, _CMP_GT_OQ);
|
||||||
|
return _mm256_or_ps(_mm256_and_ps(mask, max), _mm256_andnot_ps(mask, x));
|
||||||
|
}
|
||||||
|
|
||||||
#endif // __AVX2__
|
#endif // __AVX2__
|
||||||
|
|
||||||
|
|||||||
@@ -70,50 +70,52 @@ extern "C" {
|
|||||||
typedef int32_t llama_seq_id;
|
typedef int32_t llama_seq_id;
|
||||||
|
|
||||||
enum llama_vocab_type {
|
enum llama_vocab_type {
|
||||||
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
|
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
|
||||||
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
|
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
|
||||||
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
||||||
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
||||||
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
||||||
|
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
||||||
|
LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming
|
||||||
};
|
};
|
||||||
|
|
||||||
// pre-tokenization types
|
// pre-tokenization types
|
||||||
enum llama_vocab_pre_type {
|
//enum llama_vocab_pre_type {
|
||||||
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
|
// LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
|
||||||
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
|
// LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
|
||||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
|
// LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
|
||||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
|
// LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
|
||||||
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
|
// LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
|
||||||
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
// LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
||||||
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
// LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
||||||
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
// LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
||||||
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
// LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
||||||
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
// LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
||||||
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
|
// LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
|
||||||
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
|
// LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
|
||||||
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
|
// LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
|
||||||
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
|
// LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
|
||||||
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
// LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
||||||
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
|
// LLAMA_VOCAB_PRE_TYPE_PORO = 15,
|
||||||
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
|
// LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
|
||||||
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
|
// LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
|
||||||
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
|
// LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
|
||||||
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
|
// LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
|
||||||
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
|
// LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
|
||||||
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
|
// LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
|
||||||
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
// LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
||||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, //llama.cpp lists this as 28
|
// LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, //llama.cpp lists this as 28
|
||||||
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
// LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||||
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
// LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
||||||
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
// LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
||||||
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
// LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
||||||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
// LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||||
LLAMA_VOCAB_PRE_TYPE_FALCON_3 = 34,
|
// LLAMA_VOCAB_PRE_TYPE_FALCON_3 = 34,
|
||||||
LLAMA_VOCAB_PRE_TYPE_FALCON_E = 35,
|
// LLAMA_VOCAB_PRE_TYPE_FALCON_E = 35,
|
||||||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 36, //llama.cpp lists this as 35
|
// LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 36, //llama.cpp lists this as 35
|
||||||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 37, //llama.cpp lists this as 36
|
// LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 37, //llama.cpp lists this as 36
|
||||||
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 38, //llama.cpp lists this as 37
|
// LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 38, //llama.cpp lists this as 37
|
||||||
};
|
//};
|
||||||
|
|
||||||
// note: these values should be synchronized with ggml_rope
|
// note: these values should be synchronized with ggml_rope
|
||||||
// TODO: maybe move this enum to ggml.h (ggml_rope_type)
|
// TODO: maybe move this enum to ggml.h (ggml_rope_type)
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ add_library(llama
|
|||||||
llama-vocab.cpp
|
llama-vocab.cpp
|
||||||
llama-grammar.cpp
|
llama-grammar.cpp
|
||||||
llama-sampling.cpp
|
llama-sampling.cpp
|
||||||
|
llama-mmap.cpp
|
||||||
|
llama-model-loader.cpp
|
||||||
unicode.h
|
unicode.h
|
||||||
unicode.cpp
|
unicode.cpp
|
||||||
unicode-data.cpp
|
unicode-data.cpp
|
||||||
|
|||||||
288
src/llama-arch.h
Normal file
288
src/llama-arch.h
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
enum llm_arch {
|
||||||
|
LLM_ARCH_LLAMA,
|
||||||
|
LLM_ARCH_LLAMA4,
|
||||||
|
LLM_ARCH_DECI,
|
||||||
|
LLM_ARCH_FALCON,
|
||||||
|
LLM_ARCH_BAICHUAN,
|
||||||
|
LLM_ARCH_GROK,
|
||||||
|
LLM_ARCH_GPT2,
|
||||||
|
LLM_ARCH_GPTJ,
|
||||||
|
LLM_ARCH_GPTNEOX,
|
||||||
|
LLM_ARCH_MPT,
|
||||||
|
LLM_ARCH_STARCODER,
|
||||||
|
LLM_ARCH_REFACT,
|
||||||
|
LLM_ARCH_BERT,
|
||||||
|
LLM_ARCH_NOMIC_BERT,
|
||||||
|
LLM_ARCH_JINA_BERT_V2,
|
||||||
|
LLM_ARCH_BLOOM,
|
||||||
|
LLM_ARCH_STABLELM,
|
||||||
|
LLM_ARCH_QWEN,
|
||||||
|
LLM_ARCH_QWEN2,
|
||||||
|
LLM_ARCH_QWEN2MOE,
|
||||||
|
LLM_ARCH_QWEN3,
|
||||||
|
LLM_ARCH_QWEN3MOE,
|
||||||
|
LLM_ARCH_PHI2,
|
||||||
|
LLM_ARCH_PHI3,
|
||||||
|
LLM_ARCH_PLAMO,
|
||||||
|
LLM_ARCH_CODESHELL,
|
||||||
|
LLM_ARCH_ORION,
|
||||||
|
LLM_ARCH_INTERNLM2,
|
||||||
|
LLM_ARCH_MINICPM,
|
||||||
|
LLM_ARCH_GEMMA,
|
||||||
|
LLM_ARCH_GEMMA2,
|
||||||
|
LLM_ARCH_GEMMA3,
|
||||||
|
LLM_ARCH_STARCODER2,
|
||||||
|
LLM_ARCH_MAMBA,
|
||||||
|
LLM_ARCH_XVERSE,
|
||||||
|
LLM_ARCH_COMMAND_R,
|
||||||
|
LLM_ARCH_DBRX,
|
||||||
|
LLM_ARCH_OLMO,
|
||||||
|
LLM_ARCH_OPENELM,
|
||||||
|
LLM_ARCH_ARCTIC,
|
||||||
|
LLM_ARCH_DEEPSEEK2,
|
||||||
|
LLM_ARCH_CHATGLM,
|
||||||
|
LLM_ARCH_GLM4,
|
||||||
|
LLM_ARCH_GLM4_MOE,
|
||||||
|
LLM_ARCH_BITNET,
|
||||||
|
LLM_ARCH_BITNET_25,
|
||||||
|
LLM_ARCH_BITNET_B158,
|
||||||
|
LLM_ARCH_T5,
|
||||||
|
LLM_ARCH_T5ENCODER,
|
||||||
|
LLM_ARCH_JAIS,
|
||||||
|
LLM_ARCH_GRANITE,
|
||||||
|
LLM_ARCH_GRANITE_MOE,
|
||||||
|
LLM_ARCH_COHERE2,
|
||||||
|
LLM_ARCH_DOTS1,
|
||||||
|
LLM_ARCH_HUNYUAN_MOE,
|
||||||
|
LLM_ARCH_OPENAI_MOE,
|
||||||
|
LLM_ARCH_UNKNOWN,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum llm_kv {
|
||||||
|
LLM_KV_GENERAL_TYPE,
|
||||||
|
LLM_KV_GENERAL_ARCHITECTURE,
|
||||||
|
LLM_KV_GENERAL_QUANTIZATION_VERSION,
|
||||||
|
LLM_KV_GENERAL_ALIGNMENT,
|
||||||
|
LLM_KV_GENERAL_NAME,
|
||||||
|
LLM_KV_GENERAL_AUTHOR,
|
||||||
|
LLM_KV_GENERAL_VERSION,
|
||||||
|
LLM_KV_GENERAL_URL,
|
||||||
|
LLM_KV_GENERAL_DESCRIPTION,
|
||||||
|
LLM_KV_GENERAL_LICENSE,
|
||||||
|
LLM_KV_GENERAL_SOURCE_URL,
|
||||||
|
LLM_KV_GENERAL_SOURCE_HF_REPO,
|
||||||
|
|
||||||
|
LLM_KV_VOCAB_SIZE,
|
||||||
|
LLM_KV_CONTEXT_LENGTH,
|
||||||
|
LLM_KV_EMBEDDING_LENGTH,
|
||||||
|
LLM_KV_BLOCK_COUNT,
|
||||||
|
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
|
||||||
|
LLM_KV_FEED_FORWARD_LENGTH,
|
||||||
|
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
|
||||||
|
LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
|
||||||
|
LLM_KV_USE_PARALLEL_RESIDUAL,
|
||||||
|
LLM_KV_TENSOR_DATA_LAYOUT,
|
||||||
|
LLM_KV_EXPERT_COUNT,
|
||||||
|
LLM_KV_EXPERT_USED_COUNT,
|
||||||
|
LLM_KV_EXPERT_SHARED_COUNT,
|
||||||
|
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
||||||
|
LLM_KV_EXPERT_WEIGHTS_NORM,
|
||||||
|
LLM_KV_EXPERT_GATING_FUNC,
|
||||||
|
LLM_KV_NEXTN_PREDICT_LAYERS,
|
||||||
|
LLM_KV_POOLING_TYPE,
|
||||||
|
LLM_KV_LOGIT_SCALE,
|
||||||
|
LLM_KV_DECODER_START_TOKEN_ID,
|
||||||
|
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
|
||||||
|
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
|
||||||
|
LLM_KV_SWIN_NORM,
|
||||||
|
LLM_KV_RESCALE_EVERY_N_LAYERS,
|
||||||
|
LLM_KV_TIME_MIX_EXTRA_DIM,
|
||||||
|
LLM_KV_TIME_DECAY_EXTRA_DIM,
|
||||||
|
LLM_KV_RESIDUAL_SCALE,
|
||||||
|
LLM_KV_EMBEDDING_SCALE,
|
||||||
|
LLM_KV_TOKEN_SHIFT_COUNT,
|
||||||
|
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
|
||||||
|
|
||||||
|
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||||
|
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||||
|
LLM_KV_ATTENTION_MAX_ALIBI_BIAS,
|
||||||
|
LLM_KV_ATTENTION_CLAMP_KQV,
|
||||||
|
LLM_KV_ATTENTION_KEY_LENGTH,
|
||||||
|
LLM_KV_ATTENTION_VALUE_LENGTH,
|
||||||
|
LLM_KV_ATTENTION_LAYERNORM_EPS,
|
||||||
|
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
|
||||||
|
LLM_KV_ATTENTION_CAUSAL,
|
||||||
|
LLM_KV_ATTENTION_Q_LORA_RANK,
|
||||||
|
LLM_KV_ATTENTION_KV_LORA_RANK,
|
||||||
|
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||||
|
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||||
|
LLM_KV_ATTENTION_SCALE,
|
||||||
|
|
||||||
|
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||||
|
LLM_KV_ROPE_FREQ_BASE,
|
||||||
|
LLM_KV_ROPE_SCALE_LINEAR,
|
||||||
|
LLM_KV_ROPE_SCALING_TYPE,
|
||||||
|
LLM_KV_ROPE_SCALING_FACTOR,
|
||||||
|
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
|
||||||
|
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
||||||
|
LLM_KV_ROPE_SCALING_FINETUNED,
|
||||||
|
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
|
||||||
|
|
||||||
|
LLM_KV_SPLIT_NO,
|
||||||
|
LLM_KV_SPLIT_COUNT,
|
||||||
|
LLM_KV_SPLIT_TENSORS_COUNT,
|
||||||
|
|
||||||
|
LLM_KV_SSM_INNER_SIZE,
|
||||||
|
LLM_KV_SSM_CONV_KERNEL,
|
||||||
|
LLM_KV_SSM_STATE_SIZE,
|
||||||
|
LLM_KV_SSM_TIME_STEP_RANK,
|
||||||
|
|
||||||
|
LLM_KV_TOKENIZER_MODEL,
|
||||||
|
LLM_KV_TOKENIZER_PRE,
|
||||||
|
LLM_KV_TOKENIZER_LIST,
|
||||||
|
LLM_KV_TOKENIZER_TOKEN_TYPE,
|
||||||
|
LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,
|
||||||
|
LLM_KV_TOKENIZER_SCORES,
|
||||||
|
LLM_KV_TOKENIZER_MERGES,
|
||||||
|
LLM_KV_TOKENIZER_BOS_ID,
|
||||||
|
LLM_KV_TOKENIZER_EOS_ID,
|
||||||
|
LLM_KV_TOKENIZER_UNK_ID,
|
||||||
|
LLM_KV_TOKENIZER_SEP_ID,
|
||||||
|
LLM_KV_TOKENIZER_PAD_ID,
|
||||||
|
LLM_KV_TOKENIZER_CLS_ID,
|
||||||
|
LLM_KV_TOKENIZER_MASK_ID,
|
||||||
|
LLM_KV_TOKENIZER_ADD_BOS,
|
||||||
|
LLM_KV_TOKENIZER_ADD_EOS,
|
||||||
|
LLM_KV_TOKENIZER_ADD_SEP,
|
||||||
|
LLM_KV_TOKENIZER_ADD_PREFIX,
|
||||||
|
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
|
||||||
|
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
|
||||||
|
LLM_KV_TOKENIZER_HF_JSON,
|
||||||
|
LLM_KV_TOKENIZER_RWKV,
|
||||||
|
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
|
||||||
|
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
|
||||||
|
LLM_KV_TOKENIZER_FIM_PRE_ID,
|
||||||
|
LLM_KV_TOKENIZER_FIM_SUF_ID,
|
||||||
|
LLM_KV_TOKENIZER_FIM_MID_ID,
|
||||||
|
LLM_KV_TOKENIZER_FIM_PAD_ID,
|
||||||
|
LLM_KV_TOKENIZER_FIM_REP_ID,
|
||||||
|
LLM_KV_TOKENIZER_FIM_SEP_ID,
|
||||||
|
LLM_KV_TOKENIZER_PREFIX_ID,
|
||||||
|
LLM_KV_TOKENIZER_SUFFIX_ID,
|
||||||
|
LLM_KV_TOKENIZER_MIDDLE_ID,
|
||||||
|
LLM_KV_TOKENIZER_EOT_ID,
|
||||||
|
LLM_KV_TOKENIZER_EOM_ID,
|
||||||
|
|
||||||
|
LLM_KV_ADAPTER_TYPE,
|
||||||
|
LLM_KV_ADAPTER_LORA_ALPHA,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LLM_KV {
|
||||||
|
LLM_KV(llm_arch arch, const char* suffix = nullptr);
|
||||||
|
|
||||||
|
llm_arch arch;
|
||||||
|
const char* suffix;
|
||||||
|
std::string operator()(llm_kv kv) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum llm_tensor {
|
||||||
|
LLM_TENSOR_TOKEN_EMBD,
|
||||||
|
LLM_TENSOR_TOKEN_EMBD_NORM,
|
||||||
|
LLM_TENSOR_TOKEN_TYPES,
|
||||||
|
LLM_TENSOR_POS_EMBD,
|
||||||
|
LLM_TENSOR_OUTPUT,
|
||||||
|
LLM_TENSOR_OUTPUT_NORM,
|
||||||
|
LLM_TENSOR_ROPE_FREQS,
|
||||||
|
LLM_TENSOR_ROPE_FACTORS_LONG,
|
||||||
|
LLM_TENSOR_ROPE_FACTORS_SHORT,
|
||||||
|
LLM_TENSOR_ATTN_Q,
|
||||||
|
LLM_TENSOR_ATTN_K,
|
||||||
|
LLM_TENSOR_ATTN_V,
|
||||||
|
LLM_TENSOR_ATTN_QKV,
|
||||||
|
LLM_TENSOR_ATTN_OUT,
|
||||||
|
LLM_TENSOR_ATTN_NORM,
|
||||||
|
LLM_TENSOR_ATTN_NORM_2,
|
||||||
|
LLM_TENSOR_ATTN_OUT_NORM,
|
||||||
|
LLM_TENSOR_ATTN_POST_NORM,
|
||||||
|
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||||
|
LLM_TENSOR_ATTN_SINKS,
|
||||||
|
LLM_TENSOR_FFN_GATE_INP,
|
||||||
|
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||||
|
LLM_TENSOR_FFN_NORM,
|
||||||
|
LLM_TENSOR_FFN_POST_NORM,
|
||||||
|
LLM_TENSOR_FFN_GATE,
|
||||||
|
LLM_TENSOR_FFN_DOWN,
|
||||||
|
LLM_TENSOR_FFN_UP,
|
||||||
|
LLM_TENSOR_FFN_ACT,
|
||||||
|
LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility
|
||||||
|
LLM_TENSOR_FFN_GATE_EXP,
|
||||||
|
LLM_TENSOR_FFN_UP_EXP,
|
||||||
|
LLM_TENSOR_FFN_NORM_EXPS,
|
||||||
|
LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
|
||||||
|
LLM_TENSOR_FFN_GATE_EXPS,
|
||||||
|
LLM_TENSOR_FFN_UP_EXPS,
|
||||||
|
LLM_TENSOR_FFN_DOWN_SHEXP,
|
||||||
|
LLM_TENSOR_FFN_GATE_SHEXP,
|
||||||
|
LLM_TENSOR_FFN_UP_SHEXP,
|
||||||
|
LLM_TENSOR_FFN_EXP_PROBS_B,
|
||||||
|
LLM_TENSOR_ATTN_Q_NORM,
|
||||||
|
LLM_TENSOR_ATTN_K_NORM,
|
||||||
|
LLM_TENSOR_LAYER_OUT_NORM,
|
||||||
|
LLM_TENSOR_SSM_IN,
|
||||||
|
LLM_TENSOR_SSM_CONV1D,
|
||||||
|
LLM_TENSOR_SSM_X,
|
||||||
|
LLM_TENSOR_SSM_DT,
|
||||||
|
LLM_TENSOR_SSM_A,
|
||||||
|
LLM_TENSOR_SSM_D,
|
||||||
|
LLM_TENSOR_SSM_OUT,
|
||||||
|
LLM_TENSOR_ATTN_Q_A,
|
||||||
|
LLM_TENSOR_ATTN_Q_B,
|
||||||
|
LLM_TENSOR_ATTN_KV_A_MQA,
|
||||||
|
LLM_TENSOR_ATTN_KV_B,
|
||||||
|
LLM_TENSOR_ATTN_K_B,
|
||||||
|
LLM_TENSOR_ATTN_V_B,
|
||||||
|
LLM_TENSOR_ATTN_Q_A_NORM,
|
||||||
|
LLM_TENSOR_ATTN_KV_A_NORM,
|
||||||
|
LLM_TENSOR_ATTN_SUB_NORM,
|
||||||
|
LLM_TENSOR_FFN_SUB_NORM,
|
||||||
|
LLM_TENSOR_DEC_ATTN_NORM,
|
||||||
|
LLM_TENSOR_DEC_ATTN_Q,
|
||||||
|
LLM_TENSOR_DEC_ATTN_K,
|
||||||
|
LLM_TENSOR_DEC_ATTN_V,
|
||||||
|
LLM_TENSOR_DEC_ATTN_OUT,
|
||||||
|
LLM_TENSOR_DEC_ATTN_REL_B,
|
||||||
|
LLM_TENSOR_DEC_CROSS_ATTN_NORM,
|
||||||
|
LLM_TENSOR_DEC_CROSS_ATTN_Q,
|
||||||
|
LLM_TENSOR_DEC_CROSS_ATTN_K,
|
||||||
|
LLM_TENSOR_DEC_CROSS_ATTN_V,
|
||||||
|
LLM_TENSOR_DEC_CROSS_ATTN_OUT,
|
||||||
|
LLM_TENSOR_DEC_CROSS_ATTN_REL_B,
|
||||||
|
LLM_TENSOR_DEC_FFN_NORM,
|
||||||
|
LLM_TENSOR_DEC_FFN_GATE,
|
||||||
|
LLM_TENSOR_DEC_FFN_DOWN,
|
||||||
|
LLM_TENSOR_DEC_FFN_UP,
|
||||||
|
LLM_TENSOR_DEC_OUTPUT_NORM,
|
||||||
|
LLM_TENSOR_ENC_ATTN_NORM,
|
||||||
|
LLM_TENSOR_ENC_ATTN_Q,
|
||||||
|
LLM_TENSOR_ENC_ATTN_K,
|
||||||
|
LLM_TENSOR_ENC_ATTN_V,
|
||||||
|
LLM_TENSOR_ENC_ATTN_OUT,
|
||||||
|
LLM_TENSOR_ENC_ATTN_REL_B,
|
||||||
|
LLM_TENSOR_ENC_FFN_NORM,
|
||||||
|
LLM_TENSOR_ENC_FFN_GATE,
|
||||||
|
LLM_TENSOR_ENC_FFN_DOWN,
|
||||||
|
LLM_TENSOR_ENC_FFN_UP,
|
||||||
|
LLM_TENSOR_ENC_OUTPUT_NORM,
|
||||||
|
LLM_TENSOR_NEXTN_EH_PROJ,
|
||||||
|
LLM_TENSOR_NEXTN_EMBED_TOKENS,
|
||||||
|
LLM_TENSOR_NEXTN_ENORM,
|
||||||
|
LLM_TENSOR_NEXTN_HNORM,
|
||||||
|
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
|
||||||
|
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
|
||||||
|
};
|
||||||
|
|
||||||
|
llm_arch llm_arch_from_string(const std::string & name);
|
||||||
@@ -486,9 +486,9 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
|
|||||||
|
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
const llama_token id = candidates->data[i].id;
|
const llama_token id = candidates->data[i].id;
|
||||||
const std::string & piece = vocab->cache_token_to_piece.at(id);
|
const std::string & piece = vocab->token_to_piece(id);
|
||||||
|
|
||||||
if (llama_token_is_eog_impl(*vocab, id)) {
|
if (vocab->is_eog(id)) {
|
||||||
if (!allow_eog) {
|
if (!allow_eog) {
|
||||||
candidates->data[i].logit = -INFINITY;
|
candidates->data[i].logit = -INFINITY;
|
||||||
}
|
}
|
||||||
@@ -511,7 +511,7 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
|
|||||||
void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
|
void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
if (llama_token_is_eog_impl(*vocab, token)) {
|
if (vocab->is_eog(token)) {
|
||||||
for (const auto & stack : grammar->stacks) {
|
for (const auto & stack : grammar->stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
return;
|
return;
|
||||||
@@ -520,7 +520,7 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
|
|||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string & piece = vocab->cache_token_to_piece.at(token);
|
const std::string & piece = vocab->token_to_piece(token);
|
||||||
|
|
||||||
// Note terminating 0 in decoded string
|
// Note terminating 0 in decoded string
|
||||||
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
|
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
|
||||||
|
|||||||
@@ -10,6 +10,11 @@
|
|||||||
#define LLAMA_API_INTERNAL
|
#define LLAMA_API_INTERNAL
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
#include <climits>
|
||||||
|
#include <cstdarg>
|
||||||
|
#include <vector>
|
||||||
|
#include <cinttypes>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
#ifdef __GNUC__
|
#ifdef __GNUC__
|
||||||
#ifdef __MINGW32__
|
#ifdef __MINGW32__
|
||||||
@@ -33,6 +38,7 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
|
|||||||
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
||||||
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
||||||
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||||
|
#define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||||
|
|
||||||
//
|
//
|
||||||
// helpers
|
// helpers
|
||||||
@@ -166,3 +172,49 @@ struct ring_buffer {
|
|||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
std::vector<T> data;
|
std::vector<T> data;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
LLAMA_ATTRIBUTE_FORMAT(1, 2)
|
||||||
|
static std::string format(const char * fmt, ...) {
|
||||||
|
va_list ap;
|
||||||
|
va_list ap2;
|
||||||
|
va_start(ap, fmt);
|
||||||
|
va_copy(ap2, ap);
|
||||||
|
int size = vsnprintf(NULL, 0, fmt, ap);
|
||||||
|
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
||||||
|
std::vector<char> buf(size + 1);
|
||||||
|
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
||||||
|
GGML_ASSERT(size2 == size);
|
||||||
|
va_end(ap2);
|
||||||
|
va_end(ap);
|
||||||
|
return std::string(buf.data(), size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string llama_format_tensor_shape(const std::vector<int64_t> & ne) {
|
||||||
|
char buf[256];
|
||||||
|
snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0));
|
||||||
|
for (size_t i = 1; i < ne.size(); i++) {
|
||||||
|
snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i));
|
||||||
|
}
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
|
||||||
|
char buf[256];
|
||||||
|
snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]);
|
||||||
|
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||||
|
snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]);
|
||||||
|
}
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct no_init {
|
||||||
|
T value;
|
||||||
|
no_init() { /* do nothing */ }
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct gguf_context;
|
||||||
|
std::string gguf_kv_to_str(const gguf_context * ctx_gguf, int i);
|
||||||
|
|
||||||
|
ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer);
|
||||||
|
|||||||
650
src/llama-mmap.cpp
Normal file
650
src/llama-mmap.cpp
Normal file
@@ -0,0 +1,650 @@
|
|||||||
|
#include "llama-mmap.h"
|
||||||
|
|
||||||
|
#include "llama-impl.h"
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <climits>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <cerrno>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#ifdef __has_include
|
||||||
|
#if __has_include(<unistd.h>)
|
||||||
|
#include <unistd.h>
|
||||||
|
#if defined(_POSIX_MAPPED_FILES)
|
||||||
|
#include <sys/mman.h>
|
||||||
|
#include <fcntl.h>
|
||||||
|
#endif
|
||||||
|
#if defined(_POSIX_MEMLOCK_RANGE)
|
||||||
|
#include <sys/resource.h>
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#define WIN32_LEAN_AND_MEAN
|
||||||
|
#ifndef NOMINMAX
|
||||||
|
#define NOMINMAX
|
||||||
|
#endif
|
||||||
|
#include <windows.h>
|
||||||
|
#ifndef PATH_MAX
|
||||||
|
#define PATH_MAX MAX_PATH
|
||||||
|
#endif
|
||||||
|
#include <io.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__APPLE__)
|
||||||
|
#include <TargetConditionals.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// TODO: consider moving to llama-impl.h if needed in more places
|
||||||
|
#if defined(_WIN32)
|
||||||
|
static std::string llama_format_win_err(DWORD err) {
|
||||||
|
LPSTR buf;
|
||||||
|
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||||
|
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
|
||||||
|
if (!size) {
|
||||||
|
return "FormatMessageA failed";
|
||||||
|
}
|
||||||
|
std::string ret(buf, size);
|
||||||
|
LocalFree(buf);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// llama_file
|
||||||
|
|
||||||
|
struct llama_file::impl {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
HANDLE fp_win32;
|
||||||
|
std::string GetErrorMessageWin32(DWORD error_code) const {
|
||||||
|
std::string ret;
|
||||||
|
LPSTR lpMsgBuf = NULL;
|
||||||
|
DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||||
|
NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
|
||||||
|
if (!bufLen) {
|
||||||
|
ret = format("Win32 error code: %lx", error_code);
|
||||||
|
} else {
|
||||||
|
ret = lpMsgBuf;
|
||||||
|
LocalFree(lpMsgBuf);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl(const char * fname, const char * mode) {
|
||||||
|
fp = ggml_fopen(fname, mode);
|
||||||
|
if (fp == NULL) {
|
||||||
|
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
||||||
|
}
|
||||||
|
fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp));
|
||||||
|
seek(0, SEEK_END);
|
||||||
|
size = tell();
|
||||||
|
seek(0, SEEK_SET);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t tell() const {
|
||||||
|
LARGE_INTEGER li;
|
||||||
|
li.QuadPart = 0;
|
||||||
|
BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT);
|
||||||
|
if (!ret) {
|
||||||
|
throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return li.QuadPart;
|
||||||
|
}
|
||||||
|
|
||||||
|
void seek(size_t offset, int whence) const {
|
||||||
|
static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN");
|
||||||
|
static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT");
|
||||||
|
static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END");
|
||||||
|
|
||||||
|
LARGE_INTEGER li;
|
||||||
|
li.QuadPart = offset;
|
||||||
|
BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence);
|
||||||
|
if (!ret) {
|
||||||
|
throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void read_raw(void * ptr, size_t len) const {
|
||||||
|
size_t bytes_read = 0;
|
||||||
|
while (bytes_read < len) {
|
||||||
|
size_t chunk_size = std::min<size_t>(len - bytes_read, 64*1024*1024);
|
||||||
|
DWORD chunk_read = 0;
|
||||||
|
BOOL result = ReadFile(fp_win32, reinterpret_cast<char*>(ptr) + bytes_read, chunk_size, &chunk_read, NULL);
|
||||||
|
if (!result) {
|
||||||
|
throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
|
||||||
|
}
|
||||||
|
if (chunk_read < chunk_size || chunk_read == 0) {
|
||||||
|
throw std::runtime_error("unexpectedly reached end of file");
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes_read += chunk_read;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t read_u32() const {
|
||||||
|
uint32_t val;
|
||||||
|
read_raw(&val, sizeof(val));
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
void write_raw(const void * ptr, size_t len) const {
|
||||||
|
size_t bytes_written = 0;
|
||||||
|
while (bytes_written < len) {
|
||||||
|
size_t chunk_size = std::min<size_t>(len - bytes_written, 64*1024*1024);
|
||||||
|
DWORD chunk_written = 0;
|
||||||
|
BOOL result = WriteFile(fp_win32, reinterpret_cast<char const*>(ptr) + bytes_written, chunk_size, &chunk_written, NULL);
|
||||||
|
if (!result) {
|
||||||
|
throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
|
||||||
|
}
|
||||||
|
if (chunk_written < chunk_size || chunk_written == 0) {
|
||||||
|
throw std::runtime_error("unexpectedly failed to write bytes");
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes_written += chunk_written;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void write_u32(uint32_t val) const {
|
||||||
|
write_raw(&val, sizeof(val));
|
||||||
|
}
|
||||||
|
|
||||||
|
~impl() {
|
||||||
|
if (fp) {
|
||||||
|
std::fclose(fp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
impl(const char * fname, const char * mode) {
|
||||||
|
fp = ggml_fopen(fname, mode);
|
||||||
|
if (fp == NULL) {
|
||||||
|
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
||||||
|
}
|
||||||
|
seek(0, SEEK_END);
|
||||||
|
size = tell();
|
||||||
|
seek(0, SEEK_SET);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t tell() const {
|
||||||
|
// TODO: this ifdef is never true?
|
||||||
|
#ifdef _WIN32
|
||||||
|
__int64 ret = _ftelli64(fp);
|
||||||
|
#else
|
||||||
|
long ret = std::ftell(fp);
|
||||||
|
#endif
|
||||||
|
if (ret == -1) {
|
||||||
|
throw std::runtime_error(format("ftell error: %s", strerror(errno)));
|
||||||
|
}
|
||||||
|
|
||||||
|
return (size_t) ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void seek(size_t offset, int whence) const {
|
||||||
|
// TODO: this ifdef is never true?
|
||||||
|
#ifdef _WIN32
|
||||||
|
int ret = _fseeki64(fp, (__int64) offset, whence);
|
||||||
|
#else
|
||||||
|
int ret = std::fseek(fp, (long) offset, whence);
|
||||||
|
#endif
|
||||||
|
if (ret != 0) {
|
||||||
|
throw std::runtime_error(format("seek error: %s", strerror(errno)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void read_raw(void * ptr, size_t len) const {
|
||||||
|
if (len == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
errno = 0;
|
||||||
|
std::size_t ret = std::fread(ptr, len, 1, fp);
|
||||||
|
if (ferror(fp)) {
|
||||||
|
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
||||||
|
}
|
||||||
|
if (ret != 1) {
|
||||||
|
throw std::runtime_error("unexpectedly reached end of file");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t read_u32() const {
|
||||||
|
uint32_t ret;
|
||||||
|
read_raw(&ret, sizeof(ret));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void write_raw(const void * ptr, size_t len) const {
|
||||||
|
if (len == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
errno = 0;
|
||||||
|
size_t ret = std::fwrite(ptr, len, 1, fp);
|
||||||
|
if (ret != 1) {
|
||||||
|
throw std::runtime_error(format("write error: %s", strerror(errno)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void write_u32(uint32_t val) const {
|
||||||
|
write_raw(&val, sizeof(val));
|
||||||
|
}
|
||||||
|
|
||||||
|
~impl() {
|
||||||
|
if (fp) {
|
||||||
|
std::fclose(fp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
FILE * fp;
|
||||||
|
size_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique<impl>(fname, mode)) {}
|
||||||
|
llama_file::~llama_file() = default;
|
||||||
|
|
||||||
|
size_t llama_file::tell() const { return pimpl->tell(); }
|
||||||
|
size_t llama_file::size() const { return pimpl->size; }
|
||||||
|
|
||||||
|
int llama_file::file_id() const {
|
||||||
|
#ifdef _WIN32
|
||||||
|
return _fileno(pimpl->fp);
|
||||||
|
#else
|
||||||
|
#if defined(fileno)
|
||||||
|
return fileno(pimpl->fp);
|
||||||
|
#else
|
||||||
|
return ::fileno(pimpl->fp);
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
|
||||||
|
void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); }
|
||||||
|
|
||||||
|
uint32_t llama_file::read_u32() const { return pimpl->read_u32(); }
|
||||||
|
|
||||||
|
void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); }
|
||||||
|
void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); }
|
||||||
|
|
||||||
|
// llama_mmap
|
||||||
|
|
||||||
|
struct llama_mmap::impl {
|
||||||
|
#ifdef _POSIX_MAPPED_FILES
|
||||||
|
std::vector<std::pair<size_t, size_t>> mapped_fragments;
|
||||||
|
|
||||||
|
impl(struct llama_file * file, size_t prefetch, bool numa, bool use_thp) {
|
||||||
|
size = file->size();
|
||||||
|
int fd = file->file_id();
|
||||||
|
int flags = MAP_SHARED;
|
||||||
|
if (numa) { prefetch = 0; }
|
||||||
|
#ifdef __linux__
|
||||||
|
if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) {
|
||||||
|
LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n",
|
||||||
|
strerror(errno));
|
||||||
|
}
|
||||||
|
if (prefetch) { flags |= MAP_POPULATE; }
|
||||||
|
if (use_thp) {
|
||||||
|
size_t huge = get_default_huge_page_size();
|
||||||
|
auto size = huge*((file->size() + huge - 1)/huge);
|
||||||
|
addr = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB, -1, 0);
|
||||||
|
if (addr != MAP_FAILED) {
|
||||||
|
printf("%s: using THP with page size %zu MiB ", __func__, huge/(1024*1024));
|
||||||
|
fflush(stdout);
|
||||||
|
size_t tot = 0;
|
||||||
|
while (tot < file->size()) {
|
||||||
|
auto n_read = pread(fd, static_cast<char*>(addr) + tot, file->size() - tot, tot);
|
||||||
|
if (n_read < 0) throw std::runtime_error(format("Reading into mapped huge pages failed at %zu (%s)", tot, strerror(errno)));
|
||||||
|
printf("."); fflush(stdout);
|
||||||
|
tot += n_read;
|
||||||
|
}
|
||||||
|
printf(" done\n");
|
||||||
|
mapped_fragments.emplace_back(0, file->size());
|
||||||
|
mapped_page_size = huge;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
fprintf(stderr, "%s: mmap with huge page size %zu MiB failed (%s)\n", __func__, huge/(1024*1024), strerror(errno));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
addr = mmap(NULL, file->size(), PROT_READ, flags, fd, 0);
|
||||||
|
if (addr == MAP_FAILED) {
|
||||||
|
throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (prefetch > 0) {
|
||||||
|
if (posix_madvise(addr, std::min(file->size(), prefetch), POSIX_MADV_WILLNEED)) {
|
||||||
|
LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
|
||||||
|
strerror(errno));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (numa) {
|
||||||
|
if (posix_madvise(addr, file->size(), POSIX_MADV_RANDOM)) {
|
||||||
|
LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n",
|
||||||
|
strerror(errno));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mapped_fragments.emplace_back(0, file->size());
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef __linux__
|
||||||
|
static int get_default_huge_page_size() {
|
||||||
|
int pg_size = 2048;
|
||||||
|
std::ifstream in("/proc/meminfo");
|
||||||
|
if (in) {
|
||||||
|
std::string line;
|
||||||
|
while (true) {
|
||||||
|
std::getline(in, line);
|
||||||
|
if (in.fail()) break;
|
||||||
|
if (auto pos = line.find("Hugepagesize:"); pos != std::string::npos) {
|
||||||
|
std::istringstream str(line.data() + pos + 13);
|
||||||
|
int aux;
|
||||||
|
str >> aux;
|
||||||
|
if (!str.fail()) pg_size = aux;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pg_size * 1024;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
static void align_range(size_t * first, size_t * last, size_t page_size) {
|
||||||
|
size_t offset_in_page = *first & (page_size - 1);
|
||||||
|
size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page;
|
||||||
|
*first += offset_to_page;
|
||||||
|
|
||||||
|
*last = *last & ~(page_size - 1);
|
||||||
|
|
||||||
|
if (*last <= *first) {
|
||||||
|
*last = *first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void unmap_fragment(size_t first, size_t last) {
|
||||||
|
int page_size = mapped_page_size > 0 ? mapped_page_size : sysconf(_SC_PAGESIZE);
|
||||||
|
align_range(&first, &last, page_size);
|
||||||
|
size_t len = last - first;
|
||||||
|
|
||||||
|
if (len == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(first % page_size == 0);
|
||||||
|
GGML_ASSERT(last % page_size == 0);
|
||||||
|
GGML_ASSERT(last > first);
|
||||||
|
|
||||||
|
void * next_page_start = (uint8_t *) addr + first;
|
||||||
|
|
||||||
|
if (munmap(next_page_start, len)) {
|
||||||
|
LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<size_t, size_t>> new_mapped_fragments;
|
||||||
|
for (const auto & frag : mapped_fragments) {
|
||||||
|
if (frag.first < first && frag.second > last) {
|
||||||
|
new_mapped_fragments.emplace_back(frag.first, first);
|
||||||
|
new_mapped_fragments.emplace_back(last, frag.second);
|
||||||
|
} else if (frag.first < first && frag.second > first) {
|
||||||
|
new_mapped_fragments.emplace_back(frag.first, first);
|
||||||
|
} else if (frag.first < last && frag.second > last) {
|
||||||
|
new_mapped_fragments.emplace_back(last, frag.second);
|
||||||
|
} else if (frag.first >= first && frag.second <= last) {
|
||||||
|
} else {
|
||||||
|
new_mapped_fragments.push_back(frag);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mapped_fragments = std::move(new_mapped_fragments);
|
||||||
|
}
|
||||||
|
|
||||||
|
~impl() {
|
||||||
|
for (const auto & frag : mapped_fragments) {
|
||||||
|
if (munmap((char *) addr + frag.first, frag.second - frag.first)) {
|
||||||
|
LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#elif defined(_WIN32)
|
||||||
|
impl(struct llama_file * file, size_t prefetch, bool numa, [[maybe_unused]] bool use_thp) {
|
||||||
|
GGML_UNUSED(numa);
|
||||||
|
|
||||||
|
size = file->size();
|
||||||
|
|
||||||
|
HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id());
|
||||||
|
|
||||||
|
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
|
||||||
|
|
||||||
|
if (hMapping == NULL) {
|
||||||
|
DWORD error = GetLastError();
|
||||||
|
throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
|
||||||
|
DWORD error = GetLastError();
|
||||||
|
CloseHandle(hMapping);
|
||||||
|
|
||||||
|
if (addr == NULL) {
|
||||||
|
throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (prefetch > 0) {
|
||||||
|
#if _WIN32_WINNT >= 0x602
|
||||||
|
BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG);
|
||||||
|
HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll");
|
||||||
|
|
||||||
|
pPrefetchVirtualMemory = (decltype(pPrefetchVirtualMemory))(void *) GetProcAddress(hKernel32, "PrefetchVirtualMemory");
|
||||||
|
|
||||||
|
if (pPrefetchVirtualMemory) {
|
||||||
|
WIN32_MEMORY_RANGE_ENTRY range;
|
||||||
|
range.VirtualAddress = addr;
|
||||||
|
range.NumberOfBytes = (SIZE_T) std::min(size, prefetch);
|
||||||
|
if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
|
||||||
|
LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n",
|
||||||
|
llama_format_win_err(GetLastError()).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void unmap_fragment(size_t first, size_t last) {
|
||||||
|
GGML_UNUSED(first);
|
||||||
|
GGML_UNUSED(last);
|
||||||
|
}
|
||||||
|
|
||||||
|
~impl() {
|
||||||
|
if (!UnmapViewOfFile(addr)) {
|
||||||
|
LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n",
|
||||||
|
llama_format_win_err(GetLastError()).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
impl(struct llama_file * file, size_t prefetch, bool numa, [[maybe_unused]] bool use_thp) {
|
||||||
|
GGML_UNUSED(file);
|
||||||
|
GGML_UNUSED(prefetch);
|
||||||
|
GGML_UNUSED(numa);
|
||||||
|
|
||||||
|
throw std::runtime_error("mmap not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
void unmap_fragment(size_t first, size_t last) {
|
||||||
|
GGML_UNUSED(first);
|
||||||
|
GGML_UNUSED(last);
|
||||||
|
|
||||||
|
throw std::runtime_error("mmap not supported");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void * addr;
|
||||||
|
size_t size;
|
||||||
|
size_t mapped_page_size = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
llama_mmap::llama_mmap(struct llama_file * file, size_t prefetch, bool numa, bool use_thp) :
|
||||||
|
pimpl(std::make_unique<impl>(file, prefetch, numa, use_thp)) {}
|
||||||
|
llama_mmap::~llama_mmap() = default;
|
||||||
|
|
||||||
|
size_t llama_mmap::size() const { return pimpl->size; }
|
||||||
|
void * llama_mmap::addr() const { return pimpl->addr; }
|
||||||
|
|
||||||
|
void llama_mmap::unmap_fragment(size_t first, size_t last) { pimpl->unmap_fragment(first, last); }
|
||||||
|
|
||||||
|
#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32)
|
||||||
|
const bool llama_mmap::SUPPORTED = true;
|
||||||
|
#else
|
||||||
|
const bool llama_mmap::SUPPORTED = false;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// llama_mlock
|
||||||
|
|
||||||
|
struct llama_mlock::impl {
|
||||||
|
#ifdef _POSIX_MEMLOCK_RANGE
|
||||||
|
static size_t lock_granularity() {
|
||||||
|
return (size_t) sysconf(_SC_PAGESIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool raw_lock(const void * addr, size_t size) const {
|
||||||
|
if (!mlock(addr, size)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#define MLOCK_SUGGESTION \
|
||||||
|
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
|
||||||
|
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n"
|
||||||
|
#else
|
||||||
|
#define MLOCK_SUGGESTION \
|
||||||
|
"Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
char* errmsg = std::strerror(errno);
|
||||||
|
bool suggest = (errno == ENOMEM);
|
||||||
|
#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX)
|
||||||
|
// visionOS/tvOS dont't support RLIMIT_MEMLOCK
|
||||||
|
// Skip resource limit checks on visionOS/tvOS
|
||||||
|
suggest = false;
|
||||||
|
#else
|
||||||
|
struct rlimit lock_limit;
|
||||||
|
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
|
||||||
|
suggest = false;
|
||||||
|
}
|
||||||
|
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
|
||||||
|
suggest = false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
|
||||||
|
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void raw_unlock(void * addr, size_t size) {
|
||||||
|
if (munlock(addr, size)) {
|
||||||
|
LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#elif defined(_WIN32)
|
||||||
|
static size_t lock_granularity() {
|
||||||
|
SYSTEM_INFO si;
|
||||||
|
GetSystemInfo(&si);
|
||||||
|
return (size_t) si.dwPageSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool raw_lock(void * ptr, size_t len) const {
|
||||||
|
for (int tries = 1; ; tries++) {
|
||||||
|
if (VirtualLock(ptr, len)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (tries == 2) {
|
||||||
|
LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
|
||||||
|
len, size, llama_format_win_err(GetLastError()).c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
SIZE_T min_ws_size, max_ws_size;
|
||||||
|
if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
|
||||||
|
LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n",
|
||||||
|
llama_format_win_err(GetLastError()).c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_t increment = len + 1048576;
|
||||||
|
min_ws_size += increment;
|
||||||
|
max_ws_size += increment;
|
||||||
|
if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
|
||||||
|
LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n",
|
||||||
|
llama_format_win_err(GetLastError()).c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void raw_unlock(void * ptr, size_t len) {
|
||||||
|
if (!VirtualUnlock(ptr, len)) {
|
||||||
|
LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n",
|
||||||
|
llama_format_win_err(GetLastError()).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
static size_t lock_granularity() {
|
||||||
|
return (size_t) 65536;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool raw_lock(const void * addr, size_t len) const {
|
||||||
|
LLAMA_LOG_WARN("warning: mlock not supported on this system\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void raw_unlock(const void * addr, size_t len) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
impl() : addr(NULL), size(0), failed_already(false) {}
|
||||||
|
|
||||||
|
void init(void * ptr) {
|
||||||
|
GGML_ASSERT(addr == NULL && size == 0);
|
||||||
|
addr = ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void grow_to(size_t target_size) {
|
||||||
|
GGML_ASSERT(addr);
|
||||||
|
if (failed_already) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
size_t granularity = lock_granularity();
|
||||||
|
target_size = (target_size + granularity - 1) & ~(granularity - 1);
|
||||||
|
if (target_size > size) {
|
||||||
|
if (raw_lock((uint8_t *) addr + size, target_size - size)) {
|
||||||
|
size = target_size;
|
||||||
|
} else {
|
||||||
|
failed_already = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void * addr;
|
||||||
|
size_t size;
|
||||||
|
|
||||||
|
bool failed_already;
|
||||||
|
};
|
||||||
|
|
||||||
|
llama_mlock::llama_mlock() : pimpl(std::make_unique<impl>()) {}
|
||||||
|
llama_mlock::~llama_mlock() = default;
|
||||||
|
|
||||||
|
void llama_mlock::init(void * ptr) { pimpl->init(ptr); }
|
||||||
|
void llama_mlock::grow_to(size_t target_size) { pimpl->grow_to(target_size); }
|
||||||
|
|
||||||
|
#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32)
|
||||||
|
const bool llama_mlock::SUPPORTED = true;
|
||||||
|
#else
|
||||||
|
const bool llama_mlock::SUPPORTED = false;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
size_t llama_path_max() {
|
||||||
|
return PATH_MAX;
|
||||||
|
}
|
||||||
68
src/llama-mmap.h
Normal file
68
src/llama-mmap.h
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
struct llama_file;
|
||||||
|
struct llama_mmap;
|
||||||
|
struct llama_mlock;
|
||||||
|
|
||||||
|
using llama_files = std::vector<std::unique_ptr<llama_file>>;
|
||||||
|
using llama_mmaps = std::vector<std::unique_ptr<llama_mmap>>;
|
||||||
|
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
|
||||||
|
|
||||||
|
struct llama_file {
|
||||||
|
llama_file(const char * fname, const char * mode);
|
||||||
|
~llama_file();
|
||||||
|
|
||||||
|
size_t tell() const;
|
||||||
|
size_t size() const;
|
||||||
|
|
||||||
|
int file_id() const; // fileno overload
|
||||||
|
|
||||||
|
void seek(size_t offset, int whence) const;
|
||||||
|
|
||||||
|
void read_raw(void * ptr, size_t len) const;
|
||||||
|
uint32_t read_u32() const;
|
||||||
|
|
||||||
|
void write_raw(const void * ptr, size_t len) const;
|
||||||
|
void write_u32(uint32_t val) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct impl;
|
||||||
|
std::unique_ptr<impl> pimpl;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_mmap {
|
||||||
|
llama_mmap(const llama_mmap &) = delete;
|
||||||
|
llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false, bool use_thp = false);
|
||||||
|
~llama_mmap();
|
||||||
|
|
||||||
|
size_t size() const;
|
||||||
|
void * addr() const;
|
||||||
|
|
||||||
|
void unmap_fragment(size_t first, size_t last);
|
||||||
|
|
||||||
|
static const bool SUPPORTED;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct impl;
|
||||||
|
std::unique_ptr<impl> pimpl;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_mlock {
|
||||||
|
llama_mlock();
|
||||||
|
~llama_mlock();
|
||||||
|
|
||||||
|
void init(void * ptr);
|
||||||
|
void grow_to(size_t target_size);
|
||||||
|
|
||||||
|
static const bool SUPPORTED;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct impl;
|
||||||
|
std::unique_ptr<impl> pimpl;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t llama_path_max();
|
||||||
1082
src/llama-model-loader.cpp
Normal file
1082
src/llama-model-loader.cpp
Normal file
File diff suppressed because it is too large
Load Diff
169
src/llama-model-loader.h
Normal file
169
src/llama-model-loader.h
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
#include "llama-impl.h"
|
||||||
|
#include "llama-mmap.h"
|
||||||
|
#include "llama-arch.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
enum llama_fver {
|
||||||
|
GGUF_FILE_VERSION_V1 = 1,
|
||||||
|
GGUF_FILE_VERSION_V2 = 2,
|
||||||
|
GGUF_FILE_VERSION_V3 = 3,
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char * llama_file_version_name(llama_fver version) {
|
||||||
|
switch (version) {
|
||||||
|
case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)";
|
||||||
|
case GGUF_FILE_VERSION_V2: return "GGUF V2";
|
||||||
|
case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)";
|
||||||
|
}
|
||||||
|
|
||||||
|
return "unknown";
|
||||||
|
}
|
||||||
|
|
||||||
|
using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>;
|
||||||
|
|
||||||
|
struct llama_model_loader {
|
||||||
|
int n_kv = 0;
|
||||||
|
int n_tensors = 0;
|
||||||
|
int n_created = 0;
|
||||||
|
|
||||||
|
int64_t n_elements = 0;
|
||||||
|
size_t n_bytes = 0;
|
||||||
|
|
||||||
|
bool use_mmap = false;
|
||||||
|
bool check_tensors;
|
||||||
|
bool repack_tensors = false;
|
||||||
|
bool use_thp = false;
|
||||||
|
|
||||||
|
llama_files files;
|
||||||
|
llama_ftype ftype;
|
||||||
|
llama_fver fver;
|
||||||
|
|
||||||
|
llama_mmaps mappings;
|
||||||
|
|
||||||
|
// Holds information on a model weight
|
||||||
|
struct llama_tensor_weight {
|
||||||
|
uint16_t idx; // source file index
|
||||||
|
size_t offs; // tensor data offset in the original file
|
||||||
|
|
||||||
|
ggml_tensor * tensor;
|
||||||
|
|
||||||
|
llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
|
||||||
|
const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
|
||||||
|
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
|
||||||
|
|
||||||
|
if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size()) {
|
||||||
|
throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
std::vector<llama_tensor_weight> weights;
|
||||||
|
|
||||||
|
std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
|
||||||
|
const llama_model_tensor_buft_override * tensor_buft_overrides;
|
||||||
|
|
||||||
|
gguf_context * meta = NULL;
|
||||||
|
std::vector<ggml_context *> contexts;
|
||||||
|
|
||||||
|
std::string arch_name;
|
||||||
|
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
|
||||||
|
|
||||||
|
llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp,
|
||||||
|
const llama_model_kv_override * param_overrides_p,
|
||||||
|
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
|
||||||
|
|
||||||
|
~llama_model_loader();
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
||||||
|
get_arr_n(const std::string & key, T & result, const bool required = true);
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
||||||
|
get_arr_n(const enum llm_kv kid, T & result, const bool required = true);
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
bool get_arr(const std::string & key, std::vector<T> & result, const bool required = true);
|
||||||
|
|
||||||
|
template<typename T, size_t N_MAX>
|
||||||
|
bool get_arr(const std::string & key, std::array<T, N_MAX> & result, const bool required = true);
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
bool get_arr(const enum llm_kv kid, T & result, const bool required = true);
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
bool get_key(const std::string & key, T & result, const bool required = true);
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
bool get_key(const enum llm_kv kid, T & result, const bool required = true);
|
||||||
|
|
||||||
|
// get array of n <= N_MAX elements, or a single element repeated n times
|
||||||
|
template<typename T, size_t N_MAX>
|
||||||
|
bool get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, const bool required = true);
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
bool get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required = true);
|
||||||
|
|
||||||
|
const std::string& get_arch_name() const { return arch_name; }
|
||||||
|
|
||||||
|
enum llm_arch get_arch() const { return llm_kv.arch; }
|
||||||
|
|
||||||
|
const char * get_tensor_name(int i) const;
|
||||||
|
|
||||||
|
const llama_tensor_weight * get_weight(const char * name) const;
|
||||||
|
|
||||||
|
const llama_tensor_weight * get_weight(int i) const {
|
||||||
|
return get_weight(get_tensor_name(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_tensor_weight & require_weight(const char * name) const;
|
||||||
|
|
||||||
|
struct ggml_tensor * get_tensor_meta(const char * name) const;
|
||||||
|
|
||||||
|
struct ggml_tensor * require_tensor_meta(const char * name) const;
|
||||||
|
|
||||||
|
struct ggml_tensor * get_tensor_meta(int i) const {
|
||||||
|
return get_tensor_meta(get_tensor_name(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated);
|
||||||
|
|
||||||
|
const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector<int64_t> & ne, bool required) const;
|
||||||
|
|
||||||
|
static const int TENSOR_NOT_REQUIRED = 1 << 0;
|
||||||
|
static const int TENSOR_DUPLICATED = 1 << 1;
|
||||||
|
static const int TENSOR_SKIP = 1 << 2;
|
||||||
|
|
||||||
|
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, int flags = 0);
|
||||||
|
|
||||||
|
struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base,
|
||||||
|
const std::string & name, const std::vector<int64_t> & ne, size_t offset, bool required = true);
|
||||||
|
|
||||||
|
void done_getting_tensors() const;
|
||||||
|
|
||||||
|
void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr, bool use_thp = false);
|
||||||
|
|
||||||
|
void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const;
|
||||||
|
|
||||||
|
// for backwards compatibility, does not support ggml-backend
|
||||||
|
void load_data_for(struct ggml_tensor * cur) const;
|
||||||
|
|
||||||
|
size_t size_done = 0;
|
||||||
|
size_t size_data = 0;
|
||||||
|
std::vector<std::pair<size_t, size_t>> mmaps_used;
|
||||||
|
|
||||||
|
// Returns false if cancelled by progress_callback
|
||||||
|
bool load_all_data(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
llama_buf_map & bufs_mmap,
|
||||||
|
llama_mlocks * lmlocks,
|
||||||
|
llama_progress_callback progress_callback,
|
||||||
|
void * progress_callback_user_data);
|
||||||
|
};
|
||||||
@@ -734,7 +734,7 @@ llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_da
|
|||||||
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
||||||
static void get_overlapping_token_sequences(const llama_vocab& vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
|
static void get_overlapping_token_sequences(const llama_vocab& vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_tokens(); token_id++) {
|
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_tokens(); token_id++) {
|
||||||
std::string word = llama_detokenize(vocab, { token_id }, true);
|
auto word = vocab.detokenize( { token_id }, true);
|
||||||
if (word.find(str) != std::string::npos) {
|
if (word.find(str) != std::string::npos) {
|
||||||
token_sequences.emplace(token_id, std::vector<llama_token>());
|
token_sequences.emplace(token_id, std::vector<llama_token>());
|
||||||
}
|
}
|
||||||
@@ -751,7 +751,8 @@ static void get_overlapping_token_sequences(const llama_vocab& vocab, const std:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (match) {
|
if (match) {
|
||||||
std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
|
auto tokenization = vocab.tokenize(str.substr(i), false, false);
|
||||||
|
//std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
|
||||||
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
|
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
|
||||||
tokenization.resize(max_tail_len);
|
tokenization.resize(max_tail_len);
|
||||||
}
|
}
|
||||||
|
|||||||
3137
src/llama-vocab.cpp
3137
src/llama-vocab.cpp
File diff suppressed because it is too large
Load Diff
@@ -1,155 +1,178 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "llama-impl.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_map>
|
#include <memory>
|
||||||
#include <map>
|
|
||||||
|
// pre-tokenization types
|
||||||
|
enum llama_vocab_pre_type {
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LLM_KV;
|
||||||
|
struct llama_model_loader;
|
||||||
|
|
||||||
struct llama_vocab {
|
struct llama_vocab {
|
||||||
using id = llama_token;
|
|
||||||
using token = std::string;
|
|
||||||
using tattr = llama_token_attr;
|
|
||||||
|
|
||||||
struct token_data {
|
struct token_data {
|
||||||
token text;
|
std::string text;
|
||||||
float score;
|
float score;
|
||||||
tattr attr;
|
llama_token_attr attr;
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
llama_vocab();
|
||||||
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
~llama_vocab();
|
||||||
|
|
||||||
int max_token_len = 0; // used for optimizing longest token search
|
void load(llama_model_loader & ml, const LLM_KV & kv);
|
||||||
|
|
||||||
|
std::string get_tokenizer_model() const;
|
||||||
|
std::string get_tokenizer_pre() const;
|
||||||
|
|
||||||
|
enum llama_vocab_type get_type() const;
|
||||||
|
enum llama_vocab_pre_type get_pre_type() const;
|
||||||
|
|
||||||
uint32_t n_tokens() const;
|
uint32_t n_tokens() const;
|
||||||
|
uint32_t n_token_types() const;
|
||||||
|
|
||||||
std::unordered_map<token, id> token_to_id;
|
std::string type_name() const;
|
||||||
std::vector<token_data> id_to_token;
|
|
||||||
|
|
||||||
std::vector<id> cache_special_tokens;
|
bool is_normal (llama_token id) const;
|
||||||
std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = true);
|
bool is_unknown (llama_token id) const;
|
||||||
|
bool is_control (llama_token id) const;
|
||||||
|
bool is_byte (llama_token id) const;
|
||||||
|
bool is_user_defined(llama_token id) const;
|
||||||
|
bool is_unused (llama_token id) const;
|
||||||
|
bool is_eog (llama_token id) const;
|
||||||
|
|
||||||
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
|
uint8_t token_to_byte(llama_token id) const;
|
||||||
|
llama_token byte_to_token(uint8_t ch) const;
|
||||||
|
|
||||||
// default LLaMA special tokens
|
llama_token text_to_token(const std::string & text) const;
|
||||||
id special_bos_id = 1;
|
|
||||||
id special_eos_id = 2;
|
|
||||||
id special_unk_id = 0;
|
|
||||||
id special_sep_id = -1;
|
|
||||||
id special_pad_id = -1;
|
|
||||||
id special_cls_id = -1;
|
|
||||||
id special_mask_id = -1;
|
|
||||||
|
|
||||||
id linefeed_id = 13;
|
const token_data & get_token_data(llama_token id) const;
|
||||||
|
|
||||||
// fim tokens
|
const char * token_get_text (llama_token id) const;
|
||||||
llama_token special_fim_pre_id = -1;
|
float token_get_score(llama_token id) const;
|
||||||
llama_token special_fim_suf_id = -1;
|
llama_token_attr token_get_attr (llama_token id) const;
|
||||||
llama_token special_fim_mid_id = -1;
|
|
||||||
llama_token special_fim_pad_id = -1;
|
|
||||||
llama_token special_fim_rep_id = -1; // repo
|
|
||||||
llama_token special_fim_sep_id = -1; // file separator
|
|
||||||
|
|
||||||
id special_prefix_id = -1;
|
llama_token token_bos() const;
|
||||||
id special_suffix_id = -1;
|
llama_token token_eos() const;
|
||||||
id special_middle_id = -1;
|
llama_token token_eot() const;
|
||||||
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
|
llama_token token_eom() const;
|
||||||
id special_eom_id = -1;
|
llama_token token_unk() const;
|
||||||
|
llama_token token_sep() const;
|
||||||
|
llama_token token_nl () const;
|
||||||
|
llama_token token_pad() const;
|
||||||
|
llama_token token_mask() const;
|
||||||
|
|
||||||
// tokenizer flags
|
llama_token token_prefix() const;
|
||||||
bool tokenizer_add_space_prefix = false;
|
llama_token token_middle() const;
|
||||||
bool tokenizer_add_bos = false;
|
llama_token token_suffix() const;
|
||||||
bool tokenizer_add_eos = false;
|
|
||||||
bool tokenizer_ignore_merges = false;
|
|
||||||
bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
|
|
||||||
bool tokenizer_remove_extra_whitespaces = false;
|
|
||||||
bool tokenizer_escape_whitespaces = true;
|
|
||||||
bool tokenizer_treat_whitespace_as_suffix = false;
|
|
||||||
|
|
||||||
std::vector<char> precompiled_charsmap;
|
llama_token token_fim_pre() const;
|
||||||
|
llama_token token_fim_suf() const;
|
||||||
|
llama_token token_fim_mid() const;
|
||||||
|
llama_token token_fim_pad() const;
|
||||||
|
llama_token token_fim_rep() const;
|
||||||
|
llama_token token_fim_sep() const;
|
||||||
|
|
||||||
|
bool get_add_space_prefix () const;
|
||||||
|
bool get_add_bos () const;
|
||||||
|
bool get_add_eos () const;
|
||||||
|
bool get_add_sep () const;
|
||||||
|
bool get_ignore_merges () const;
|
||||||
|
bool get_clean_spaces () const;
|
||||||
|
bool get_remove_extra_whitespaces () const;
|
||||||
|
bool get_escape_whitespaces () const;
|
||||||
|
bool get_treat_whitespace_as_suffix() const;
|
||||||
|
|
||||||
|
int max_token_len() const;
|
||||||
|
|
||||||
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
||||||
|
std::vector<std::string> get_bpe_merges() const;
|
||||||
|
|
||||||
|
std::vector<char> get_precompiled_charsmap() const;
|
||||||
|
|
||||||
|
int32_t tokenize(
|
||||||
|
const char * text,
|
||||||
|
int32_t text_len,
|
||||||
|
llama_token * tokens,
|
||||||
|
int32_t n_tokens_max,
|
||||||
|
bool add_special,
|
||||||
|
bool parse_special) const;
|
||||||
|
|
||||||
|
std::vector<llama_token> tokenize(
|
||||||
|
const std::string & raw_text,
|
||||||
|
bool add_special,
|
||||||
|
bool parse_special = false) const;
|
||||||
|
|
||||||
|
// does not write null-terminator to buf
|
||||||
|
int32_t token_to_piece(
|
||||||
|
llama_token token,
|
||||||
|
char * buf,
|
||||||
|
int32_t length,
|
||||||
|
int32_t lstrip,
|
||||||
|
bool special) const;
|
||||||
|
|
||||||
|
// use cached data
|
||||||
|
const std::string & token_to_piece(llama_token token) const;
|
||||||
|
|
||||||
|
int32_t detokenize(
|
||||||
|
const llama_token * tokens,
|
||||||
|
int32_t n_tokens,
|
||||||
|
char * text,
|
||||||
|
int32_t text_len_max,
|
||||||
|
bool remove_special,
|
||||||
|
bool unparse_special) const;
|
||||||
|
|
||||||
|
std::string detokenize(
|
||||||
|
const std::vector<llama_token> & tokens,
|
||||||
|
bool special) const;
|
||||||
|
|
||||||
|
void print_info() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct impl;
|
||||||
|
std::unique_ptr<impl> pimpl;
|
||||||
};
|
};
|
||||||
|
|
||||||
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
|
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
|
||||||
|
|
||||||
//
|
|
||||||
// internal API
|
|
||||||
//
|
|
||||||
|
|
||||||
// TODO: rename to llama_tokenize_impl
|
|
||||||
// TODO: This should probably be in llama.h
|
|
||||||
std::vector<llama_vocab::id> llama_tokenize_internal(
|
|
||||||
const llama_vocab & vocab,
|
|
||||||
std::string raw_text,
|
|
||||||
bool add_special,
|
|
||||||
bool parse_special = false);
|
|
||||||
|
|
||||||
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
|
|
||||||
|
|
||||||
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
|
|
||||||
|
|
||||||
float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
|
|
||||||
|
|
||||||
llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
|
|
||||||
|
|
||||||
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
|
|
||||||
|
|
||||||
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
|
|
||||||
|
|
||||||
llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
|
|
||||||
|
|
||||||
int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab);
|
|
||||||
int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab);
|
|
||||||
|
|
||||||
llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
|
|
||||||
|
|
||||||
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_eot_impl (const struct llama_vocab & vocab);
|
|
||||||
llama_token llama_token_eom_impl (const struct llama_vocab & vocab);
|
|
||||||
|
|
||||||
int32_t llama_tokenize_impl(
|
|
||||||
const struct llama_vocab & vocab,
|
|
||||||
const char * text,
|
|
||||||
int32_t text_len,
|
|
||||||
llama_token * tokens,
|
|
||||||
int32_t n_tokens_max,
|
|
||||||
bool add_special,
|
|
||||||
bool parse_special);
|
|
||||||
|
|
||||||
// does not write null-terminator to buf
|
|
||||||
int32_t llama_token_to_piece_impl(
|
|
||||||
const struct llama_vocab & vocab,
|
|
||||||
llama_token token,
|
|
||||||
char * buf,
|
|
||||||
int32_t length,
|
|
||||||
int32_t lstrip,
|
|
||||||
bool special);
|
|
||||||
|
|
||||||
int32_t llama_detokenize_impl(
|
|
||||||
const struct llama_vocab & vocab,
|
|
||||||
const llama_token * tokens,
|
|
||||||
int32_t n_tokens,
|
|
||||||
char * text,
|
|
||||||
int32_t text_len_max,
|
|
||||||
bool remove_special,
|
|
||||||
bool unparse_special);
|
|
||||||
|
|
||||||
std::string llama_detokenize(
|
|
||||||
const struct llama_vocab& vocab,
|
|
||||||
const std::vector<llama_token>& tokens,
|
|
||||||
bool special);
|
|
||||||
|
|||||||
3497
src/llama.cpp
3497
src/llama.cpp
File diff suppressed because it is too large
Load Diff
691
src/unicode.cpp
691
src/unicode.cpp
@@ -5,20 +5,19 @@
|
|||||||
#include "unicode.h"
|
#include "unicode.h"
|
||||||
#include "unicode-data.h"
|
#include "unicode-data.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <codecvt>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <locale>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <locale>
|
|
||||||
#include <codecvt>
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
size_t unicode_len_utf8(char src) {
|
size_t unicode_len_utf8(char src) {
|
||||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
@@ -26,7 +25,7 @@ size_t unicode_len_utf8(char src) {
|
|||||||
return lookup[highbits];
|
return lookup[highbits];
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string unicode_cpts_to_utf8(const std::vector<uint32_t>& cps) {
|
static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
|
||||||
std::string result;
|
std::string result;
|
||||||
for (size_t i = 0; i < cps.size(); ++i) {
|
for (size_t i = 0; i < cps.size(); ++i) {
|
||||||
result.append(unicode_cpt_to_utf8(cps[i]));
|
result.append(unicode_cpt_to_utf8(cps[i]));
|
||||||
@@ -34,7 +33,7 @@ static std::string unicode_cpts_to_utf8(const std::vector<uint32_t>& cps) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
|
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
|
||||||
assert(offset < utf8.size());
|
assert(offset < utf8.size());
|
||||||
if (!(utf8[offset + 0] & 0x80)) {
|
if (!(utf8[offset + 0] & 0x80)) {
|
||||||
auto result = utf8[offset + 0];
|
auto result = utf8[offset + 0];
|
||||||
@@ -45,7 +44,7 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
|
|||||||
throw std::invalid_argument("invalid character");
|
throw std::invalid_argument("invalid character");
|
||||||
}
|
}
|
||||||
if (!(utf8[offset + 0] & 0x20)) {
|
if (!(utf8[offset + 0] & 0x20)) {
|
||||||
if (offset + 1 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80)) {
|
if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) {
|
||||||
throw std::invalid_argument("invalid character");
|
throw std::invalid_argument("invalid character");
|
||||||
}
|
}
|
||||||
auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
|
auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
|
||||||
@@ -53,7 +52,7 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if (!(utf8[offset + 0] & 0x10)) {
|
if (!(utf8[offset + 0] & 0x10)) {
|
||||||
if (offset + 2 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80) || !((utf8[offset + 2] & 0xc0) == 0x80)) {
|
if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) {
|
||||||
throw std::invalid_argument("invalid character");
|
throw std::invalid_argument("invalid character");
|
||||||
}
|
}
|
||||||
auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
|
auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
|
||||||
@@ -61,7 +60,7 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if (!(utf8[offset + 0] & 0x08)) {
|
if (!(utf8[offset + 0] & 0x08)) {
|
||||||
if (offset + 3 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80) || !((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
|
if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
|
||||||
throw std::invalid_argument("invalid character");
|
throw std::invalid_argument("invalid character");
|
||||||
}
|
}
|
||||||
auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
|
auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
|
||||||
@@ -71,15 +70,15 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
|
|||||||
throw std::invalid_argument("failed to convert utf8 to codepoint");
|
throw std::invalid_argument("failed to convert utf8 to codepoint");
|
||||||
}
|
}
|
||||||
|
|
||||||
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
|
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cpt) {
|
||||||
// std::vector<uint16_t> result;
|
// std::vector<uint16_t> result;
|
||||||
// if (/* 0x0000 <= cp && */ cp <= 0xffff) {
|
// if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
|
||||||
// result.emplace_back(cp);
|
// result.emplace_back(cpt);
|
||||||
// return result;
|
// return result;
|
||||||
// }
|
// }
|
||||||
// if (0x10000 <= cp && cp <= 0x10ffff) {
|
// if (0x10000 <= cpt && cpt <= 0x10ffff) {
|
||||||
// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
|
// result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
|
||||||
// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
|
// result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
|
||||||
// return result;
|
// return result;
|
||||||
// }
|
// }
|
||||||
// throw std::invalid_argument("failed to convert codepoint to utf16");
|
// throw std::invalid_argument("failed to convert codepoint to utf16");
|
||||||
@@ -120,14 +119,14 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
|
|||||||
// return result;
|
// return result;
|
||||||
//}
|
//}
|
||||||
|
|
||||||
static std::vector<codepoint_flags> unicode_cpt_flags_array() {
|
static std::vector<unicode_cpt_flags> unicode_cpt_flags_array() {
|
||||||
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
|
std::vector<unicode_cpt_flags> cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
|
||||||
|
|
||||||
assert(unicode_ranges_flags.front().first == 0);
|
assert (unicode_ranges_flags.begin()[0].first == 0);
|
||||||
assert(unicode_ranges_flags.back().first == MAX_CODEPOINTS);
|
assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
|
||||||
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
|
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
|
||||||
const auto range_ini = unicode_ranges_flags[i - 1]; // codepoint_ini, flags
|
const auto range_ini = unicode_ranges_flags.begin()[i-1]; // codepoint_ini, flags
|
||||||
const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
|
const auto range_end = unicode_ranges_flags.begin()[i]; // codepoint_end, flags
|
||||||
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
|
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
|
||||||
cpt_flags[cpt] = range_ini.second;
|
cpt_flags[cpt] = range_ini.second;
|
||||||
}
|
}
|
||||||
@@ -145,7 +144,7 @@ static std::vector<codepoint_flags> unicode_cpt_flags_array() {
|
|||||||
cpt_flags[p.second].is_uppercase = true;
|
cpt_flags[p.second].is_uppercase = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& range : unicode_ranges_nfd) { // start, last, nfd
|
for (auto &range : unicode_ranges_nfd) { // start, last, nfd
|
||||||
cpt_flags[range.nfd].is_nfd = true;
|
cpt_flags[range.nfd].is_nfd = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -200,55 +199,38 @@ static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
|
|||||||
return map;
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline bool is_valid_utf8(const std::string& str) {
|
static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
|
||||||
int remaining_bytes = 0; // 当前多字节字符剩余的字节数
|
|
||||||
for (unsigned char c : str) {
|
|
||||||
if (remaining_bytes == 0) {
|
|
||||||
if ((c & 0x80) == 0x00) continue; // 1字节字符
|
|
||||||
else if ((c & 0xE0) == 0xC0) remaining_bytes = 1; // 2字节
|
|
||||||
else if ((c & 0xF0) == 0xE0) remaining_bytes = 2; // 3字节
|
|
||||||
else if ((c & 0xF8) == 0xF0) remaining_bytes = 3; // 4字节
|
|
||||||
else return false; // 非法起始字节
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// 检查后续字节是否为10xxxxxx
|
|
||||||
if ((c & 0xC0) != 0x80)
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
remaining_bytes--;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return (remaining_bytes == 0); // 确保多字节字符完整
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline std::wstring unicode_wstring_from_utf8(const std::string& s) {
|
|
||||||
#if defined(__clang__)
|
#if defined(__clang__)
|
||||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||||
# pragma clang diagnostic push
|
# pragma clang diagnostic push
|
||||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||||
|
#elif defined(__GNUC__)
|
||||||
|
# pragma GCC diagnostic push
|
||||||
|
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||||
#endif
|
#endif
|
||||||
bool isvalid = is_valid_utf8(s);
|
|
||||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
|
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
|
||||||
|
|
||||||
#if defined(__clang__)
|
#if defined(__clang__)
|
||||||
# pragma clang diagnostic pop
|
# pragma clang diagnostic pop
|
||||||
|
#elif defined(__GNUC__)
|
||||||
|
# pragma GCC diagnostic pop
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return conv.from_bytes(s);
|
return conv.from_bytes(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string>& bpe_words) {
|
static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
|
||||||
std::vector<std::string> bpe_encoded_words;
|
std::vector<std::string> bpe_encoded_words;
|
||||||
for (const auto& word : bpe_words) {
|
for (const auto & word : bpe_words) {
|
||||||
std::string text_utf;
|
std::string text_utf;
|
||||||
auto utf_word = unicode_cpts_from_utf8(word);
|
auto utf_word = unicode_cpts_from_utf8(word);
|
||||||
for (size_t i = 0; i < utf_word.size(); ++i) {
|
for (size_t i = 0; i < utf_word.size(); ++i) {
|
||||||
text_utf += unicode_cpt_to_utf8(utf_word[i]);
|
text_utf += unicode_cpt_to_utf8(utf_word[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string encoded_token;
|
std::string encoded_token;
|
||||||
for (char& c : text_utf) {
|
for (char & c : text_utf) {
|
||||||
encoded_token += unicode_byte_to_utf8(c);
|
encoded_token += unicode_byte_to_utf8(c);
|
||||||
}
|
}
|
||||||
bpe_encoded_words.emplace_back(encoded_token);
|
bpe_encoded_words.emplace_back(encoded_token);
|
||||||
@@ -257,7 +239,7 @@ static std::vector<std::string> unicode_byte_encoding_process(const std::vector<
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
|
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
|
||||||
static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string& text, const std::vector<size_t>& offsets) {
|
static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & text, const std::vector<size_t> & offsets) {
|
||||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||||
|
|
||||||
@@ -271,16 +253,16 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string& te
|
|||||||
start = offset_end;
|
start = offset_end;
|
||||||
|
|
||||||
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
||||||
auto _get_cpt = [&](const size_t pos) -> uint32_t {
|
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
||||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto _get_flags = [&](const size_t pos) -> codepoint_flags {
|
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
|
||||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
|
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t _prev_end = offset_ini;
|
size_t _prev_end = offset_ini;
|
||||||
auto _add_token = [&](const size_t end) -> size_t {
|
auto _add_token = [&] (const size_t end) -> size_t {
|
||||||
assert(_prev_end <= end && end <= offset_end);
|
assert(_prev_end <= end && end <= offset_end);
|
||||||
size_t len = end - _prev_end;
|
size_t len = end - _prev_end;
|
||||||
if (len > 0) {
|
if (len > 0) {
|
||||||
@@ -296,29 +278,29 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string& te
|
|||||||
return len;
|
return len;
|
||||||
};
|
};
|
||||||
|
|
||||||
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/) {
|
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
||||||
const uint32_t cpt = _get_cpt(pos);
|
const uint32_t cpt = _get_cpt(pos);
|
||||||
const auto flags = _get_flags(pos);
|
const auto flags = _get_flags(pos);
|
||||||
|
|
||||||
// regex: 's|'t|'re|'ve|'m|'ll|'d
|
// regex: 's|'t|'re|'ve|'m|'ll|'d
|
||||||
if (cpt == '\'' && pos + 1 < offset_end) {
|
if (cpt == '\'' && pos+1 < offset_end) {
|
||||||
uint32_t cpt_next = _get_cpt(pos + 1);
|
uint32_t cpt_next = _get_cpt(pos+1);
|
||||||
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
||||||
pos += _add_token(pos + 2);
|
pos += _add_token(pos+2);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (pos + 2 < offset_end) {
|
if (pos+2 < offset_end) {
|
||||||
uint32_t cpt_next_next = _get_cpt(pos + 2);
|
uint32_t cpt_next_next = _get_cpt(pos+2);
|
||||||
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
||||||
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
||||||
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
||||||
pos += _add_token(pos + 3);
|
pos += _add_token(pos+3);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
|
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
||||||
// regex: <space>?\p{L}+
|
// regex: <space>?\p{L}+
|
||||||
if (flags2.is_letter) {
|
if (flags2.is_letter) {
|
||||||
pos += (cpt == ' ');
|
pos += (cpt == ' ');
|
||||||
@@ -348,12 +330,12 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string& te
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t num_whitespaces = 0;
|
size_t num_whitespaces = 0;
|
||||||
while (_get_flags(pos + num_whitespaces).is_whitespace) {
|
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
||||||
num_whitespaces++;
|
num_whitespaces++;
|
||||||
}
|
}
|
||||||
|
|
||||||
// regex: \s+(?!\S)
|
// regex: \s+(?!\S)
|
||||||
if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
|
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
|
||||||
pos += num_whitespaces - 1;
|
pos += num_whitespaces - 1;
|
||||||
_add_token(pos);
|
_add_token(pos);
|
||||||
continue;
|
continue;
|
||||||
@@ -374,6 +356,207 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string& te
|
|||||||
return bpe_offsets;
|
return bpe_offsets;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
|
||||||
|
static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) {
|
||||||
|
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||||
|
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||||
|
|
||||||
|
const auto cpts = unicode_cpts_from_utf8(text);
|
||||||
|
|
||||||
|
size_t start = 0;
|
||||||
|
for (auto offset : offsets) {
|
||||||
|
const size_t offset_ini = start;
|
||||||
|
const size_t offset_end = start + offset;
|
||||||
|
assert(offset_end <= cpts.size());
|
||||||
|
start = offset_end;
|
||||||
|
|
||||||
|
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
||||||
|
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
||||||
|
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
|
||||||
|
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t _prev_end = offset_ini;
|
||||||
|
auto _add_token = [&] (const size_t end) -> size_t {
|
||||||
|
assert(_prev_end <= end && end <= offset_end);
|
||||||
|
size_t len = end - _prev_end;
|
||||||
|
if (len > 0) {
|
||||||
|
bpe_offsets.push_back(len);
|
||||||
|
}
|
||||||
|
_prev_end = end;
|
||||||
|
//if (len > 0) {
|
||||||
|
// std::string s = "";
|
||||||
|
// for(size_t p = end-len; p < end; p++)
|
||||||
|
// s += unicode_cpt_to_utf8(cpts[p]);
|
||||||
|
// printf(">>> '%s'\n", s.c_str());
|
||||||
|
//}
|
||||||
|
return len;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
||||||
|
const uint32_t cpt = _get_cpt(pos);
|
||||||
|
const auto flags = _get_flags(pos);
|
||||||
|
|
||||||
|
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
|
||||||
|
if (cpt == '\'' && pos+1 < offset_end) {
|
||||||
|
uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
|
||||||
|
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
||||||
|
pos += _add_token(pos+2);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (pos+2 < offset_end) {
|
||||||
|
uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
|
||||||
|
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
||||||
|
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
||||||
|
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
||||||
|
pos += _add_token(pos+3);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// regex: [^\r\n\p{L}\p{N}]?\p{L}+
|
||||||
|
if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
|
||||||
|
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
|
||||||
|
pos++;
|
||||||
|
while (_get_flags(pos).is_letter) {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
_add_token(pos);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// regex: \p{N}{1,3}
|
||||||
|
if (flags.is_number) {
|
||||||
|
size_t ini = pos;
|
||||||
|
while (_get_flags(pos).is_number) {
|
||||||
|
if (++pos - ini >= 3 ) {
|
||||||
|
_add_token(pos);
|
||||||
|
ini = pos;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_add_token(pos);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
|
||||||
|
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
||||||
|
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
|
||||||
|
pos += (cpt == ' ');
|
||||||
|
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
|
||||||
|
flags2 = _get_flags(++pos);
|
||||||
|
}
|
||||||
|
uint32_t cpt2 = _get_cpt(pos);
|
||||||
|
while (cpt2 == '\r' || cpt2 == '\n') {
|
||||||
|
cpt2 = _get_cpt(++pos);
|
||||||
|
}
|
||||||
|
_add_token(pos);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t num_whitespaces = 0;
|
||||||
|
size_t last_end_r_or_n = 0;
|
||||||
|
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
||||||
|
uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
|
||||||
|
if (cpt2 == '\r' || cpt2 == '\n') {
|
||||||
|
last_end_r_or_n = pos + num_whitespaces + 1;
|
||||||
|
}
|
||||||
|
num_whitespaces++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// regex: \s*[\r\n]+
|
||||||
|
if (last_end_r_or_n > 0) {
|
||||||
|
pos = last_end_r_or_n;
|
||||||
|
_add_token(pos);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// regex: \s+(?!\S)
|
||||||
|
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
|
||||||
|
pos += num_whitespaces - 1;
|
||||||
|
_add_token(pos);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// regex: \s+
|
||||||
|
if (num_whitespaces > 0) {
|
||||||
|
pos += num_whitespaces;
|
||||||
|
_add_token(pos);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// no matches
|
||||||
|
_add_token(++pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bpe_offsets;
|
||||||
|
}
|
||||||
|
|
||||||
|
// use std::wregex to split the text
|
||||||
|
static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) {
|
||||||
|
std::wregex expr(regex_expr);
|
||||||
|
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||||
|
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||||
|
size_t start = 0;
|
||||||
|
for (auto offset : offsets) {
|
||||||
|
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
|
||||||
|
std::wcregex_iterator end;
|
||||||
|
|
||||||
|
int64_t start_idx = 0;
|
||||||
|
while (it != end) {
|
||||||
|
std::wcmatch match = *it;
|
||||||
|
if (match.position() > start_idx) {
|
||||||
|
bpe_offsets.emplace_back(match.position() - start_idx);
|
||||||
|
}
|
||||||
|
bpe_offsets.emplace_back(match.length());
|
||||||
|
start_idx = match.position() + match.length();
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (start_idx < (int64_t) offset) {
|
||||||
|
bpe_offsets.emplace_back(offset - start_idx);
|
||||||
|
}
|
||||||
|
start += offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
return bpe_offsets;
|
||||||
|
}
|
||||||
|
|
||||||
|
// use std::regex to split the text
|
||||||
|
static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||||
|
std::regex expr(regex_expr);
|
||||||
|
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||||
|
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||||
|
size_t start = 0;
|
||||||
|
for (auto offset : offsets) {
|
||||||
|
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
|
||||||
|
std::cregex_iterator end;
|
||||||
|
|
||||||
|
int64_t start_idx = 0;
|
||||||
|
while (it != end) {
|
||||||
|
std::cmatch match = *it;
|
||||||
|
if (match.position() > start_idx) {
|
||||||
|
bpe_offsets.emplace_back(match.position() - start_idx);
|
||||||
|
}
|
||||||
|
bpe_offsets.emplace_back(match.length());
|
||||||
|
start_idx = match.position() + match.length();
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (start_idx < (int64_t) offset) {
|
||||||
|
bpe_offsets.emplace_back(offset - start_idx);
|
||||||
|
}
|
||||||
|
start += offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
return bpe_offsets;
|
||||||
|
}
|
||||||
|
|
||||||
// K2 system regex patterns (from tokenization_kimi.py):
|
// K2 system regex patterns (from tokenization_kimi.py):
|
||||||
// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
|
// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
|
||||||
static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector<size_t> & offsets) {
|
static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector<size_t> & offsets) {
|
||||||
@@ -394,8 +577,8 @@ static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string
|
|||||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
|
||||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
|
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t _prev_end = offset_ini;
|
size_t _prev_end = offset_ini;
|
||||||
@@ -546,220 +729,17 @@ static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string
|
|||||||
return bpe_offsets;
|
return bpe_offsets;
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
|
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||||
static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string& text, const std::vector<size_t>& offsets) {
|
|
||||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
|
||||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
|
||||||
|
|
||||||
const auto cpts = unicode_cpts_from_utf8(text);
|
|
||||||
|
|
||||||
size_t start = 0;
|
|
||||||
for (auto offset : offsets) {
|
|
||||||
const size_t offset_ini = start;
|
|
||||||
const size_t offset_end = start + offset;
|
|
||||||
assert(offset_end <= cpts.size());
|
|
||||||
start = offset_end;
|
|
||||||
|
|
||||||
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
|
||||||
auto _get_cpt = [&](const size_t pos) -> uint32_t {
|
|
||||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
|
||||||
};
|
|
||||||
|
|
||||||
auto _get_flags = [&](const size_t pos) -> codepoint_flags {
|
|
||||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
|
|
||||||
};
|
|
||||||
|
|
||||||
size_t _prev_end = offset_ini;
|
|
||||||
auto _add_token = [&](const size_t end) -> size_t {
|
|
||||||
assert(_prev_end <= end && end <= offset_end);
|
|
||||||
size_t len = end - _prev_end;
|
|
||||||
if (len > 0) {
|
|
||||||
bpe_offsets.push_back(len);
|
|
||||||
}
|
|
||||||
_prev_end = end;
|
|
||||||
//if (len > 0) {
|
|
||||||
// std::string s = "";
|
|
||||||
// for(size_t p = end-len; p < end; p++)
|
|
||||||
// s += unicode_cpt_to_utf8(cpts[p]);
|
|
||||||
// printf(">>> '%s'\n", s.c_str());
|
|
||||||
//}
|
|
||||||
return len;
|
|
||||||
};
|
|
||||||
|
|
||||||
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/) {
|
|
||||||
const uint32_t cpt = _get_cpt(pos);
|
|
||||||
const auto flags = _get_flags(pos);
|
|
||||||
|
|
||||||
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
|
|
||||||
if (cpt == '\'' && pos + 1 < offset_end) {
|
|
||||||
uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1));
|
|
||||||
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
|
||||||
pos += _add_token(pos + 2);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (pos + 2 < offset_end) {
|
|
||||||
uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2));
|
|
||||||
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
|
||||||
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
|
||||||
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
|
||||||
pos += _add_token(pos + 3);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// regex: [^\r\n\p{L}\p{N}]?\p{L}+
|
|
||||||
if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
|
|
||||||
if (flags.is_letter || _get_flags(pos + 1).is_letter) { // one or more letters
|
|
||||||
pos++;
|
|
||||||
while (_get_flags(pos).is_letter) {
|
|
||||||
pos++;
|
|
||||||
}
|
|
||||||
_add_token(pos);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// regex: \p{N}{1,3}
|
|
||||||
if (flags.is_number) {
|
|
||||||
size_t ini = pos;
|
|
||||||
while (_get_flags(pos).is_number) {
|
|
||||||
if (++pos - ini >= 3) {
|
|
||||||
_add_token(pos);
|
|
||||||
ini = pos;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_add_token(pos);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
|
|
||||||
auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
|
|
||||||
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
|
|
||||||
pos += (cpt == ' ');
|
|
||||||
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
|
|
||||||
flags2 = _get_flags(++pos);
|
|
||||||
}
|
|
||||||
uint32_t cpt2 = _get_cpt(pos);
|
|
||||||
while (cpt2 == '\r' || cpt2 == '\n') {
|
|
||||||
cpt2 = _get_cpt(++pos);
|
|
||||||
}
|
|
||||||
_add_token(pos);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t num_whitespaces = 0;
|
|
||||||
size_t last_end_r_or_n = 0;
|
|
||||||
while (_get_flags(pos + num_whitespaces).is_whitespace) {
|
|
||||||
uint32_t cpt2 = _get_cpt(pos + num_whitespaces);
|
|
||||||
if (cpt2 == '\r' || cpt2 == '\n') {
|
|
||||||
last_end_r_or_n = pos + num_whitespaces + 1;
|
|
||||||
}
|
|
||||||
num_whitespaces++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// regex: \s*[\r\n]+
|
|
||||||
if (last_end_r_or_n > 0) {
|
|
||||||
pos = last_end_r_or_n;
|
|
||||||
_add_token(pos);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// regex: \s+(?!\S)
|
|
||||||
if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
|
|
||||||
pos += num_whitespaces - 1;
|
|
||||||
_add_token(pos);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// regex: \s+
|
|
||||||
if (num_whitespaces > 0) {
|
|
||||||
pos += num_whitespaces;
|
|
||||||
_add_token(pos);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// no matches
|
|
||||||
_add_token(++pos);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return bpe_offsets;
|
|
||||||
}
|
|
||||||
|
|
||||||
// use std::wregex to split the text
|
|
||||||
static std::vector<size_t> unicode_regex_split_stl(const std::wstring& wtext, const std::wstring& regex_expr, const std::vector<size_t>& offsets) {
|
|
||||||
std::wregex expr(regex_expr);
|
|
||||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
|
||||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
|
||||||
size_t start = 0;
|
|
||||||
for (auto offset : offsets) {
|
|
||||||
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
|
|
||||||
std::wcregex_iterator end;
|
|
||||||
|
|
||||||
int64_t start_idx = 0;
|
|
||||||
while (it != end) {
|
|
||||||
std::wcmatch match = *it;
|
|
||||||
if (match.position() > start_idx) {
|
|
||||||
bpe_offsets.emplace_back(match.position() - start_idx);
|
|
||||||
}
|
|
||||||
bpe_offsets.emplace_back(match.length());
|
|
||||||
start_idx = match.position() + match.length();
|
|
||||||
++it;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (start_idx < (int64_t)offset) {
|
|
||||||
bpe_offsets.emplace_back(offset - start_idx);
|
|
||||||
}
|
|
||||||
start += offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
return bpe_offsets;
|
|
||||||
}
|
|
||||||
|
|
||||||
// use std::regex to split the text
|
|
||||||
static std::vector<size_t> unicode_regex_split_stl(const std::string& text, const std::string& regex_expr, const std::vector<size_t>& offsets) {
|
|
||||||
std::regex expr(regex_expr);
|
|
||||||
std::vector<size_t> bpe_offsets; // store the offset of each word
|
|
||||||
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
|
||||||
size_t start = 0;
|
|
||||||
for (auto offset : offsets) {
|
|
||||||
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
|
|
||||||
std::cregex_iterator end;
|
|
||||||
|
|
||||||
int64_t start_idx = 0;
|
|
||||||
while (it != end) {
|
|
||||||
std::cmatch match = *it;
|
|
||||||
if (match.position() > start_idx) {
|
|
||||||
bpe_offsets.emplace_back(match.position() - start_idx);
|
|
||||||
}
|
|
||||||
bpe_offsets.emplace_back(match.length());
|
|
||||||
start_idx = match.position() + match.length();
|
|
||||||
++it;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (start_idx < (int64_t)offset) {
|
|
||||||
bpe_offsets.emplace_back(offset - start_idx);
|
|
||||||
}
|
|
||||||
start += offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
return bpe_offsets;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::vector<size_t> unicode_regex_split_custom(const std::string& text, const std::string& regex_expr, const std::vector<size_t>& offsets) {
|
|
||||||
std::vector<size_t> bpe_offsets;
|
std::vector<size_t> bpe_offsets;
|
||||||
|
|
||||||
if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
|
if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
|
||||||
bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
|
bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
|
||||||
}
|
} else if (
|
||||||
else if (
|
regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
|
||||||
regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
|
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
|
||||||
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
|
|
||||||
|
|
||||||
bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
|
bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
|
||||||
}
|
} else if (regex_expr == "\\p{Han}+") {
|
||||||
else if (regex_expr == "\\p{Han}+") {
|
|
||||||
// K2's first pattern - handle all K2 patterns together
|
// K2's first pattern - handle all K2 patterns together
|
||||||
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
|
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
|
||||||
}
|
}
|
||||||
@@ -771,71 +751,100 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string& text, c
|
|||||||
// interface
|
// interface
|
||||||
//
|
//
|
||||||
|
|
||||||
std::string unicode_cpt_to_utf8(uint32_t cp) {
|
std::string unicode_cpt_to_utf8(uint32_t cpt) {
|
||||||
std::string result;
|
std::string result;
|
||||||
|
|
||||||
if (/* 0x00 <= cp && */ cp <= 0x7f) {
|
if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
|
||||||
result.push_back(cp);
|
result.push_back(cpt);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if (0x80 <= cp && cp <= 0x7ff) {
|
if (0x80 <= cpt && cpt <= 0x7ff) {
|
||||||
result.push_back(0xc0 | ((cp >> 6) & 0x1f));
|
result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
|
||||||
result.push_back(0x80 | (cp & 0x3f));
|
result.push_back(0x80 | (cpt & 0x3f));
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if (0x800 <= cp && cp <= 0xffff) {
|
if (0x800 <= cpt && cpt <= 0xffff) {
|
||||||
result.push_back(0xe0 | ((cp >> 12) & 0x0f));
|
result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
|
||||||
result.push_back(0x80 | ((cp >> 6) & 0x3f));
|
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
|
||||||
result.push_back(0x80 | (cp & 0x3f));
|
result.push_back(0x80 | (cpt & 0x3f));
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if (0x10000 <= cp && cp <= 0x10ffff) {
|
if (0x10000 <= cpt && cpt <= 0x10ffff) {
|
||||||
result.push_back(0xf0 | ((cp >> 18) & 0x07));
|
result.push_back(0xf0 | ((cpt >> 18) & 0x07));
|
||||||
result.push_back(0x80 | ((cp >> 12) & 0x3f));
|
result.push_back(0x80 | ((cpt >> 12) & 0x3f));
|
||||||
result.push_back(0x80 | ((cp >> 6) & 0x3f));
|
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
|
||||||
result.push_back(0x80 | (cp & 0x3f));
|
result.push_back(0x80 | (cpt & 0x3f));
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument("invalid codepoint");
|
throw std::invalid_argument("invalid codepoint");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t>& cpts) {
|
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
|
||||||
auto comp = [](const uint32_t cpt, const range_nfd& range) {
|
auto comp = [] (const uint32_t cpt, const range_nfd & range) {
|
||||||
return cpt < range.first;
|
return cpt < range.first;
|
||||||
};
|
};
|
||||||
std::vector<uint32_t> result(cpts.size());
|
std::vector<uint32_t> result(cpts.size());
|
||||||
for (size_t i = 0; i < cpts.size(); ++i) {
|
for (size_t i = 0; i < cpts.size(); ++i) {
|
||||||
const uint32_t cpt = cpts[i];
|
const uint32_t cpt = cpts[i];
|
||||||
auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
|
auto it = std::upper_bound(unicode_ranges_nfd.begin(), unicode_ranges_nfd.end(), cpt, comp) - 1;
|
||||||
result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
|
result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string& utf8) {
|
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
|
||||||
std::vector<uint32_t> result;
|
std::vector<uint32_t> result;
|
||||||
result.reserve(utf8.size());
|
result.reserve(utf8.size());
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
while (offset < utf8.size()) {
|
while (offset < utf8.size()) {
|
||||||
result.push_back(unicode_cpt_from_utf8(utf8, offset));
|
try {
|
||||||
|
result.push_back(unicode_cpt_from_utf8(utf8, offset));
|
||||||
|
}
|
||||||
|
catch (const std::invalid_argument & /*ex*/) {
|
||||||
|
// Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
|
||||||
|
++offset;
|
||||||
|
result.emplace_back(0xFFFD); // replacement character
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
codepoint_flags unicode_cpt_flags(const uint32_t cp) {
|
unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
|
||||||
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
|
||||||
static const auto cpt_flags = unicode_cpt_flags_array();
|
static const auto cpt_flags = unicode_cpt_flags_array();
|
||||||
return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
|
return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
|
||||||
}
|
}
|
||||||
|
|
||||||
codepoint_flags unicode_cpt_flags(const std::string& utf8) {
|
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
|
||||||
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
|
||||||
if (utf8.empty()) {
|
if (utf8.empty()) {
|
||||||
return undef; // undefined
|
return undef; // undefined
|
||||||
}
|
}
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
|
return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string unicode_byte_to_utf8(uint8_t byte) {
|
||||||
|
static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
|
||||||
|
return map.at(byte);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t unicode_utf8_to_byte(const std::string & utf8) {
|
||||||
|
static std::unordered_map<std::string, uint8_t> map = unicode_utf8_to_byte_map();
|
||||||
|
return map.at(utf8);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t unicode_tolower(uint32_t cpt) {
|
||||||
|
// binary search
|
||||||
|
auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
|
||||||
|
[](const std::pair<uint32_t, uint32_t> & pair, uint32_t value) {
|
||||||
|
return pair.first < value;
|
||||||
|
});
|
||||||
|
if (it != unicode_map_lowercase.end() && it->first == cpt) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
return cpt; // Return the original code point if no lowercase mapping is found
|
||||||
}
|
}
|
||||||
|
|
||||||
bool unicode_cpt_is_han(uint32_t cpt) {
|
bool unicode_cpt_is_han(uint32_t cpt) {
|
||||||
@@ -870,53 +879,37 @@ bool unicode_cpt_is_han(uint32_t cpt) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string unicode_byte_to_utf8(uint8_t byte) {
|
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||||
static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
|
|
||||||
return map.at(byte);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t unicode_utf8_to_byte(const std::string& utf8) {
|
|
||||||
static std::unordered_map<std::string, uint8_t> map = unicode_utf8_to_byte_map();
|
|
||||||
return map.at(utf8);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t unicode_tolower(uint32_t cp) {
|
|
||||||
auto it = unicode_map_lowercase.find(cp);
|
|
||||||
return it == unicode_map_lowercase.end() ? cp : it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> unicode_regex_split(const std::string& text, const std::vector<std::string>& regex_exprs) {
|
|
||||||
// unicode categories
|
// unicode categories
|
||||||
static const std::map<std::string, int> k_ucat_enum = {
|
static const std::map<std::string, int> k_ucat_enum = {
|
||||||
{ "\\p{N}", codepoint_flags::NUMBER },
|
{ "\\p{N}", unicode_cpt_flags::NUMBER },
|
||||||
{ "\\p{L}", codepoint_flags::LETTER },
|
{ "\\p{L}", unicode_cpt_flags::LETTER },
|
||||||
{ "\\p{P}", codepoint_flags::PUNCTUATION },
|
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
|
||||||
{ "\\p{M}", codepoint_flags::ACCENT_MARK },
|
{ "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
|
||||||
{ "\\p{S}", codepoint_flags::SYMBOL },
|
{ "\\p{S}", unicode_cpt_flags::SYMBOL },
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<int, int> k_ucat_cpt = {
|
static const std::map<int, int> k_ucat_cpt = {
|
||||||
{ codepoint_flags::NUMBER, 0xD1 },
|
{ unicode_cpt_flags::NUMBER, 0xD1 },
|
||||||
{ codepoint_flags::LETTER, 0xD2 },
|
{ unicode_cpt_flags::LETTER, 0xD2 },
|
||||||
{ codepoint_flags::PUNCTUATION, 0xD3 },
|
{ unicode_cpt_flags::PUNCTUATION, 0xD3 },
|
||||||
{ codepoint_flags::ACCENT_MARK, 0xD4 },
|
{ unicode_cpt_flags::ACCENT_MARK, 0xD4 },
|
||||||
{ codepoint_flags::SYMBOL, 0xD5 },
|
{ unicode_cpt_flags::SYMBOL, 0xD5 },
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<int, std::string> k_ucat_map = {
|
static const std::map<int, std::string> k_ucat_map = {
|
||||||
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
|
{ unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
|
||||||
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
{ unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||||
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}i
|
{ unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||||
{ codepoint_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
|
{ unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
|
||||||
{ codepoint_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
|
{ unicode_cpt_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
|
||||||
};
|
};
|
||||||
|
|
||||||
// compute collapsed codepoints only if needed by at least one regex
|
// compute collapsed codepoints only if needed by at least one regex
|
||||||
bool need_collapse = false;
|
bool need_collapse = false;
|
||||||
for (auto& regex_expr : regex_exprs) {
|
for (const auto & regex_expr : regex_exprs) {
|
||||||
// search for unicode categories
|
// search for unicode categories
|
||||||
for (const auto& ucat : k_ucat_enum) {
|
for (const auto & ucat : k_ucat_enum) {
|
||||||
if (std::string::npos != regex_expr.find(ucat.first)) {
|
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||||
need_collapse = true;
|
need_collapse = true;
|
||||||
break;
|
break;
|
||||||
@@ -927,7 +920,7 @@ std::vector<std::string> unicode_regex_split(const std::string& text, const std:
|
|||||||
const auto cpts = unicode_cpts_from_utf8(text);
|
const auto cpts = unicode_cpts_from_utf8(text);
|
||||||
|
|
||||||
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
|
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
|
// ref: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2081479935
|
||||||
std::string text_collapsed;
|
std::string text_collapsed;
|
||||||
if (need_collapse) {
|
if (need_collapse) {
|
||||||
// collapse all unicode categories
|
// collapse all unicode categories
|
||||||
@@ -940,25 +933,23 @@ std::vector<std::string> unicode_regex_split(const std::string& text, const std:
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto flags = unicode_cpt_flags(cpts[i]);
|
const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
|
||||||
|
|
||||||
if (flags.is_whitespace) {
|
if (flags.is_whitespace) {
|
||||||
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
|
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
|
||||||
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
|
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
|
||||||
text_collapsed[i] = (char)0x0B; // <vertical tab> as whitespace fallback
|
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
|
||||||
}
|
} else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
|
||||||
else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
|
|
||||||
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
|
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
|
||||||
}
|
} else {
|
||||||
else {
|
text_collapsed[i] = (char) 0xD0; // fallback
|
||||||
text_collapsed[i] = (char)0xD0; // fallback
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> bpe_offsets = { cpts.size() };
|
std::vector<size_t> bpe_offsets = { cpts.size() };
|
||||||
|
|
||||||
for (auto& regex_expr : regex_exprs) {
|
for (const auto & regex_expr : regex_exprs) {
|
||||||
// first, see if we have an efficient custom regex implementation
|
// first, see if we have an efficient custom regex implementation
|
||||||
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
|
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
|
||||||
|
|
||||||
@@ -972,7 +963,7 @@ std::vector<std::string> unicode_regex_split(const std::string& text, const std:
|
|||||||
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
|
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
|
||||||
// with the corresponding collapsed representation
|
// with the corresponding collapsed representation
|
||||||
bool use_collapsed = false;
|
bool use_collapsed = false;
|
||||||
for (auto& ucat : k_ucat_enum) {
|
for (const auto & ucat : k_ucat_enum) {
|
||||||
if (std::string::npos != regex_expr.find(ucat.first)) {
|
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||||
use_collapsed = true;
|
use_collapsed = true;
|
||||||
break;
|
break;
|
||||||
@@ -1031,15 +1022,14 @@ std::vector<std::string> unicode_regex_split(const std::string& text, const std:
|
|||||||
//printf("text_collapsed: %s\n", text_collapsed.c_str());
|
//printf("text_collapsed: %s\n", text_collapsed.c_str());
|
||||||
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
|
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
|
||||||
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
|
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
// no unicode category used, we can use std::wregex directly
|
// no unicode category used, we can use std::wregex directly
|
||||||
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
||||||
|
|
||||||
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
|
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
|
||||||
std::wstring wtext(cpts.begin(), cpts.end());
|
std::wstring wtext(cpts.begin(), cpts.end());
|
||||||
for (size_t i = 0; i < wtext.size(); ++i) {
|
for (size_t i = 0; i < wtext.size(); ++i) {
|
||||||
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
|
if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
|
||||||
wtext[i] = 0x0B;
|
wtext[i] = 0x0B;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1048,8 +1038,7 @@ std::vector<std::string> unicode_regex_split(const std::string& text, const std:
|
|||||||
//printf("regex_expr: %s\n", regex_expr.c_str());
|
//printf("regex_expr: %s\n", regex_expr.c_str());
|
||||||
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
|
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
|
||||||
}
|
}
|
||||||
}
|
} catch (std::regex_error & e) {
|
||||||
catch (std::regex_error& e) {
|
|
||||||
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
|
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
|
||||||
fprintf(stderr, "Regex error: %s\n", e.what());
|
fprintf(stderr, "Regex error: %s\n", e.what());
|
||||||
throw std::runtime_error("Failed to process regex");
|
throw std::runtime_error("Failed to process regex");
|
||||||
@@ -1060,7 +1049,7 @@ std::vector<std::string> unicode_regex_split(const std::string& text, const std:
|
|||||||
bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
|
bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
|
||||||
|
|
||||||
size_t start = 0;
|
size_t start = 0;
|
||||||
for (size_t& offset : bpe_offsets) {
|
for (size_t & offset : bpe_offsets) {
|
||||||
bpe_words.emplace_back();
|
bpe_words.emplace_back();
|
||||||
for (size_t i = start; i < start + offset; ++i) {
|
for (size_t i = start; i < start + offset; ++i) {
|
||||||
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
|
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
|
||||||
|
|||||||
@@ -4,9 +4,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// TODO: prefix all symbols with "llama_"
|
struct unicode_cpt_flags {
|
||||||
|
|
||||||
struct codepoint_flags {
|
|
||||||
enum {
|
enum {
|
||||||
UNDEFINED = 0x0001,
|
UNDEFINED = 0x0001,
|
||||||
NUMBER = 0x0002, // regex: \p{N}
|
NUMBER = 0x0002, // regex: \p{N}
|
||||||
@@ -35,7 +33,7 @@ struct codepoint_flags {
|
|||||||
uint16_t is_nfd : 1;
|
uint16_t is_nfd : 1;
|
||||||
|
|
||||||
// decode from uint16
|
// decode from uint16
|
||||||
inline codepoint_flags(const uint16_t flags=0) {
|
inline unicode_cpt_flags(const uint16_t flags = 0) {
|
||||||
*reinterpret_cast<uint16_t*>(this) = flags;
|
*reinterpret_cast<uint16_t*>(this) = flags;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,19 +48,20 @@ struct codepoint_flags {
|
|||||||
|
|
||||||
size_t unicode_len_utf8(char src);
|
size_t unicode_len_utf8(char src);
|
||||||
|
|
||||||
std::string unicode_cpt_to_utf8(uint32_t cp);
|
std::string unicode_cpt_to_utf8 (uint32_t cpt);
|
||||||
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
|
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
|
||||||
|
|
||||||
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
|
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
|
||||||
|
|
||||||
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
|
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
|
||||||
|
|
||||||
codepoint_flags unicode_cpt_flags(const uint32_t cp);
|
unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt);
|
||||||
codepoint_flags unicode_cpt_flags(const std::string & utf8);
|
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8);
|
||||||
|
|
||||||
std::string unicode_byte_to_utf8(uint8_t byte);
|
std::string unicode_byte_to_utf8(uint8_t byte);
|
||||||
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
||||||
|
|
||||||
uint32_t unicode_tolower(uint32_t cp);
|
uint32_t unicode_tolower(uint32_t cpt);
|
||||||
|
|
||||||
bool unicode_cpt_is_han(uint32_t cpt);
|
bool unicode_cpt_is_han(uint32_t cpt);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user