Files
ai-toolkit/ui_scripts/upsample_ideogram4_caption.py

392 lines
13 KiB
Python

"""Upsample a short user idea into a full Ideogram4 structured-JSON caption.
Runs the Ideogram4 generation ("magic prompt") system prompt through
Qwen/Qwen3-VL-8B-Instruct as a text-only request and returns the resulting JSON.
Nothing is written to disk -- the upsampled JSON object is printed to stdout
(progress/logs go to stderr so stdout stays clean for the caller to parse).
"""
import argparse
import json
import os
import re
import sys
from typing import Optional
import torch
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Make the repo importable (e.g. `toolkit.util.quantize`) regardless of cwd.
if REPO_ROOT not in sys.path:
sys.path.insert(0, REPO_ROOT)
from toolkit.ideogram_caption import normalize_caption_dict
# The generation prompt lives here. It's a `name = """<content>"""` file, but the
# content intentionally contains literal `\uNNNN` and `\n` sequences that are not
# valid Python escapes, so it cannot be imported -- we read the triple-quoted
# content verbatim as text instead.
_PROMPT_PATH = os.path.join(
REPO_ROOT,
"extensions_built_in",
"captioner",
"prompts",
"ideogram4_upsample_prompt.py",
)
# Swapped into the prompt's {{mode_directive}} slot. Both keep the FIDELITY rules;
# they only differ on how much the model may expand beyond the literal prompt.
FAITHFUL_DIRECTIVE = (
"- **Fill in ONLY what the structure needs.** Add a concrete background shell, "
"bounding boxes, and the required elements/text -- nothing else. Do NOT add new "
"subjects, props, narrative, mood, or a setting the user did not specify. If the "
"prompt names no location, keep the background minimal. If the prompt is sparse, "
"the scene stays sparse."
)
CREATIVE_DIRECTIVE = (
"- **Expand the scene while keeping the user's idea intact.** Place the subject in "
"a specific, believable setting and build a real background environment with fitting "
"secondary details (props, depth layers, atmosphere) that serve the idea -- never a "
"blank or 'plain' background when a setting can be implied. Everything you add must "
"support, never replace or contradict, what the user asked for, and you must not "
"introduce a different main subject. The FIDELITY rules above still hold: triggers "
"verbatim, no invented appearance for a named person, no elaboration of a named style."
)
DTYPE_MAP = {
"float32": torch.float32,
"fp32": torch.float32,
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
}
def log(message: str) -> None:
print(message, file=sys.stderr, flush=True)
def load_generation_prompt() -> str:
with open(_PROMPT_PATH, "r", encoding="utf-8") as f:
src = f.read()
# Extract the triple-quoted body verbatim (see note on _PROMPT_PATH).
start = src.find('"""')
end = src.rfind('"""')
if start == -1 or end <= start:
raise RuntimeError(f"Could not parse prompt body from {_PROMPT_PATH}")
return src[start + 3 : end]
def build_prompt(
template: str,
aspect_ratio: str,
original_prompt: str,
creative: bool = False,
instructions: str = "",
) -> str:
directive = CREATIVE_DIRECTIVE if creative else FAITHFUL_DIRECTIVE
prompt = template.replace("{{mode_directive}}", directive)
prompt = prompt.replace("{{user_instructions}}", instructions.strip() or "None.")
prompt = prompt.replace("{{aspect_ratio}}", aspect_ratio)
prompt = prompt.replace("{{original_prompt}}", original_prompt)
return prompt
def extract_json(raw: str):
"""Pull the JSON object out of the model output, tolerating code fences and
stray preamble. Returns the parsed dict or None."""
text = raw.strip()
fence = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL)
if fence:
text = fence.group(1).strip()
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end <= start:
return None
try:
return json.loads(text[start : end + 1])
except json.JSONDecodeError:
return None
def sanitize_bbox(bbox):
"""The generation prompt already emits normalized 0-1000 [y1,x1,y2,x2]. Clamp
to range, sort each axis pair, coerce to ints (keeps y/x order). Returns the
cleaned box or None to drop it."""
if not isinstance(bbox, (list, tuple)) or len(bbox) != 4:
return None
try:
y1, x1, y2, x2 = [float(v) for v in bbox]
except (TypeError, ValueError):
return None
y1, y2 = sorted((max(0, min(1000, round(y1))), max(0, min(1000, round(y2)))))
x1, x2 = sorted((max(0, min(1000, round(x1))), max(0, min(1000, round(x2)))))
if y2 <= y1 or x2 <= x1:
return None
return [y1, x1, y2, x2]
def sanitize_caption(data: dict) -> dict:
"""Clamp each bbox to valid 0-1000 [y1,x1,y2,x2], then hand off to the shared
normalizer for the rest: drop aspect_ratio, enforce the photo/art_style branch
and key order, canonicalize medium, and cap/uppercase color palettes (16 per
image, 5 per element)."""
decon = data.get("compositional_deconstruction", {})
elements = decon.get("elements", []) if isinstance(decon, dict) else []
if isinstance(elements, list):
for el in elements:
if isinstance(el, dict) and "bbox" in el:
cleaned = sanitize_bbox(el["bbox"])
if cleaned is None:
el.pop("bbox", None)
else:
el["bbox"] = cleaned
return normalize_caption_dict(data)
def upsample_one(
model,
processor,
device,
template,
idea,
aspect_ratio,
gen_kwargs,
creative=False,
instructions="",
) -> Optional[dict]:
"""Run one idea through the generation prompt. Returns the cleaned caption
dict, or None if the model output couldn't be parsed."""
full_prompt = build_prompt(
template, aspect_ratio, idea.strip(), creative, instructions
)
messages = [{"role": "user", "content": [{"type": "text", "text": full_prompt}]}]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
).to(device)
generated_ids = model.generate(**inputs, **gen_kwargs)
trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0].strip()
data = extract_json(output_text)
if data is None:
log("Failed to parse JSON from model output. Raw output follows:")
log(output_text)
return None
return sanitize_caption(data)
def normalize_item(item, default_aspect_ratio):
"""Accept either a bare prompt string or {'prompt': ..., 'aspect_ratio': ...}.
Returns (idea, aspect_ratio) or None if the item is malformed/empty."""
if isinstance(item, str):
idea, aspect_ratio = item, default_aspect_ratio
elif isinstance(item, dict) and isinstance(item.get("prompt"), str):
idea = item["prompt"]
aspect_ratio = item.get("aspect_ratio") or default_aspect_ratio
else:
return None
if not idea.strip():
return None
return idea, aspect_ratio
def load_model(
model_name_or_path: str,
dtype: torch.dtype,
device: torch.device,
quantize: bool,
qtype: str,
):
from transformers import (
Qwen3VLForConditionalGeneration,
Qwen3VLMoeForConditionalGeneration,
AutoProcessor,
)
ModelClass = (
Qwen3VLMoeForConditionalGeneration
if "B-A" in model_name_or_path
else Qwen3VLForConditionalGeneration
)
log(f"Loading {model_name_or_path}")
model = ModelClass.from_pretrained(
model_name_or_path, dtype=dtype, device_map="cpu"
)
if quantize:
# Lazy import so the common (non-quantized) path needs no toolkit deps.
from optimum.quanto import freeze
from toolkit.util.quantize import quantize as quantize_model, get_qtype
log(f"Quantizing model ({qtype})")
quantize_model(model, weights=get_qtype(qtype))
freeze(model)
model.to(device)
model.eval()
processor = AutoProcessor.from_pretrained(model_name_or_path)
return model, processor
def main() -> int:
parser = argparse.ArgumentParser(
description="Upsample a short idea into an Ideogram4 structured-JSON caption."
)
parser.add_argument(
"--prompt",
default=None,
help="A single user idea to upsample (prints one JSON object).",
)
parser.add_argument(
"--prompts",
default=None,
help=(
"JSON list to upsample in one model load (prints a JSON list, same order). "
'Each item is a prompt string or {"prompt": "...", "aspect_ratio": "W:H"}. '
"Failed/empty items come back as null."
),
)
parser.add_argument(
"--aspect_ratio",
default="auto",
help="Default aspect ratio as 'W:H', or 'auto'. Per-item values override it.",
)
parser.add_argument("--model_name_or_path", default="Qwen/Qwen3-VL-8B-Instruct")
parser.add_argument("--max_new_tokens", type=int, default=3072)
parser.add_argument("--device", default="cuda")
parser.add_argument("--dtype", default="bf16", choices=list(DTYPE_MAP.keys()))
parser.add_argument("--quantize", action="store_true")
parser.add_argument("--qtype", default="float8")
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Sampling temperature. <= 0 uses greedy decoding.",
)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument(
"--creative",
action="store_true",
help="Expand the prompt into a populated scene (default: faithful/minimal).",
)
parser.add_argument(
"--instructions",
default="",
help="Extra user instructions injected into the system prompt for every item.",
)
parser.add_argument("--pretty", action="store_true", help="Indent the output JSON.")
parser.add_argument(
"--stream",
action="store_true",
help=(
"Emit one compact JSON line per prompt as it completes "
'({"index": i, "caption": {...}|null}) instead of a single final list.'
),
)
args = parser.parse_args()
if bool(args.prompt) == bool(args.prompts):
print(
"Provide exactly one of --prompt or --prompts.", file=sys.stderr, flush=True
)
return 2
# Resolve the work list up front so we can fail fast on bad input.
if args.prompts is not None:
try:
raw_items = json.loads(args.prompts)
except json.JSONDecodeError as e:
print(f"Failed to parse --prompts JSON: {e}", file=sys.stderr, flush=True)
return 2
if not isinstance(raw_items, list) or len(raw_items) == 0:
print(
"--prompts must be a non-empty JSON list.", file=sys.stderr, flush=True
)
return 2
batch = True
else:
if not args.prompt.strip():
print("--prompt must not be empty.", file=sys.stderr, flush=True)
return 2
raw_items = [args.prompt]
batch = False
if args.seed is not None:
torch.manual_seed(args.seed)
device = torch.device(args.device)
dtype = DTYPE_MAP[args.dtype]
indent = 2 if args.pretty else None
template = load_generation_prompt()
gen_kwargs = {"max_new_tokens": args.max_new_tokens}
if args.temperature and args.temperature > 0:
gen_kwargs.update(do_sample=True, temperature=args.temperature)
else:
gen_kwargs.update(do_sample=False)
with torch.no_grad():
model, processor = load_model(
args.model_name_or_path, dtype, device, args.quantize, args.qtype
)
results = []
for idx, item in enumerate(raw_items):
norm = normalize_item(item, args.aspect_ratio)
if norm is None:
log(f"[{idx + 1}/{len(raw_items)}] invalid/empty item, skipping")
result = None
else:
idea, aspect_ratio = norm
log(
f"[{idx + 1}/{len(raw_items)}] Generating (aspect_ratio={aspect_ratio})..."
)
result = upsample_one(
model,
processor,
device,
template,
idea,
aspect_ratio,
gen_kwargs,
args.creative,
args.instructions,
)
results.append(result)
# Stream each result on its own compact line so callers can update live.
if args.stream:
print(
json.dumps({"index": idx, "caption": result}, ensure_ascii=False),
flush=True,
)
if args.stream:
return 0 if any(r is not None for r in results) else 1
if batch:
print(json.dumps(results, ensure_ascii=False, indent=indent), flush=True)
# Non-zero only if nothing succeeded.
return 0 if any(r is not None for r in results) else 1
if results[0] is None:
return 1
print(json.dumps(results[0], ensure_ascii=False, indent=indent), flush=True)
return 0
if __name__ == "__main__":
sys.exit(main())