mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-06-29 10:47:22 +00:00
392 lines
13 KiB
Python
392 lines
13 KiB
Python
"""Upsample a short user idea into a full Ideogram4 structured-JSON caption.
|
|
|
|
Runs the Ideogram4 generation ("magic prompt") system prompt through
|
|
Qwen/Qwen3-VL-8B-Instruct as a text-only request and returns the resulting JSON.
|
|
Nothing is written to disk -- the upsampled JSON object is printed to stdout
|
|
(progress/logs go to stderr so stdout stays clean for the caller to parse).
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
import sys
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
# Make the repo importable (e.g. `toolkit.util.quantize`) regardless of cwd.
|
|
if REPO_ROOT not in sys.path:
|
|
sys.path.insert(0, REPO_ROOT)
|
|
|
|
from toolkit.ideogram_caption import normalize_caption_dict
|
|
|
|
# The generation prompt lives here. It's a `name = """<content>"""` file, but the
|
|
# content intentionally contains literal `\uNNNN` and `\n` sequences that are not
|
|
# valid Python escapes, so it cannot be imported -- we read the triple-quoted
|
|
# content verbatim as text instead.
|
|
_PROMPT_PATH = os.path.join(
|
|
REPO_ROOT,
|
|
"extensions_built_in",
|
|
"captioner",
|
|
"prompts",
|
|
"ideogram4_upsample_prompt.py",
|
|
)
|
|
|
|
# Swapped into the prompt's {{mode_directive}} slot. Both keep the FIDELITY rules;
|
|
# they only differ on how much the model may expand beyond the literal prompt.
|
|
FAITHFUL_DIRECTIVE = (
|
|
"- **Fill in ONLY what the structure needs.** Add a concrete background shell, "
|
|
"bounding boxes, and the required elements/text -- nothing else. Do NOT add new "
|
|
"subjects, props, narrative, mood, or a setting the user did not specify. If the "
|
|
"prompt names no location, keep the background minimal. If the prompt is sparse, "
|
|
"the scene stays sparse."
|
|
)
|
|
|
|
CREATIVE_DIRECTIVE = (
|
|
"- **Expand the scene while keeping the user's idea intact.** Place the subject in "
|
|
"a specific, believable setting and build a real background environment with fitting "
|
|
"secondary details (props, depth layers, atmosphere) that serve the idea -- never a "
|
|
"blank or 'plain' background when a setting can be implied. Everything you add must "
|
|
"support, never replace or contradict, what the user asked for, and you must not "
|
|
"introduce a different main subject. The FIDELITY rules above still hold: triggers "
|
|
"verbatim, no invented appearance for a named person, no elaboration of a named style."
|
|
)
|
|
|
|
DTYPE_MAP = {
|
|
"float32": torch.float32,
|
|
"fp32": torch.float32,
|
|
"float16": torch.float16,
|
|
"fp16": torch.float16,
|
|
"bfloat16": torch.bfloat16,
|
|
"bf16": torch.bfloat16,
|
|
}
|
|
|
|
|
|
def log(message: str) -> None:
|
|
print(message, file=sys.stderr, flush=True)
|
|
|
|
|
|
def load_generation_prompt() -> str:
|
|
with open(_PROMPT_PATH, "r", encoding="utf-8") as f:
|
|
src = f.read()
|
|
# Extract the triple-quoted body verbatim (see note on _PROMPT_PATH).
|
|
start = src.find('"""')
|
|
end = src.rfind('"""')
|
|
if start == -1 or end <= start:
|
|
raise RuntimeError(f"Could not parse prompt body from {_PROMPT_PATH}")
|
|
return src[start + 3 : end]
|
|
|
|
|
|
def build_prompt(
|
|
template: str,
|
|
aspect_ratio: str,
|
|
original_prompt: str,
|
|
creative: bool = False,
|
|
instructions: str = "",
|
|
) -> str:
|
|
directive = CREATIVE_DIRECTIVE if creative else FAITHFUL_DIRECTIVE
|
|
prompt = template.replace("{{mode_directive}}", directive)
|
|
prompt = prompt.replace("{{user_instructions}}", instructions.strip() or "None.")
|
|
prompt = prompt.replace("{{aspect_ratio}}", aspect_ratio)
|
|
prompt = prompt.replace("{{original_prompt}}", original_prompt)
|
|
return prompt
|
|
|
|
|
|
def extract_json(raw: str):
|
|
"""Pull the JSON object out of the model output, tolerating code fences and
|
|
stray preamble. Returns the parsed dict or None."""
|
|
text = raw.strip()
|
|
fence = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL)
|
|
if fence:
|
|
text = fence.group(1).strip()
|
|
start = text.find("{")
|
|
end = text.rfind("}")
|
|
if start == -1 or end == -1 or end <= start:
|
|
return None
|
|
try:
|
|
return json.loads(text[start : end + 1])
|
|
except json.JSONDecodeError:
|
|
return None
|
|
|
|
|
|
def sanitize_bbox(bbox):
|
|
"""The generation prompt already emits normalized 0-1000 [y1,x1,y2,x2]. Clamp
|
|
to range, sort each axis pair, coerce to ints (keeps y/x order). Returns the
|
|
cleaned box or None to drop it."""
|
|
if not isinstance(bbox, (list, tuple)) or len(bbox) != 4:
|
|
return None
|
|
try:
|
|
y1, x1, y2, x2 = [float(v) for v in bbox]
|
|
except (TypeError, ValueError):
|
|
return None
|
|
y1, y2 = sorted((max(0, min(1000, round(y1))), max(0, min(1000, round(y2)))))
|
|
x1, x2 = sorted((max(0, min(1000, round(x1))), max(0, min(1000, round(x2)))))
|
|
if y2 <= y1 or x2 <= x1:
|
|
return None
|
|
return [y1, x1, y2, x2]
|
|
|
|
|
|
def sanitize_caption(data: dict) -> dict:
|
|
"""Clamp each bbox to valid 0-1000 [y1,x1,y2,x2], then hand off to the shared
|
|
normalizer for the rest: drop aspect_ratio, enforce the photo/art_style branch
|
|
and key order, canonicalize medium, and cap/uppercase color palettes (16 per
|
|
image, 5 per element)."""
|
|
decon = data.get("compositional_deconstruction", {})
|
|
elements = decon.get("elements", []) if isinstance(decon, dict) else []
|
|
if isinstance(elements, list):
|
|
for el in elements:
|
|
if isinstance(el, dict) and "bbox" in el:
|
|
cleaned = sanitize_bbox(el["bbox"])
|
|
if cleaned is None:
|
|
el.pop("bbox", None)
|
|
else:
|
|
el["bbox"] = cleaned
|
|
return normalize_caption_dict(data)
|
|
|
|
|
|
def upsample_one(
|
|
model,
|
|
processor,
|
|
device,
|
|
template,
|
|
idea,
|
|
aspect_ratio,
|
|
gen_kwargs,
|
|
creative=False,
|
|
instructions="",
|
|
) -> Optional[dict]:
|
|
"""Run one idea through the generation prompt. Returns the cleaned caption
|
|
dict, or None if the model output couldn't be parsed."""
|
|
full_prompt = build_prompt(
|
|
template, aspect_ratio, idea.strip(), creative, instructions
|
|
)
|
|
messages = [{"role": "user", "content": [{"type": "text", "text": full_prompt}]}]
|
|
inputs = processor.apply_chat_template(
|
|
messages,
|
|
tokenize=True,
|
|
add_generation_prompt=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
).to(device)
|
|
|
|
generated_ids = model.generate(**inputs, **gen_kwargs)
|
|
trimmed = [
|
|
out_ids[len(in_ids) :]
|
|
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
]
|
|
output_text = processor.batch_decode(
|
|
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
)[0].strip()
|
|
|
|
data = extract_json(output_text)
|
|
if data is None:
|
|
log("Failed to parse JSON from model output. Raw output follows:")
|
|
log(output_text)
|
|
return None
|
|
return sanitize_caption(data)
|
|
|
|
|
|
def normalize_item(item, default_aspect_ratio):
|
|
"""Accept either a bare prompt string or {'prompt': ..., 'aspect_ratio': ...}.
|
|
Returns (idea, aspect_ratio) or None if the item is malformed/empty."""
|
|
if isinstance(item, str):
|
|
idea, aspect_ratio = item, default_aspect_ratio
|
|
elif isinstance(item, dict) and isinstance(item.get("prompt"), str):
|
|
idea = item["prompt"]
|
|
aspect_ratio = item.get("aspect_ratio") or default_aspect_ratio
|
|
else:
|
|
return None
|
|
if not idea.strip():
|
|
return None
|
|
return idea, aspect_ratio
|
|
|
|
|
|
def load_model(
|
|
model_name_or_path: str,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
quantize: bool,
|
|
qtype: str,
|
|
):
|
|
from transformers import (
|
|
Qwen3VLForConditionalGeneration,
|
|
Qwen3VLMoeForConditionalGeneration,
|
|
AutoProcessor,
|
|
)
|
|
|
|
ModelClass = (
|
|
Qwen3VLMoeForConditionalGeneration
|
|
if "B-A" in model_name_or_path
|
|
else Qwen3VLForConditionalGeneration
|
|
)
|
|
log(f"Loading {model_name_or_path}")
|
|
model = ModelClass.from_pretrained(
|
|
model_name_or_path, dtype=dtype, device_map="cpu"
|
|
)
|
|
if quantize:
|
|
# Lazy import so the common (non-quantized) path needs no toolkit deps.
|
|
from optimum.quanto import freeze
|
|
from toolkit.util.quantize import quantize as quantize_model, get_qtype
|
|
|
|
log(f"Quantizing model ({qtype})")
|
|
quantize_model(model, weights=get_qtype(qtype))
|
|
freeze(model)
|
|
model.to(device)
|
|
model.eval()
|
|
processor = AutoProcessor.from_pretrained(model_name_or_path)
|
|
return model, processor
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(
|
|
description="Upsample a short idea into an Ideogram4 structured-JSON caption."
|
|
)
|
|
parser.add_argument(
|
|
"--prompt",
|
|
default=None,
|
|
help="A single user idea to upsample (prints one JSON object).",
|
|
)
|
|
parser.add_argument(
|
|
"--prompts",
|
|
default=None,
|
|
help=(
|
|
"JSON list to upsample in one model load (prints a JSON list, same order). "
|
|
'Each item is a prompt string or {"prompt": "...", "aspect_ratio": "W:H"}. '
|
|
"Failed/empty items come back as null."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--aspect_ratio",
|
|
default="auto",
|
|
help="Default aspect ratio as 'W:H', or 'auto'. Per-item values override it.",
|
|
)
|
|
parser.add_argument("--model_name_or_path", default="Qwen/Qwen3-VL-8B-Instruct")
|
|
parser.add_argument("--max_new_tokens", type=int, default=3072)
|
|
parser.add_argument("--device", default="cuda")
|
|
parser.add_argument("--dtype", default="bf16", choices=list(DTYPE_MAP.keys()))
|
|
parser.add_argument("--quantize", action="store_true")
|
|
parser.add_argument("--qtype", default="float8")
|
|
parser.add_argument(
|
|
"--temperature",
|
|
type=float,
|
|
default=0.7,
|
|
help="Sampling temperature. <= 0 uses greedy decoding.",
|
|
)
|
|
parser.add_argument("--seed", type=int, default=None)
|
|
parser.add_argument(
|
|
"--creative",
|
|
action="store_true",
|
|
help="Expand the prompt into a populated scene (default: faithful/minimal).",
|
|
)
|
|
parser.add_argument(
|
|
"--instructions",
|
|
default="",
|
|
help="Extra user instructions injected into the system prompt for every item.",
|
|
)
|
|
parser.add_argument("--pretty", action="store_true", help="Indent the output JSON.")
|
|
parser.add_argument(
|
|
"--stream",
|
|
action="store_true",
|
|
help=(
|
|
"Emit one compact JSON line per prompt as it completes "
|
|
'({"index": i, "caption": {...}|null}) instead of a single final list.'
|
|
),
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if bool(args.prompt) == bool(args.prompts):
|
|
print(
|
|
"Provide exactly one of --prompt or --prompts.", file=sys.stderr, flush=True
|
|
)
|
|
return 2
|
|
|
|
# Resolve the work list up front so we can fail fast on bad input.
|
|
if args.prompts is not None:
|
|
try:
|
|
raw_items = json.loads(args.prompts)
|
|
except json.JSONDecodeError as e:
|
|
print(f"Failed to parse --prompts JSON: {e}", file=sys.stderr, flush=True)
|
|
return 2
|
|
if not isinstance(raw_items, list) or len(raw_items) == 0:
|
|
print(
|
|
"--prompts must be a non-empty JSON list.", file=sys.stderr, flush=True
|
|
)
|
|
return 2
|
|
batch = True
|
|
else:
|
|
if not args.prompt.strip():
|
|
print("--prompt must not be empty.", file=sys.stderr, flush=True)
|
|
return 2
|
|
raw_items = [args.prompt]
|
|
batch = False
|
|
|
|
if args.seed is not None:
|
|
torch.manual_seed(args.seed)
|
|
|
|
device = torch.device(args.device)
|
|
dtype = DTYPE_MAP[args.dtype]
|
|
indent = 2 if args.pretty else None
|
|
|
|
template = load_generation_prompt()
|
|
|
|
gen_kwargs = {"max_new_tokens": args.max_new_tokens}
|
|
if args.temperature and args.temperature > 0:
|
|
gen_kwargs.update(do_sample=True, temperature=args.temperature)
|
|
else:
|
|
gen_kwargs.update(do_sample=False)
|
|
|
|
with torch.no_grad():
|
|
model, processor = load_model(
|
|
args.model_name_or_path, dtype, device, args.quantize, args.qtype
|
|
)
|
|
|
|
results = []
|
|
for idx, item in enumerate(raw_items):
|
|
norm = normalize_item(item, args.aspect_ratio)
|
|
if norm is None:
|
|
log(f"[{idx + 1}/{len(raw_items)}] invalid/empty item, skipping")
|
|
result = None
|
|
else:
|
|
idea, aspect_ratio = norm
|
|
log(
|
|
f"[{idx + 1}/{len(raw_items)}] Generating (aspect_ratio={aspect_ratio})..."
|
|
)
|
|
result = upsample_one(
|
|
model,
|
|
processor,
|
|
device,
|
|
template,
|
|
idea,
|
|
aspect_ratio,
|
|
gen_kwargs,
|
|
args.creative,
|
|
args.instructions,
|
|
)
|
|
results.append(result)
|
|
# Stream each result on its own compact line so callers can update live.
|
|
if args.stream:
|
|
print(
|
|
json.dumps({"index": idx, "caption": result}, ensure_ascii=False),
|
|
flush=True,
|
|
)
|
|
|
|
if args.stream:
|
|
return 0 if any(r is not None for r in results) else 1
|
|
|
|
if batch:
|
|
print(json.dumps(results, ensure_ascii=False, indent=indent), flush=True)
|
|
# Non-zero only if nothing succeeded.
|
|
return 0 if any(r is not None for r in results) else 1
|
|
|
|
if results[0] is None:
|
|
return 1
|
|
print(json.dumps(results[0], ensure_ascii=False, indent=indent), flush=True)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|