mirror of
https://github.com/theroyallab/YALS.git
synced 2026-04-22 07:18:52 +00:00
Templating: Switch to @huggingface/jinja
Since llamacpp will be using arbitrary jinja templates soon, use the more secure jinja package from huggingface. This comes with increased security at the cost of most jinja features. However, huggingface's Jinja will be updated to keep up with chat templates for AI models. Since the AST isn't typed in the package, include the types here. Signed-off-by: kingbri <8082010+bdashore3@users.noreply.github.com>
This commit is contained in:
@@ -421,13 +421,24 @@ export class Model {
|
||||
|
||||
let promptTemplate: PromptTemplate | undefined = undefined;
|
||||
if (params.prompt_template) {
|
||||
promptTemplate = await PromptTemplate.fromFile(
|
||||
`templates/${params.prompt_template}`,
|
||||
);
|
||||
try {
|
||||
promptTemplate = await PromptTemplate.fromFile(
|
||||
`templates/${params.prompt_template}`,
|
||||
);
|
||||
|
||||
logger.info(
|
||||
`Using template "${promptTemplate.name}" for chat completions`,
|
||||
);
|
||||
logger.info(
|
||||
`Using template "${promptTemplate.name}" for chat completions`,
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
logger.warn(
|
||||
"Could not create a prompt template because of the following error:\n " +
|
||||
`${error.stack}\n\n` +
|
||||
"YALS will continue loading without the prompt template.\n" +
|
||||
"Please proofread the template and make sure it's compatible with huggingface's jinja subset.",
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new Model(
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
// @ts-types="@/types/nunjucks.d.ts"
|
||||
import nunjucks from "nunjucks";
|
||||
// @ts-types="@/types/jinja.d.ts"
|
||||
import {
|
||||
ArrayLiteral,
|
||||
Identifier,
|
||||
Literal,
|
||||
SetStatement,
|
||||
Template,
|
||||
} from "@huggingface/jinja";
|
||||
import * as z from "@/common/myZod.ts";
|
||||
import * as Path from "@std/path";
|
||||
|
||||
@@ -11,15 +17,10 @@ const TemplateMetadataSchema = z.object({
|
||||
|
||||
type TemplateMetadata = z.infer<typeof TemplateMetadataSchema>;
|
||||
|
||||
function raiseException(message: string) {
|
||||
throw new Error(message);
|
||||
}
|
||||
|
||||
export class PromptTemplate {
|
||||
name: string;
|
||||
rawTemplate: string;
|
||||
environment: nunjucks.Environment;
|
||||
template: nunjucks.Template;
|
||||
template: Template;
|
||||
metadata: TemplateMetadata;
|
||||
|
||||
public constructor(
|
||||
@@ -28,53 +29,49 @@ export class PromptTemplate {
|
||||
) {
|
||||
this.name = name;
|
||||
this.rawTemplate = rawTemplate;
|
||||
this.environment = nunjucks.configure({ autoescape: false })
|
||||
.addGlobal("raise_exception", raiseException);
|
||||
this.template = new nunjucks.Template(rawTemplate, this.environment);
|
||||
this.metadata = this.extractMetadata(rawTemplate);
|
||||
this.template = new Template(rawTemplate);
|
||||
this.metadata = this.extractMetadata(this.template);
|
||||
}
|
||||
|
||||
public render(context: object = {}): string {
|
||||
public render(context: Record<string, unknown> = {}): string {
|
||||
return this.template.render(context);
|
||||
}
|
||||
|
||||
private extractMetadata(rawTemplate: string): TemplateMetadata {
|
||||
const ast = nunjucks.parser.parse(rawTemplate);
|
||||
private extractMetadata(template: Template) {
|
||||
const metadata: TemplateMetadata = TemplateMetadataSchema.parse({});
|
||||
if (!ast.children) {
|
||||
return metadata;
|
||||
}
|
||||
|
||||
ast.children.forEach((node) => {
|
||||
// Targets is unique to a setNode
|
||||
if ("targets" in node) {
|
||||
const setNode = node as nunjucks.SetNode;
|
||||
if (setNode.targets.length === 0) {
|
||||
return;
|
||||
}
|
||||
template.parsed.body.forEach((statement) => {
|
||||
if (statement.type === "Set") {
|
||||
const setStatement = statement as SetStatement;
|
||||
|
||||
const assignee = setStatement.assignee as Identifier;
|
||||
const foundMetaKey = Object.keys(TemplateMetadataSchema.shape)
|
||||
.find(
|
||||
(key) => key === setNode.targets[0].value,
|
||||
(key) => key === assignee.value,
|
||||
) as keyof TemplateMetadata;
|
||||
|
||||
if (foundMetaKey) {
|
||||
// Get field schema from overall schema
|
||||
const fieldSchema =
|
||||
TemplateMetadataSchema.shape[foundMetaKey];
|
||||
|
||||
// Only use for validation. For some reason, the parsed data can't be assigned
|
||||
let result;
|
||||
if (setNode.value.children) {
|
||||
result = setNode.value.children.map((child) =>
|
||||
child.value
|
||||
);
|
||||
} else {
|
||||
result = setNode.value.value;
|
||||
let result: unknown;
|
||||
if (setStatement.value.type === "ArrayLiteral") {
|
||||
const arrayValue = setStatement.value as ArrayLiteral;
|
||||
result = arrayValue.value.map((e) => {
|
||||
const literalValue = e as Literal<unknown>;
|
||||
return literalValue.value;
|
||||
});
|
||||
} else if (setStatement.value.type.endsWith("Literal")) {
|
||||
const literalValue = setStatement.value as Literal<
|
||||
unknown
|
||||
>;
|
||||
result = literalValue.value;
|
||||
}
|
||||
|
||||
const parsedValue = fieldSchema.safeParse(result);
|
||||
if (parsedValue.success) {
|
||||
metadata[foundMetaKey] = result;
|
||||
// deno-lint-ignore no-explicit-any
|
||||
metadata[foundMetaKey] = parsedValue.data as any;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
},
|
||||
"imports": {
|
||||
"@/": "./",
|
||||
"@types/nunjucks": "npm:@types/nunjucks@^3.2.6",
|
||||
"hono": "npm:hono@^4.6.14",
|
||||
"@huggingface/jinja": "npm:@huggingface/jinja@^0.3.2",
|
||||
"@scalar/hono-api-reference": "npm:@scalar/hono-api-reference@^0.5.165",
|
||||
@@ -16,7 +15,6 @@
|
||||
"hono-openapi": "npm:hono-openapi@^0.3.0",
|
||||
"logtape": "jsr:@logtape/logtape@^0.8.0",
|
||||
"@std/yaml": "jsr:@std/yaml@^1.0.5",
|
||||
"nunjucks": "npm:nunjucks@^3.2.4",
|
||||
"zod": "npm:zod@^3.24.1",
|
||||
"zod-openapi": "npm:zod-openapi@^4.2.1"
|
||||
},
|
||||
|
||||
24
deno.lock
generated
24
deno.lock
generated
@@ -8,10 +8,8 @@
|
||||
"npm:@huggingface/jinja@~0.3.2": "0.3.2",
|
||||
"npm:@scalar/hono-api-reference@~0.5.165": "0.5.165_hono@4.6.14",
|
||||
"npm:@types/node@*": "22.5.4",
|
||||
"npm:@types/nunjucks@^3.2.6": "3.2.6",
|
||||
"npm:hono-openapi@0.3": "0.3.0_arktype@2.0.0-rc.25_hono@4.6.14_effect@3.12.0_@sinclair+typebox@0.34.13_valibot@1.0.0-beta.9_zod@3.24.1",
|
||||
"npm:hono@^4.6.14": "4.6.14",
|
||||
"npm:nunjucks@^3.2.4": "3.2.4",
|
||||
"npm:zod-openapi@^4.2.1": "4.2.1_zod@3.24.1",
|
||||
"npm:zod@^3.24.1": "3.24.1"
|
||||
},
|
||||
@@ -117,9 +115,6 @@
|
||||
"undici-types"
|
||||
]
|
||||
},
|
||||
"@types/nunjucks@3.2.6": {
|
||||
"integrity": "sha512-pHiGtf83na1nCzliuAdq8GowYiXvH5l931xZ0YEHaLMNFgynpEqx+IPStlu7UaDkehfvl01e4x/9Tpwhy7Ue3w=="
|
||||
},
|
||||
"@unhead/schema@1.11.14": {
|
||||
"integrity": "sha512-V9W9u5tF1/+TiLqxu+Qvh1ShoMDkPEwHoEo4DKdDG6ko7YlbzFfDxV6el9JwCren45U/4Vy/4Xi7j8OH02wsiA==",
|
||||
"dependencies": [
|
||||
@@ -133,9 +128,6 @@
|
||||
"valibot"
|
||||
]
|
||||
},
|
||||
"a-sync-waterfall@1.0.1": {
|
||||
"integrity": "sha512-RYTOHHdWipFUliRFMCS4X2Yn2X8M87V/OpSqWzKKOGhzqyUxzyVmhHDH9sAvG+ZuQf/TAOFsLCpMw09I1ufUnA=="
|
||||
},
|
||||
"argparse@2.0.1": {
|
||||
"integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q=="
|
||||
},
|
||||
@@ -146,15 +138,9 @@
|
||||
"@ark/util"
|
||||
]
|
||||
},
|
||||
"asap@2.0.6": {
|
||||
"integrity": "sha512-BSHWgDSAiKs50o2Re8ppvp3seVHXSRM44cdSsT9FfNEUUZLOGWVCsiWaRPWM1Znn+mqZ1OfVZ3z3DWEzSp7hRA=="
|
||||
},
|
||||
"clone@2.1.2": {
|
||||
"integrity": "sha512-3Pe/CF1Nn94hyhIYpjtiLhdCoEoz0DqQ+988E9gmeEdQZlojxnOb74wctFyuwWQHzqyf9X7C7MG8juUpqBJT8w=="
|
||||
},
|
||||
"commander@5.1.0": {
|
||||
"integrity": "sha512-P0CysNDQ7rtVw4QIQtm+MRxV66vKFSvlsQvGYXZWR3qFU0jlMKHZZZgw8e+8DSah4UDKMqnknRDQz+xuQXQ/Zg=="
|
||||
},
|
||||
"effect@3.12.0": {
|
||||
"integrity": "sha512-b/u9s3b9HfTo0qygVouegP0hkbiuxRIeaCe1ppf8P88hPyl6lKCbErtn7Az4jG7LuU7f0Wgm4c8WXbMcL2j8+g==",
|
||||
"dependencies": [
|
||||
@@ -206,14 +192,6 @@
|
||||
"clone"
|
||||
]
|
||||
},
|
||||
"nunjucks@3.2.4": {
|
||||
"integrity": "sha512-26XRV6BhkgK0VOxfbU5cQI+ICFUtMLixv1noZn1tGU38kQH5A5nmmbk/O45xdyBhD1esk47nKrY0mvQpZIhRjQ==",
|
||||
"dependencies": [
|
||||
"a-sync-waterfall",
|
||||
"asap",
|
||||
"commander"
|
||||
]
|
||||
},
|
||||
"openapi-types@12.1.3": {
|
||||
"integrity": "sha512-N4YtSYJqghVu4iek2ZUvcN/0aqH1kRDuNqzcycDxhOUpg7GdvLa2F3DgS6yBNhInhv2r/6I0Flkn7CqL8+nIcw=="
|
||||
},
|
||||
@@ -247,10 +225,8 @@
|
||||
"jsr:@std/yaml@^1.0.5",
|
||||
"npm:@huggingface/jinja@~0.3.2",
|
||||
"npm:@scalar/hono-api-reference@~0.5.165",
|
||||
"npm:@types/nunjucks@^3.2.6",
|
||||
"npm:hono-openapi@0.3",
|
||||
"npm:hono@^4.6.14",
|
||||
"npm:nunjucks@^3.2.4",
|
||||
"npm:zod-openapi@^4.2.1",
|
||||
"npm:zod@^3.24.1"
|
||||
]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
{# Metadata #}
|
||||
{%- set stop_strings = ["### Instruction:", "### Input:", "### Response:"] -%}
|
||||
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
|
||||
{# Template #}
|
||||
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
|
||||
|
||||
|
||||
62
types/jinja.d.ts
vendored
Normal file
62
types/jinja.d.ts
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
export * from "@huggingface/jinja";
|
||||
|
||||
declare module "@huggingface/jinja" {
|
||||
export class Statement {
|
||||
type: string;
|
||||
}
|
||||
|
||||
export class Expression extends Statement {
|
||||
type: string;
|
||||
}
|
||||
|
||||
abstract class Literal<T> extends Expression {
|
||||
value: T;
|
||||
type: string;
|
||||
constructor(value: T);
|
||||
}
|
||||
|
||||
export class Identifier extends Expression {
|
||||
value: string;
|
||||
type: string;
|
||||
|
||||
/**
|
||||
* @param {string} value The name of the identifier
|
||||
*/
|
||||
constructor(value: string);
|
||||
}
|
||||
|
||||
export class NumericLiteral extends Literal<number> {
|
||||
type: string;
|
||||
}
|
||||
|
||||
export class StringLiteral extends Literal<string> {
|
||||
type: string;
|
||||
}
|
||||
|
||||
export class BooleanLiteral extends Literal<boolean> {
|
||||
type: string;
|
||||
}
|
||||
|
||||
export class NullLiteral extends Literal<null> {
|
||||
type: string;
|
||||
}
|
||||
|
||||
export class ArrayLiteral extends Literal<Expression[]> {
|
||||
type: string;
|
||||
}
|
||||
|
||||
export class TupleLiteral extends Literal<Expression[]> {
|
||||
type: string;
|
||||
}
|
||||
|
||||
export class ObjectLiteral extends Literal<Map<Expression, Expression>> {
|
||||
type: string;
|
||||
}
|
||||
|
||||
export class SetStatement extends Statement {
|
||||
assignee: Expression;
|
||||
value: Expression;
|
||||
type: string;
|
||||
constructor(assignee: Expression, value: Expression);
|
||||
}
|
||||
}
|
||||
35
types/nunjucks.d.ts
vendored
35
types/nunjucks.d.ts
vendored
@@ -1,35 +0,0 @@
|
||||
// Extended types for nunjucks
|
||||
// These are the bare minimum to parse out template metadata through the AST
|
||||
|
||||
// deno-lint-ignore-file no-explicit-any no-unused-vars
|
||||
// @ts-types="@types/nunjucks"
|
||||
import nunjucks from "nunjucks";
|
||||
export * from "nunjucks";
|
||||
|
||||
declare module "nunjucks" {
|
||||
export interface Node {
|
||||
lineno: number;
|
||||
colno: number;
|
||||
value?: any;
|
||||
children?: Node[];
|
||||
}
|
||||
|
||||
export interface TargetValue extends Node {
|
||||
value: string;
|
||||
}
|
||||
|
||||
export interface SetNode extends Node {
|
||||
targets: TargetValue[];
|
||||
value: Node;
|
||||
}
|
||||
|
||||
export interface NodeList extends Node {
|
||||
children: Node[];
|
||||
}
|
||||
|
||||
export interface ParserModule {
|
||||
parse(src: string, extensions?: any, opts?: any): Node;
|
||||
}
|
||||
|
||||
export const parser: ParserModule;
|
||||
}
|
||||
Reference in New Issue
Block a user