Templating: Switch to @huggingface/jinja

Since llamacpp will be using arbitrary jinja templates soon, use the
more secure jinja package from huggingface. This comes with increased
security at the cost of most jinja features.

However, huggingface's Jinja will be updated to keep up with chat templates
for AI models.

Since the AST isn't typed in the package, include the types here.

Signed-off-by: kingbri <8082010+bdashore3@users.noreply.github.com>
This commit is contained in:
kingbri
2024-12-30 13:56:47 -05:00
parent cd8197cedd
commit d767d31ed6
7 changed files with 112 additions and 104 deletions

View File

@@ -421,13 +421,24 @@ export class Model {
let promptTemplate: PromptTemplate | undefined = undefined;
if (params.prompt_template) {
promptTemplate = await PromptTemplate.fromFile(
`templates/${params.prompt_template}`,
);
try {
promptTemplate = await PromptTemplate.fromFile(
`templates/${params.prompt_template}`,
);
logger.info(
`Using template "${promptTemplate.name}" for chat completions`,
);
logger.info(
`Using template "${promptTemplate.name}" for chat completions`,
);
} catch (error) {
if (error instanceof Error) {
logger.warn(
"Could not create a prompt template because of the following error:\n " +
`${error.stack}\n\n` +
"YALS will continue loading without the prompt template.\n" +
"Please proofread the template and make sure it's compatible with huggingface's jinja subset.",
);
}
}
}
return new Model(

View File

@@ -1,5 +1,11 @@
// @ts-types="@/types/nunjucks.d.ts"
import nunjucks from "nunjucks";
// @ts-types="@/types/jinja.d.ts"
import {
ArrayLiteral,
Identifier,
Literal,
SetStatement,
Template,
} from "@huggingface/jinja";
import * as z from "@/common/myZod.ts";
import * as Path from "@std/path";
@@ -11,15 +17,10 @@ const TemplateMetadataSchema = z.object({
type TemplateMetadata = z.infer<typeof TemplateMetadataSchema>;
function raiseException(message: string) {
throw new Error(message);
}
export class PromptTemplate {
name: string;
rawTemplate: string;
environment: nunjucks.Environment;
template: nunjucks.Template;
template: Template;
metadata: TemplateMetadata;
public constructor(
@@ -28,53 +29,49 @@ export class PromptTemplate {
) {
this.name = name;
this.rawTemplate = rawTemplate;
this.environment = nunjucks.configure({ autoescape: false })
.addGlobal("raise_exception", raiseException);
this.template = new nunjucks.Template(rawTemplate, this.environment);
this.metadata = this.extractMetadata(rawTemplate);
this.template = new Template(rawTemplate);
this.metadata = this.extractMetadata(this.template);
}
public render(context: object = {}): string {
public render(context: Record<string, unknown> = {}): string {
return this.template.render(context);
}
private extractMetadata(rawTemplate: string): TemplateMetadata {
const ast = nunjucks.parser.parse(rawTemplate);
private extractMetadata(template: Template) {
const metadata: TemplateMetadata = TemplateMetadataSchema.parse({});
if (!ast.children) {
return metadata;
}
ast.children.forEach((node) => {
// Targets is unique to a setNode
if ("targets" in node) {
const setNode = node as nunjucks.SetNode;
if (setNode.targets.length === 0) {
return;
}
template.parsed.body.forEach((statement) => {
if (statement.type === "Set") {
const setStatement = statement as SetStatement;
const assignee = setStatement.assignee as Identifier;
const foundMetaKey = Object.keys(TemplateMetadataSchema.shape)
.find(
(key) => key === setNode.targets[0].value,
(key) => key === assignee.value,
) as keyof TemplateMetadata;
if (foundMetaKey) {
// Get field schema from overall schema
const fieldSchema =
TemplateMetadataSchema.shape[foundMetaKey];
// Only use for validation. For some reason, the parsed data can't be assigned
let result;
if (setNode.value.children) {
result = setNode.value.children.map((child) =>
child.value
);
} else {
result = setNode.value.value;
let result: unknown;
if (setStatement.value.type === "ArrayLiteral") {
const arrayValue = setStatement.value as ArrayLiteral;
result = arrayValue.value.map((e) => {
const literalValue = e as Literal<unknown>;
return literalValue.value;
});
} else if (setStatement.value.type.endsWith("Literal")) {
const literalValue = setStatement.value as Literal<
unknown
>;
result = literalValue.value;
}
const parsedValue = fieldSchema.safeParse(result);
if (parsedValue.success) {
metadata[foundMetaKey] = result;
// deno-lint-ignore no-explicit-any
metadata[foundMetaKey] = parsedValue.data as any;
}
}
}

View File

@@ -7,7 +7,6 @@
},
"imports": {
"@/": "./",
"@types/nunjucks": "npm:@types/nunjucks@^3.2.6",
"hono": "npm:hono@^4.6.14",
"@huggingface/jinja": "npm:@huggingface/jinja@^0.3.2",
"@scalar/hono-api-reference": "npm:@scalar/hono-api-reference@^0.5.165",
@@ -16,7 +15,6 @@
"hono-openapi": "npm:hono-openapi@^0.3.0",
"logtape": "jsr:@logtape/logtape@^0.8.0",
"@std/yaml": "jsr:@std/yaml@^1.0.5",
"nunjucks": "npm:nunjucks@^3.2.4",
"zod": "npm:zod@^3.24.1",
"zod-openapi": "npm:zod-openapi@^4.2.1"
},

24
deno.lock generated
View File

@@ -8,10 +8,8 @@
"npm:@huggingface/jinja@~0.3.2": "0.3.2",
"npm:@scalar/hono-api-reference@~0.5.165": "0.5.165_hono@4.6.14",
"npm:@types/node@*": "22.5.4",
"npm:@types/nunjucks@^3.2.6": "3.2.6",
"npm:hono-openapi@0.3": "0.3.0_arktype@2.0.0-rc.25_hono@4.6.14_effect@3.12.0_@sinclair+typebox@0.34.13_valibot@1.0.0-beta.9_zod@3.24.1",
"npm:hono@^4.6.14": "4.6.14",
"npm:nunjucks@^3.2.4": "3.2.4",
"npm:zod-openapi@^4.2.1": "4.2.1_zod@3.24.1",
"npm:zod@^3.24.1": "3.24.1"
},
@@ -117,9 +115,6 @@
"undici-types"
]
},
"@types/nunjucks@3.2.6": {
"integrity": "sha512-pHiGtf83na1nCzliuAdq8GowYiXvH5l931xZ0YEHaLMNFgynpEqx+IPStlu7UaDkehfvl01e4x/9Tpwhy7Ue3w=="
},
"@unhead/schema@1.11.14": {
"integrity": "sha512-V9W9u5tF1/+TiLqxu+Qvh1ShoMDkPEwHoEo4DKdDG6ko7YlbzFfDxV6el9JwCren45U/4Vy/4Xi7j8OH02wsiA==",
"dependencies": [
@@ -133,9 +128,6 @@
"valibot"
]
},
"a-sync-waterfall@1.0.1": {
"integrity": "sha512-RYTOHHdWipFUliRFMCS4X2Yn2X8M87V/OpSqWzKKOGhzqyUxzyVmhHDH9sAvG+ZuQf/TAOFsLCpMw09I1ufUnA=="
},
"argparse@2.0.1": {
"integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q=="
},
@@ -146,15 +138,9 @@
"@ark/util"
]
},
"asap@2.0.6": {
"integrity": "sha512-BSHWgDSAiKs50o2Re8ppvp3seVHXSRM44cdSsT9FfNEUUZLOGWVCsiWaRPWM1Znn+mqZ1OfVZ3z3DWEzSp7hRA=="
},
"clone@2.1.2": {
"integrity": "sha512-3Pe/CF1Nn94hyhIYpjtiLhdCoEoz0DqQ+988E9gmeEdQZlojxnOb74wctFyuwWQHzqyf9X7C7MG8juUpqBJT8w=="
},
"commander@5.1.0": {
"integrity": "sha512-P0CysNDQ7rtVw4QIQtm+MRxV66vKFSvlsQvGYXZWR3qFU0jlMKHZZZgw8e+8DSah4UDKMqnknRDQz+xuQXQ/Zg=="
},
"effect@3.12.0": {
"integrity": "sha512-b/u9s3b9HfTo0qygVouegP0hkbiuxRIeaCe1ppf8P88hPyl6lKCbErtn7Az4jG7LuU7f0Wgm4c8WXbMcL2j8+g==",
"dependencies": [
@@ -206,14 +192,6 @@
"clone"
]
},
"nunjucks@3.2.4": {
"integrity": "sha512-26XRV6BhkgK0VOxfbU5cQI+ICFUtMLixv1noZn1tGU38kQH5A5nmmbk/O45xdyBhD1esk47nKrY0mvQpZIhRjQ==",
"dependencies": [
"a-sync-waterfall",
"asap",
"commander"
]
},
"openapi-types@12.1.3": {
"integrity": "sha512-N4YtSYJqghVu4iek2ZUvcN/0aqH1kRDuNqzcycDxhOUpg7GdvLa2F3DgS6yBNhInhv2r/6I0Flkn7CqL8+nIcw=="
},
@@ -247,10 +225,8 @@
"jsr:@std/yaml@^1.0.5",
"npm:@huggingface/jinja@~0.3.2",
"npm:@scalar/hono-api-reference@~0.5.165",
"npm:@types/nunjucks@^3.2.6",
"npm:hono-openapi@0.3",
"npm:hono@^4.6.14",
"npm:nunjucks@^3.2.4",
"npm:zod-openapi@^4.2.1",
"npm:zod@^3.24.1"
]

View File

@@ -1,6 +1,5 @@
{# Metadata #}
{%- set stop_strings = ["### Instruction:", "### Input:", "### Response:"] -%}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{# Template #}
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}

62
types/jinja.d.ts vendored Normal file
View File

@@ -0,0 +1,62 @@
export * from "@huggingface/jinja";
declare module "@huggingface/jinja" {
export class Statement {
type: string;
}
export class Expression extends Statement {
type: string;
}
abstract class Literal<T> extends Expression {
value: T;
type: string;
constructor(value: T);
}
export class Identifier extends Expression {
value: string;
type: string;
/**
* @param {string} value The name of the identifier
*/
constructor(value: string);
}
export class NumericLiteral extends Literal<number> {
type: string;
}
export class StringLiteral extends Literal<string> {
type: string;
}
export class BooleanLiteral extends Literal<boolean> {
type: string;
}
export class NullLiteral extends Literal<null> {
type: string;
}
export class ArrayLiteral extends Literal<Expression[]> {
type: string;
}
export class TupleLiteral extends Literal<Expression[]> {
type: string;
}
export class ObjectLiteral extends Literal<Map<Expression, Expression>> {
type: string;
}
export class SetStatement extends Statement {
assignee: Expression;
value: Expression;
type: string;
constructor(assignee: Expression, value: Expression);
}
}

35
types/nunjucks.d.ts vendored
View File

@@ -1,35 +0,0 @@
// Extended types for nunjucks
// These are the bare minimum to parse out template metadata through the AST
// deno-lint-ignore-file no-explicit-any no-unused-vars
// @ts-types="@types/nunjucks"
import nunjucks from "nunjucks";
export * from "nunjucks";
declare module "nunjucks" {
export interface Node {
lineno: number;
colno: number;
value?: any;
children?: Node[];
}
export interface TargetValue extends Node {
value: string;
}
export interface SetNode extends Node {
targets: TargetValue[];
value: Node;
}
export interface NodeList extends Node {
children: Node[];
}
export interface ParserModule {
parse(src: string, extensions?: any, opts?: any): Node;
}
export const parser: ParserModule;
}