mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
remove everything related to training
This commit is contained in:
@@ -1,64 +0,0 @@
|
|||||||
from PIL import Image
|
|
||||||
|
|
||||||
from modules import scripts_postprocessing, ui_components
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
|
|
||||||
def center_crop(image: Image, w: int, h: int):
|
|
||||||
iw, ih = image.size
|
|
||||||
if ih / h < iw / w:
|
|
||||||
sw = w * ih / h
|
|
||||||
box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
|
|
||||||
else:
|
|
||||||
sh = h * iw / w
|
|
||||||
box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
|
|
||||||
return image.resize((w, h), Image.Resampling.LANCZOS, box)
|
|
||||||
|
|
||||||
|
|
||||||
def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
|
|
||||||
iw, ih = image.size
|
|
||||||
err = lambda w, h: 1 - (lambda x: x if x < 1 else 1 / x)(iw / ih / (w / h))
|
|
||||||
wh = max(((w, h) for w in range(mindim, maxdim + 1, 64) for h in range(mindim, maxdim + 1, 64)
|
|
||||||
if minarea <= w * h <= maxarea and err(w, h) <= threshold),
|
|
||||||
key=lambda wh: (wh[0] * wh[1], -err(*wh))[::1 if objective == 'Maximize area' else -1],
|
|
||||||
default=None
|
|
||||||
)
|
|
||||||
return wh and center_crop(image, *wh)
|
|
||||||
|
|
||||||
|
|
||||||
class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing):
|
|
||||||
name = "Auto-sized crop"
|
|
||||||
order = 4020
|
|
||||||
|
|
||||||
def ui(self):
|
|
||||||
with ui_components.InputAccordion(False, label="Auto-sized crop") as enable:
|
|
||||||
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
|
|
||||||
with gr.Row():
|
|
||||||
mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="postprocess_multicrop_mindim")
|
|
||||||
maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="postprocess_multicrop_maxdim")
|
|
||||||
with gr.Row():
|
|
||||||
minarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area lower bound", value=64 * 64, elem_id="postprocess_multicrop_minarea")
|
|
||||||
maxarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area upper bound", value=640 * 640, elem_id="postprocess_multicrop_maxarea")
|
|
||||||
with gr.Row():
|
|
||||||
objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="postprocess_multicrop_objective")
|
|
||||||
threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="postprocess_multicrop_threshold")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"enable": enable,
|
|
||||||
"mindim": mindim,
|
|
||||||
"maxdim": maxdim,
|
|
||||||
"minarea": minarea,
|
|
||||||
"maxarea": maxarea,
|
|
||||||
"objective": objective,
|
|
||||||
"threshold": threshold,
|
|
||||||
}
|
|
||||||
|
|
||||||
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, mindim, maxdim, minarea, maxarea, objective, threshold):
|
|
||||||
if not enable:
|
|
||||||
return
|
|
||||||
|
|
||||||
cropped = multicrop_pic(pp.image, mindim, maxdim, minarea, maxarea, objective, threshold)
|
|
||||||
if cropped is not None:
|
|
||||||
pp.image = cropped
|
|
||||||
else:
|
|
||||||
print(f"skipped {pp.image.width}x{pp.image.height} image (can't find suitable size within error threshold)")
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
from modules import scripts_postprocessing, ui_components, deepbooru, shared
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
|
|
||||||
class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing):
|
|
||||||
name = "Caption"
|
|
||||||
order = 4040
|
|
||||||
|
|
||||||
def ui(self):
|
|
||||||
with ui_components.InputAccordion(False, label="Caption") as enable:
|
|
||||||
option = gr.CheckboxGroup(value=["Deepbooru"], choices=["Deepbooru", "BLIP"], show_label=False)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"enable": enable,
|
|
||||||
"option": option,
|
|
||||||
}
|
|
||||||
|
|
||||||
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):
|
|
||||||
if not enable:
|
|
||||||
return
|
|
||||||
|
|
||||||
captions = [pp.caption]
|
|
||||||
|
|
||||||
if "Deepbooru" in option:
|
|
||||||
captions.append(deepbooru.model.tag(pp.image))
|
|
||||||
|
|
||||||
if "BLIP" in option:
|
|
||||||
captions.append(shared.interrogator.interrogate(pp.image.convert("RGB")))
|
|
||||||
|
|
||||||
pp.caption = ", ".join([x for x in captions if x])
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
from PIL import ImageOps, Image
|
|
||||||
|
|
||||||
from modules import scripts_postprocessing, ui_components
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
|
|
||||||
class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing):
|
|
||||||
name = "Create flipped copies"
|
|
||||||
order = 4030
|
|
||||||
|
|
||||||
def ui(self):
|
|
||||||
with ui_components.InputAccordion(False, label="Create flipped copies") as enable:
|
|
||||||
with gr.Row():
|
|
||||||
option = gr.CheckboxGroup(value=["Horizontal"], choices=["Horizontal", "Vertical", "Both"], show_label=False)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"enable": enable,
|
|
||||||
"option": option,
|
|
||||||
}
|
|
||||||
|
|
||||||
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):
|
|
||||||
if not enable:
|
|
||||||
return
|
|
||||||
|
|
||||||
if "Horizontal" in option:
|
|
||||||
pp.extra_images.append(ImageOps.mirror(pp.image))
|
|
||||||
|
|
||||||
if "Vertical" in option:
|
|
||||||
pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM))
|
|
||||||
|
|
||||||
if "Both" in option:
|
|
||||||
pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).transpose(Image.Transpose.FLIP_LEFT_RIGHT))
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
|
|
||||||
from modules import scripts_postprocessing, ui_components, errors
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
from modules.textual_inversion import autocrop
|
|
||||||
|
|
||||||
|
|
||||||
class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing):
|
|
||||||
name = "Auto focal point crop"
|
|
||||||
order = 4010
|
|
||||||
|
|
||||||
def ui(self):
|
|
||||||
with ui_components.InputAccordion(False, label="Auto focal point crop") as enable:
|
|
||||||
face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight")
|
|
||||||
entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight")
|
|
||||||
edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight")
|
|
||||||
debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"enable": enable,
|
|
||||||
"face_weight": face_weight,
|
|
||||||
"entropy_weight": entropy_weight,
|
|
||||||
"edges_weight": edges_weight,
|
|
||||||
"debug": debug,
|
|
||||||
}
|
|
||||||
|
|
||||||
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug):
|
|
||||||
if not enable:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not pp.shared.target_width or not pp.shared.target_height:
|
|
||||||
return
|
|
||||||
|
|
||||||
dnn_model_path = None
|
|
||||||
try:
|
|
||||||
dnn_model_path = autocrop.download_and_cache_models()
|
|
||||||
except Exception:
|
|
||||||
errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True)
|
|
||||||
|
|
||||||
autocrop_settings = autocrop.Settings(
|
|
||||||
crop_width=pp.shared.target_width,
|
|
||||||
crop_height=pp.shared.target_height,
|
|
||||||
face_points_weight=face_weight,
|
|
||||||
entropy_points_weight=entropy_weight,
|
|
||||||
corner_points_weight=edges_weight,
|
|
||||||
annotate_image=debug,
|
|
||||||
dnn_model_path=dnn_model_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
result, *others = autocrop.crop_image(pp.image, autocrop_settings)
|
|
||||||
|
|
||||||
pp.image = result
|
|
||||||
pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others]
|
|
||||||
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
from modules import scripts_postprocessing, ui_components
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
|
|
||||||
def split_pic(image, inverse_xy, width, height, overlap_ratio):
|
|
||||||
if inverse_xy:
|
|
||||||
from_w, from_h = image.height, image.width
|
|
||||||
to_w, to_h = height, width
|
|
||||||
else:
|
|
||||||
from_w, from_h = image.width, image.height
|
|
||||||
to_w, to_h = width, height
|
|
||||||
h = from_h * to_w // from_w
|
|
||||||
if inverse_xy:
|
|
||||||
image = image.resize((h, to_w))
|
|
||||||
else:
|
|
||||||
image = image.resize((to_w, h))
|
|
||||||
|
|
||||||
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
|
|
||||||
y_step = (h - to_h) / (split_count - 1)
|
|
||||||
for i in range(split_count):
|
|
||||||
y = int(y_step * i)
|
|
||||||
if inverse_xy:
|
|
||||||
splitted = image.crop((y, 0, y + to_h, to_w))
|
|
||||||
else:
|
|
||||||
splitted = image.crop((0, y, to_w, y + to_h))
|
|
||||||
yield splitted
|
|
||||||
|
|
||||||
|
|
||||||
class ScriptPostprocessingSplitOversized(scripts_postprocessing.ScriptPostprocessing):
|
|
||||||
name = "Split oversized images"
|
|
||||||
order = 4000
|
|
||||||
|
|
||||||
def ui(self):
|
|
||||||
with ui_components.InputAccordion(False, label="Split oversized images") as enable:
|
|
||||||
with gr.Row():
|
|
||||||
split_threshold = gr.Slider(label='Threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_split_threshold")
|
|
||||||
overlap_ratio = gr.Slider(label='Overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="postprocess_overlap_ratio")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"enable": enable,
|
|
||||||
"split_threshold": split_threshold,
|
|
||||||
"overlap_ratio": overlap_ratio,
|
|
||||||
}
|
|
||||||
|
|
||||||
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, split_threshold, overlap_ratio):
|
|
||||||
if not enable:
|
|
||||||
return
|
|
||||||
|
|
||||||
width = pp.shared.target_width
|
|
||||||
height = pp.shared.target_height
|
|
||||||
|
|
||||||
if not width or not height:
|
|
||||||
return
|
|
||||||
|
|
||||||
if pp.image.height > pp.image.width:
|
|
||||||
ratio = (pp.image.width * height) / (pp.image.height * width)
|
|
||||||
inverse_xy = False
|
|
||||||
else:
|
|
||||||
ratio = (pp.image.height * width) / (pp.image.width * height)
|
|
||||||
inverse_xy = True
|
|
||||||
|
|
||||||
if ratio >= 1.0 or ratio > split_threshold:
|
|
||||||
return
|
|
||||||
|
|
||||||
result, *others = split_pic(pp.image, inverse_xy, width, height, overlap_ratio)
|
|
||||||
|
|
||||||
pp.image = result
|
|
||||||
pp.extra_images = [pp.create_copy(x) for x in others]
|
|
||||||
|
|
||||||
@@ -21,8 +21,7 @@ from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, post
|
|||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
|
||||||
from PIL import PngImagePlugin
|
from PIL import PngImagePlugin
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
@@ -235,8 +234,6 @@ class Api:
|
|||||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
|
||||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
|
||||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||||
@@ -801,52 +798,6 @@ class Api:
|
|||||||
finally:
|
finally:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
||||||
def train_embedding(self, args: dict):
|
|
||||||
try:
|
|
||||||
shared.state.begin(job="train_embedding")
|
|
||||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
|
||||||
error = None
|
|
||||||
filename = ''
|
|
||||||
if not apply_optimizations:
|
|
||||||
sd_hijack.undo_optimizations()
|
|
||||||
try:
|
|
||||||
embedding, filename = train_embedding(**args) # can take a long time to complete
|
|
||||||
except Exception as e:
|
|
||||||
error = e
|
|
||||||
finally:
|
|
||||||
if not apply_optimizations:
|
|
||||||
sd_hijack.apply_optimizations()
|
|
||||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
|
||||||
except Exception as msg:
|
|
||||||
return models.TrainResponse(info=f"train embedding error: {msg}")
|
|
||||||
finally:
|
|
||||||
shared.state.end()
|
|
||||||
|
|
||||||
def train_hypernetwork(self, args: dict):
|
|
||||||
try:
|
|
||||||
shared.state.begin(job="train_hypernetwork")
|
|
||||||
shared.loaded_hypernetworks = []
|
|
||||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
|
||||||
error = None
|
|
||||||
filename = ''
|
|
||||||
if not apply_optimizations:
|
|
||||||
sd_hijack.undo_optimizations()
|
|
||||||
try:
|
|
||||||
hypernetwork, filename = train_hypernetwork(**args)
|
|
||||||
except Exception as e:
|
|
||||||
error = e
|
|
||||||
finally:
|
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
|
||||||
if not apply_optimizations:
|
|
||||||
sd_hijack.apply_optimizations()
|
|
||||||
shared.state.end()
|
|
||||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
|
||||||
except Exception as exc:
|
|
||||||
return models.TrainResponse(info=f"train embedding error: {exc}")
|
|
||||||
finally:
|
|
||||||
shared.state.end()
|
|
||||||
|
|
||||||
def get_memory(self):
|
def get_memory(self):
|
||||||
try:
|
try:
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -5,14 +5,12 @@ import os
|
|||||||
import inspect
|
import inspect
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
|
|
||||||
import modules.textual_inversion.dataset
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from backend.nn.unet import default
|
from backend.nn.unet import default
|
||||||
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
||||||
from modules.textual_inversion import textual_inversion, saving_settings
|
from modules.textual_inversion import textual_inversion
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||||
|
|
||||||
@@ -436,348 +434,348 @@ def statistics(data):
|
|||||||
recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
|
recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
|
||||||
return total_information, recent_information
|
return total_information, recent_information
|
||||||
|
|
||||||
|
#
|
||||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
# def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||||
# Remove illegal characters from name.
|
# # Remove illegal characters from name.
|
||||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
# name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||||
assert name, "Name cannot be empty!"
|
# assert name, "Name cannot be empty!"
|
||||||
|
#
|
||||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
# fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||||
if not overwrite_old:
|
# if not overwrite_old:
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
# assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
#
|
||||||
if type(layer_structure) == str:
|
# if type(layer_structure) == str:
|
||||||
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
# layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||||
|
#
|
||||||
if use_dropout and dropout_structure and type(dropout_structure) == str:
|
# if use_dropout and dropout_structure and type(dropout_structure) == str:
|
||||||
dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
|
# dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
|
||||||
else:
|
# else:
|
||||||
dropout_structure = [0] * len(layer_structure)
|
# dropout_structure = [0] * len(layer_structure)
|
||||||
|
#
|
||||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
# hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||||
name=name,
|
# name=name,
|
||||||
enable_sizes=[int(x) for x in enable_sizes],
|
# enable_sizes=[int(x) for x in enable_sizes],
|
||||||
layer_structure=layer_structure,
|
# layer_structure=layer_structure,
|
||||||
activation_func=activation_func,
|
# activation_func=activation_func,
|
||||||
weight_init=weight_init,
|
# weight_init=weight_init,
|
||||||
add_layer_norm=add_layer_norm,
|
# add_layer_norm=add_layer_norm,
|
||||||
use_dropout=use_dropout,
|
# use_dropout=use_dropout,
|
||||||
dropout_structure=dropout_structure
|
# dropout_structure=dropout_structure
|
||||||
)
|
# )
|
||||||
hypernet.save(fn)
|
# hypernet.save(fn)
|
||||||
|
#
|
||||||
shared.reload_hypernetworks()
|
# shared.reload_hypernetworks()
|
||||||
|
#
|
||||||
|
#
|
||||||
def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int):
|
# def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int):
|
||||||
from modules import images, processing
|
# from modules import images, processing
|
||||||
|
#
|
||||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
# save_hypernetwork_every = save_hypernetwork_every or 0
|
||||||
create_image_every = create_image_every or 0
|
# create_image_every = create_image_every or 0
|
||||||
template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
|
# template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
|
||||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
# textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
||||||
template_file = template_file.path
|
# template_file = template_file.path
|
||||||
|
#
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
# path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
hypernetwork = Hypernetwork()
|
# hypernetwork = Hypernetwork()
|
||||||
hypernetwork.load(path)
|
# hypernetwork.load(path)
|
||||||
shared.loaded_hypernetworks = [hypernetwork]
|
# shared.loaded_hypernetworks = [hypernetwork]
|
||||||
|
#
|
||||||
shared.state.job = "train-hypernetwork"
|
# shared.state.job = "train-hypernetwork"
|
||||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
# shared.state.textinfo = "Initializing hypernetwork training..."
|
||||||
shared.state.job_count = steps
|
# shared.state.job_count = steps
|
||||||
|
#
|
||||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
# hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
# filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
|
#
|
||||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
# log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
||||||
unload = shared.opts.unload_models_when_training
|
# unload = shared.opts.unload_models_when_training
|
||||||
|
#
|
||||||
if save_hypernetwork_every > 0:
|
# if save_hypernetwork_every > 0:
|
||||||
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
# hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
||||||
os.makedirs(hypernetwork_dir, exist_ok=True)
|
# os.makedirs(hypernetwork_dir, exist_ok=True)
|
||||||
else:
|
# else:
|
||||||
hypernetwork_dir = None
|
# hypernetwork_dir = None
|
||||||
|
#
|
||||||
if create_image_every > 0:
|
# if create_image_every > 0:
|
||||||
images_dir = os.path.join(log_directory, "images")
|
# images_dir = os.path.join(log_directory, "images")
|
||||||
os.makedirs(images_dir, exist_ok=True)
|
# os.makedirs(images_dir, exist_ok=True)
|
||||||
else:
|
# else:
|
||||||
images_dir = None
|
# images_dir = None
|
||||||
|
#
|
||||||
checkpoint = sd_models.select_checkpoint()
|
# checkpoint = sd_models.select_checkpoint()
|
||||||
|
#
|
||||||
initial_step = hypernetwork.step or 0
|
# initial_step = hypernetwork.step or 0
|
||||||
if initial_step >= steps:
|
# if initial_step >= steps:
|
||||||
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
# shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
||||||
return hypernetwork, filename
|
# return hypernetwork, filename
|
||||||
|
#
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
# scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
#
|
||||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
# clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
||||||
if clip_grad:
|
# if clip_grad:
|
||||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
# clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||||
|
#
|
||||||
if shared.opts.training_enable_tensorboard:
|
# if shared.opts.training_enable_tensorboard:
|
||||||
tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
|
# tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
|
||||||
|
#
|
||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# # dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
# shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
|
#
|
||||||
pin_memory = shared.opts.pin_memory
|
# pin_memory = shared.opts.pin_memory
|
||||||
|
#
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
|
# ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
|
||||||
|
#
|
||||||
if shared.opts.save_training_settings_to_txt:
|
# if shared.opts.save_training_settings_to_txt:
|
||||||
saved_params = dict(
|
# saved_params = dict(
|
||||||
model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
|
# model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
|
||||||
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
|
# **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
|
||||||
)
|
# )
|
||||||
saving_settings.save_settings_to_file(log_directory, {**saved_params, **locals()})
|
# saving_settings.save_settings_to_file(log_directory, {**saved_params, **locals()})
|
||||||
|
#
|
||||||
latent_sampling_method = ds.latent_sampling_method
|
# latent_sampling_method = ds.latent_sampling_method
|
||||||
|
#
|
||||||
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
# dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||||
|
#
|
||||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
# old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||||
|
#
|
||||||
if unload:
|
# if unload:
|
||||||
shared.parallel_processing_allowed = False
|
# shared.parallel_processing_allowed = False
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
# shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
#
|
||||||
weights = hypernetwork.weights()
|
# weights = hypernetwork.weights()
|
||||||
hypernetwork.train()
|
# hypernetwork.train()
|
||||||
|
#
|
||||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
# # Here we use optimizer from saved HN, or we can specify as UI option.
|
||||||
if hypernetwork.optimizer_name in optimizer_dict:
|
# if hypernetwork.optimizer_name in optimizer_dict:
|
||||||
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
# optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
||||||
optimizer_name = hypernetwork.optimizer_name
|
# optimizer_name = hypernetwork.optimizer_name
|
||||||
else:
|
# else:
|
||||||
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
# print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
||||||
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
# optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
||||||
optimizer_name = 'AdamW'
|
# optimizer_name = 'AdamW'
|
||||||
|
#
|
||||||
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
# if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||||
try:
|
# try:
|
||||||
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
# optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
||||||
except RuntimeError as e:
|
# except RuntimeError as e:
|
||||||
print("Cannot resume from saved optimizer!")
|
# print("Cannot resume from saved optimizer!")
|
||||||
print(e)
|
# print(e)
|
||||||
|
#
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
# scaler = torch.cuda.amp.GradScaler()
|
||||||
|
#
|
||||||
batch_size = ds.batch_size
|
# batch_size = ds.batch_size
|
||||||
gradient_step = ds.gradient_step
|
# gradient_step = ds.gradient_step
|
||||||
# n steps = batch_size * gradient_step * n image processed
|
# # n steps = batch_size * gradient_step * n image processed
|
||||||
steps_per_epoch = len(ds) // batch_size // gradient_step
|
# steps_per_epoch = len(ds) // batch_size // gradient_step
|
||||||
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
# max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
||||||
loss_step = 0
|
# loss_step = 0
|
||||||
_loss_step = 0 #internal
|
# _loss_step = 0 #internal
|
||||||
# size = len(ds.indexes)
|
# # size = len(ds.indexes)
|
||||||
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
# # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||||
loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
|
# loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
|
||||||
# losses = torch.zeros((size,))
|
# # losses = torch.zeros((size,))
|
||||||
# previous_mean_losses = [0]
|
# # previous_mean_losses = [0]
|
||||||
# previous_mean_loss = 0
|
# # previous_mean_loss = 0
|
||||||
# print("Mean loss of {} elements".format(size))
|
# # print("Mean loss of {} elements".format(size))
|
||||||
|
#
|
||||||
steps_without_grad = 0
|
# steps_without_grad = 0
|
||||||
|
#
|
||||||
last_saved_file = "<none>"
|
# last_saved_file = "<none>"
|
||||||
last_saved_image = "<none>"
|
# last_saved_image = "<none>"
|
||||||
forced_filename = "<none>"
|
# forced_filename = "<none>"
|
||||||
|
#
|
||||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
# pbar = tqdm.tqdm(total=steps - initial_step)
|
||||||
try:
|
# try:
|
||||||
sd_hijack_checkpoint.add()
|
# sd_hijack_checkpoint.add()
|
||||||
|
#
|
||||||
for _ in range((steps-initial_step) * gradient_step):
|
# for _ in range((steps-initial_step) * gradient_step):
|
||||||
if scheduler.finished:
|
# if scheduler.finished:
|
||||||
break
|
# break
|
||||||
if shared.state.interrupted:
|
# if shared.state.interrupted:
|
||||||
break
|
# break
|
||||||
for j, batch in enumerate(dl):
|
# for j, batch in enumerate(dl):
|
||||||
# works as a drop_last=True for gradient accumulation
|
# # works as a drop_last=True for gradient accumulation
|
||||||
if j == max_steps_per_epoch:
|
# if j == max_steps_per_epoch:
|
||||||
break
|
# break
|
||||||
scheduler.apply(optimizer, hypernetwork.step)
|
# scheduler.apply(optimizer, hypernetwork.step)
|
||||||
if scheduler.finished:
|
# if scheduler.finished:
|
||||||
break
|
# break
|
||||||
if shared.state.interrupted:
|
# if shared.state.interrupted:
|
||||||
break
|
# break
|
||||||
|
#
|
||||||
if clip_grad:
|
# if clip_grad:
|
||||||
clip_grad_sched.step(hypernetwork.step)
|
# clip_grad_sched.step(hypernetwork.step)
|
||||||
|
#
|
||||||
with devices.autocast():
|
# with devices.autocast():
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
# x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
if use_weight:
|
# if use_weight:
|
||||||
w = batch.weight.to(devices.device, non_blocking=pin_memory)
|
# w = batch.weight.to(devices.device, non_blocking=pin_memory)
|
||||||
if tag_drop_out != 0 or shuffle_tags:
|
# if tag_drop_out != 0 or shuffle_tags:
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
# shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
# c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
else:
|
# else:
|
||||||
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
# c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||||
if use_weight:
|
# if use_weight:
|
||||||
loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
|
# loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
|
||||||
del w
|
# del w
|
||||||
else:
|
# else:
|
||||||
loss = shared.sd_model.forward(x, c)[0] / gradient_step
|
# loss = shared.sd_model.forward(x, c)[0] / gradient_step
|
||||||
del x
|
# del x
|
||||||
del c
|
# del c
|
||||||
|
#
|
||||||
_loss_step += loss.item()
|
# _loss_step += loss.item()
|
||||||
scaler.scale(loss).backward()
|
# scaler.scale(loss).backward()
|
||||||
|
#
|
||||||
# go back until we reach gradient accumulation steps
|
# # go back until we reach gradient accumulation steps
|
||||||
if (j + 1) % gradient_step != 0:
|
# if (j + 1) % gradient_step != 0:
|
||||||
continue
|
# continue
|
||||||
loss_logging.append(_loss_step)
|
# loss_logging.append(_loss_step)
|
||||||
if clip_grad:
|
# if clip_grad:
|
||||||
clip_grad(weights, clip_grad_sched.learn_rate)
|
# clip_grad(weights, clip_grad_sched.learn_rate)
|
||||||
|
#
|
||||||
scaler.step(optimizer)
|
# scaler.step(optimizer)
|
||||||
scaler.update()
|
# scaler.update()
|
||||||
hypernetwork.step += 1
|
# hypernetwork.step += 1
|
||||||
pbar.update()
|
# pbar.update()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
# optimizer.zero_grad(set_to_none=True)
|
||||||
loss_step = _loss_step
|
# loss_step = _loss_step
|
||||||
_loss_step = 0
|
# _loss_step = 0
|
||||||
|
#
|
||||||
steps_done = hypernetwork.step + 1
|
# steps_done = hypernetwork.step + 1
|
||||||
|
#
|
||||||
epoch_num = hypernetwork.step // steps_per_epoch
|
# epoch_num = hypernetwork.step // steps_per_epoch
|
||||||
epoch_step = hypernetwork.step % steps_per_epoch
|
# epoch_step = hypernetwork.step % steps_per_epoch
|
||||||
|
#
|
||||||
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
|
# description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
|
||||||
pbar.set_description(description)
|
# pbar.set_description(description)
|
||||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
# if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||||
# Before saving, change name to match current checkpoint.
|
# # Before saving, change name to match current checkpoint.
|
||||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
# hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
# last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
||||||
hypernetwork.optimizer_name = optimizer_name
|
# hypernetwork.optimizer_name = optimizer_name
|
||||||
if shared.opts.save_optimizer_state:
|
# if shared.opts.save_optimizer_state:
|
||||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
# hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
# save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
# hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
if shared.opts.training_enable_tensorboard:
|
# if shared.opts.training_enable_tensorboard:
|
||||||
epoch_num = hypernetwork.step // len(ds)
|
# epoch_num = hypernetwork.step // len(ds)
|
||||||
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
# epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
||||||
mean_loss = sum(loss_logging) / len(loss_logging)
|
# mean_loss = sum(loss_logging) / len(loss_logging)
|
||||||
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
# textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
||||||
|
#
|
||||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
# textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
||||||
"loss": f"{loss_step:.7f}",
|
# "loss": f"{loss_step:.7f}",
|
||||||
"learn_rate": scheduler.learn_rate
|
# "learn_rate": scheduler.learn_rate
|
||||||
})
|
# })
|
||||||
|
#
|
||||||
if images_dir is not None and steps_done % create_image_every == 0:
|
# if images_dir is not None and steps_done % create_image_every == 0:
|
||||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
# forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
# last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
hypernetwork.eval()
|
# hypernetwork.eval()
|
||||||
rng_state = torch.get_rng_state()
|
# rng_state = torch.get_rng_state()
|
||||||
cuda_rng_state = None
|
# cuda_rng_state = None
|
||||||
if torch.cuda.is_available():
|
# if torch.cuda.is_available():
|
||||||
cuda_rng_state = torch.cuda.get_rng_state_all()
|
# cuda_rng_state = torch.cuda.get_rng_state_all()
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
# shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
# shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
#
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
# p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
# sd_model=shared.sd_model,
|
||||||
do_not_save_grid=True,
|
# do_not_save_grid=True,
|
||||||
do_not_save_samples=True,
|
# do_not_save_samples=True,
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
p.disable_extra_networks = True
|
# p.disable_extra_networks = True
|
||||||
|
#
|
||||||
if preview_from_txt2img:
|
# if preview_from_txt2img:
|
||||||
p.prompt = preview_prompt
|
# p.prompt = preview_prompt
|
||||||
p.negative_prompt = preview_negative_prompt
|
# p.negative_prompt = preview_negative_prompt
|
||||||
p.steps = preview_steps
|
# p.steps = preview_steps
|
||||||
p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
|
# p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
|
||||||
p.cfg_scale = preview_cfg_scale
|
# p.cfg_scale = preview_cfg_scale
|
||||||
p.seed = preview_seed
|
# p.seed = preview_seed
|
||||||
p.width = preview_width
|
# p.width = preview_width
|
||||||
p.height = preview_height
|
# p.height = preview_height
|
||||||
else:
|
# else:
|
||||||
p.prompt = batch.cond_text[0]
|
# p.prompt = batch.cond_text[0]
|
||||||
p.steps = 20
|
# p.steps = 20
|
||||||
p.width = training_width
|
# p.width = training_width
|
||||||
p.height = training_height
|
# p.height = training_height
|
||||||
|
#
|
||||||
preview_text = p.prompt
|
# preview_text = p.prompt
|
||||||
|
#
|
||||||
with closing(p):
|
# with closing(p):
|
||||||
processed = processing.process_images(p)
|
# processed = processing.process_images(p)
|
||||||
image = processed.images[0] if len(processed.images) > 0 else None
|
# image = processed.images[0] if len(processed.images) > 0 else None
|
||||||
|
#
|
||||||
if unload:
|
# if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
# shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
torch.set_rng_state(rng_state)
|
# torch.set_rng_state(rng_state)
|
||||||
if torch.cuda.is_available():
|
# if torch.cuda.is_available():
|
||||||
torch.cuda.set_rng_state_all(cuda_rng_state)
|
# torch.cuda.set_rng_state_all(cuda_rng_state)
|
||||||
hypernetwork.train()
|
# hypernetwork.train()
|
||||||
if image is not None:
|
# if image is not None:
|
||||||
shared.state.assign_current_image(image)
|
# shared.state.assign_current_image(image)
|
||||||
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
# if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||||
textual_inversion.tensorboard_add_image(tensorboard_writer,
|
# textual_inversion.tensorboard_add_image(tensorboard_writer,
|
||||||
f"Validation at epoch {epoch_num}", image,
|
# f"Validation at epoch {epoch_num}", image,
|
||||||
hypernetwork.step)
|
# hypernetwork.step)
|
||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
# last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
# last_saved_image += f", prompt: {preview_text}"
|
||||||
|
#
|
||||||
shared.state.job_no = hypernetwork.step
|
# shared.state.job_no = hypernetwork.step
|
||||||
|
#
|
||||||
shared.state.textinfo = f"""
|
# shared.state.textinfo = f"""
|
||||||
<p>
|
# <p>
|
||||||
Loss: {loss_step:.7f}<br/>
|
# Loss: {loss_step:.7f}<br/>
|
||||||
Step: {steps_done}<br/>
|
# Step: {steps_done}<br/>
|
||||||
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
# Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||||
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
# Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
# Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
# </p>
|
||||||
"""
|
# """
|
||||||
except Exception:
|
# except Exception:
|
||||||
errors.report("Exception in training hypernetwork", exc_info=True)
|
# errors.report("Exception in training hypernetwork", exc_info=True)
|
||||||
finally:
|
# finally:
|
||||||
pbar.leave = False
|
# pbar.leave = False
|
||||||
pbar.close()
|
# pbar.close()
|
||||||
hypernetwork.eval()
|
# hypernetwork.eval()
|
||||||
sd_hijack_checkpoint.remove()
|
# sd_hijack_checkpoint.remove()
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
# filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
hypernetwork.optimizer_name = optimizer_name
|
# hypernetwork.optimizer_name = optimizer_name
|
||||||
if shared.opts.save_optimizer_state:
|
# if shared.opts.save_optimizer_state:
|
||||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
# hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
# save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
||||||
|
#
|
||||||
del optimizer
|
# del optimizer
|
||||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
# hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
# shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
# shared.sd_model.first_stage_model.to(devices.device)
|
||||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
# shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||||
|
#
|
||||||
return hypernetwork, filename
|
# return hypernetwork, filename
|
||||||
|
#
|
||||||
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
# def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
||||||
old_hypernetwork_name = hypernetwork.name
|
# old_hypernetwork_name = hypernetwork.name
|
||||||
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
# old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
||||||
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
# old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
||||||
try:
|
# try:
|
||||||
hypernetwork.sd_checkpoint = checkpoint.shorthash
|
# hypernetwork.sd_checkpoint = checkpoint.shorthash
|
||||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
# hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||||
hypernetwork.name = hypernetwork_name
|
# hypernetwork.name = hypernetwork_name
|
||||||
hypernetwork.save(filename)
|
# hypernetwork.save(filename)
|
||||||
except:
|
# except:
|
||||||
hypernetwork.sd_checkpoint = old_sd_checkpoint
|
# hypernetwork.sd_checkpoint = old_sd_checkpoint
|
||||||
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
|
# hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
|
||||||
hypernetwork.name = old_hypernetwork_name
|
# hypernetwork.name = old_hypernetwork_name
|
||||||
raise
|
# raise
|
||||||
|
|||||||
@@ -1,345 +0,0 @@
|
|||||||
import cv2
|
|
||||||
import requests
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from PIL import ImageDraw
|
|
||||||
from modules import paths_internal
|
|
||||||
from pkg_resources import parse_version
|
|
||||||
|
|
||||||
GREEN = "#0F0"
|
|
||||||
BLUE = "#00F"
|
|
||||||
RED = "#F00"
|
|
||||||
|
|
||||||
|
|
||||||
def crop_image(im, settings):
|
|
||||||
""" Intelligently crop an image to the subject matter """
|
|
||||||
|
|
||||||
scale_by = 1
|
|
||||||
if is_landscape(im.width, im.height):
|
|
||||||
scale_by = settings.crop_height / im.height
|
|
||||||
elif is_portrait(im.width, im.height):
|
|
||||||
scale_by = settings.crop_width / im.width
|
|
||||||
elif is_square(im.width, im.height):
|
|
||||||
if is_square(settings.crop_width, settings.crop_height):
|
|
||||||
scale_by = settings.crop_width / im.width
|
|
||||||
elif is_landscape(settings.crop_width, settings.crop_height):
|
|
||||||
scale_by = settings.crop_width / im.width
|
|
||||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
|
||||||
scale_by = settings.crop_height / im.height
|
|
||||||
|
|
||||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
|
||||||
im_debug = im.copy()
|
|
||||||
|
|
||||||
focus = focal_point(im_debug, settings)
|
|
||||||
|
|
||||||
# take the focal point and turn it into crop coordinates that try to center over the focal
|
|
||||||
# point but then get adjusted back into the frame
|
|
||||||
y_half = int(settings.crop_height / 2)
|
|
||||||
x_half = int(settings.crop_width / 2)
|
|
||||||
|
|
||||||
x1 = focus.x - x_half
|
|
||||||
if x1 < 0:
|
|
||||||
x1 = 0
|
|
||||||
elif x1 + settings.crop_width > im.width:
|
|
||||||
x1 = im.width - settings.crop_width
|
|
||||||
|
|
||||||
y1 = focus.y - y_half
|
|
||||||
if y1 < 0:
|
|
||||||
y1 = 0
|
|
||||||
elif y1 + settings.crop_height > im.height:
|
|
||||||
y1 = im.height - settings.crop_height
|
|
||||||
|
|
||||||
x2 = x1 + settings.crop_width
|
|
||||||
y2 = y1 + settings.crop_height
|
|
||||||
|
|
||||||
crop = [x1, y1, x2, y2]
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
results.append(im.crop(tuple(crop)))
|
|
||||||
|
|
||||||
if settings.annotate_image:
|
|
||||||
d = ImageDraw.Draw(im_debug)
|
|
||||||
rect = list(crop)
|
|
||||||
rect[2] -= 1
|
|
||||||
rect[3] -= 1
|
|
||||||
d.rectangle(rect, outline=GREEN)
|
|
||||||
results.append(im_debug)
|
|
||||||
if settings.desktop_view_image:
|
|
||||||
im_debug.show()
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def focal_point(im, settings):
|
|
||||||
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
|
||||||
entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
|
|
||||||
face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
|
|
||||||
|
|
||||||
pois = []
|
|
||||||
|
|
||||||
weight_pref_total = 0
|
|
||||||
if corner_points:
|
|
||||||
weight_pref_total += settings.corner_points_weight
|
|
||||||
if entropy_points:
|
|
||||||
weight_pref_total += settings.entropy_points_weight
|
|
||||||
if face_points:
|
|
||||||
weight_pref_total += settings.face_points_weight
|
|
||||||
|
|
||||||
corner_centroid = None
|
|
||||||
if corner_points:
|
|
||||||
corner_centroid = centroid(corner_points)
|
|
||||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
|
||||||
pois.append(corner_centroid)
|
|
||||||
|
|
||||||
entropy_centroid = None
|
|
||||||
if entropy_points:
|
|
||||||
entropy_centroid = centroid(entropy_points)
|
|
||||||
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
|
|
||||||
pois.append(entropy_centroid)
|
|
||||||
|
|
||||||
face_centroid = None
|
|
||||||
if face_points:
|
|
||||||
face_centroid = centroid(face_points)
|
|
||||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
|
||||||
pois.append(face_centroid)
|
|
||||||
|
|
||||||
average_point = poi_average(pois, settings)
|
|
||||||
|
|
||||||
if settings.annotate_image:
|
|
||||||
d = ImageDraw.Draw(im)
|
|
||||||
max_size = min(im.width, im.height) * 0.07
|
|
||||||
if corner_centroid is not None:
|
|
||||||
color = BLUE
|
|
||||||
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
|
||||||
d.text((box[0], box[1] - 15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
|
|
||||||
d.ellipse(box, outline=color)
|
|
||||||
if len(corner_points) > 1:
|
|
||||||
for f in corner_points:
|
|
||||||
d.rectangle(f.bounding(4), outline=color)
|
|
||||||
if entropy_centroid is not None:
|
|
||||||
color = "#ff0"
|
|
||||||
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
|
||||||
d.text((box[0], box[1] - 15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
|
|
||||||
d.ellipse(box, outline=color)
|
|
||||||
if len(entropy_points) > 1:
|
|
||||||
for f in entropy_points:
|
|
||||||
d.rectangle(f.bounding(4), outline=color)
|
|
||||||
if face_centroid is not None:
|
|
||||||
color = RED
|
|
||||||
box = face_centroid.bounding(max_size * face_centroid.weight)
|
|
||||||
d.text((box[0], box[1] - 15), f"Face: {face_centroid.weight:.02f}", fill=color)
|
|
||||||
d.ellipse(box, outline=color)
|
|
||||||
if len(face_points) > 1:
|
|
||||||
for f in face_points:
|
|
||||||
d.rectangle(f.bounding(4), outline=color)
|
|
||||||
|
|
||||||
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
|
||||||
|
|
||||||
return average_point
|
|
||||||
|
|
||||||
|
|
||||||
def image_face_points(im, settings):
|
|
||||||
if settings.dnn_model_path is not None:
|
|
||||||
detector = cv2.FaceDetectorYN.create(
|
|
||||||
settings.dnn_model_path,
|
|
||||||
"",
|
|
||||||
(im.width, im.height),
|
|
||||||
0.9, # score threshold
|
|
||||||
0.3, # nms threshold
|
|
||||||
5000 # keep top k before nms
|
|
||||||
)
|
|
||||||
faces = detector.detect(np.array(im))
|
|
||||||
results = []
|
|
||||||
if faces[1] is not None:
|
|
||||||
for face in faces[1]:
|
|
||||||
x = face[0]
|
|
||||||
y = face[1]
|
|
||||||
w = face[2]
|
|
||||||
h = face[3]
|
|
||||||
results.append(
|
|
||||||
PointOfInterest(
|
|
||||||
int(x + (w * 0.5)), # face focus left/right is center
|
|
||||||
int(y + (h * 0.33)), # face focus up/down is close to the top of the head
|
|
||||||
size=w,
|
|
||||||
weight=1 / len(faces[1])
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
else:
|
|
||||||
np_im = np.array(im)
|
|
||||||
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
tries = [
|
|
||||||
[f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01],
|
|
||||||
[f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05],
|
|
||||||
[f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05],
|
|
||||||
[f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05],
|
|
||||||
[f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05],
|
|
||||||
[f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05],
|
|
||||||
[f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05],
|
|
||||||
[f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05]
|
|
||||||
]
|
|
||||||
for t in tries:
|
|
||||||
classifier = cv2.CascadeClassifier(t[0])
|
|
||||||
minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
|
|
||||||
try:
|
|
||||||
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
|
||||||
minNeighbors=7, minSize=(minsize, minsize),
|
|
||||||
flags=cv2.CASCADE_SCALE_IMAGE)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if faces:
|
|
||||||
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
|
|
||||||
return [PointOfInterest((r[0] + r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0] - r[2]),
|
|
||||||
weight=1 / len(rects)) for r in rects]
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def image_corner_points(im, settings):
|
|
||||||
grayscale = im.convert("L")
|
|
||||||
|
|
||||||
# naive attempt at preventing focal points from collecting at watermarks near the bottom
|
|
||||||
gd = ImageDraw.Draw(grayscale)
|
|
||||||
gd.rectangle([0, im.height * .9, im.width, im.height], fill="#999")
|
|
||||||
|
|
||||||
np_im = np.array(grayscale)
|
|
||||||
|
|
||||||
points = cv2.goodFeaturesToTrack(
|
|
||||||
np_im,
|
|
||||||
maxCorners=100,
|
|
||||||
qualityLevel=0.04,
|
|
||||||
minDistance=min(grayscale.width, grayscale.height) * 0.06,
|
|
||||||
useHarrisDetector=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if points is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
focal_points = []
|
|
||||||
for point in points:
|
|
||||||
x, y = point.ravel()
|
|
||||||
focal_points.append(PointOfInterest(x, y, size=4, weight=1 / len(points)))
|
|
||||||
|
|
||||||
return focal_points
|
|
||||||
|
|
||||||
|
|
||||||
def image_entropy_points(im, settings):
|
|
||||||
landscape = im.height < im.width
|
|
||||||
portrait = im.height > im.width
|
|
||||||
if landscape:
|
|
||||||
move_idx = [0, 2]
|
|
||||||
move_max = im.size[0]
|
|
||||||
elif portrait:
|
|
||||||
move_idx = [1, 3]
|
|
||||||
move_max = im.size[1]
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
e_max = 0
|
|
||||||
crop_current = [0, 0, settings.crop_width, settings.crop_height]
|
|
||||||
crop_best = crop_current
|
|
||||||
while crop_current[move_idx[1]] < move_max:
|
|
||||||
crop = im.crop(tuple(crop_current))
|
|
||||||
e = image_entropy(crop)
|
|
||||||
|
|
||||||
if (e > e_max):
|
|
||||||
e_max = e
|
|
||||||
crop_best = list(crop_current)
|
|
||||||
|
|
||||||
crop_current[move_idx[0]] += 4
|
|
||||||
crop_current[move_idx[1]] += 4
|
|
||||||
|
|
||||||
x_mid = int(crop_best[0] + settings.crop_width / 2)
|
|
||||||
y_mid = int(crop_best[1] + settings.crop_height / 2)
|
|
||||||
|
|
||||||
return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
|
|
||||||
|
|
||||||
|
|
||||||
def image_entropy(im):
|
|
||||||
# greyscale image entropy
|
|
||||||
# band = np.asarray(im.convert("L"))
|
|
||||||
band = np.asarray(im.convert("1"), dtype=np.uint8)
|
|
||||||
hist, _ = np.histogram(band, bins=range(0, 256))
|
|
||||||
hist = hist[hist > 0]
|
|
||||||
return -np.log2(hist / hist.sum()).sum()
|
|
||||||
|
|
||||||
|
|
||||||
def centroid(pois):
|
|
||||||
x = [poi.x for poi in pois]
|
|
||||||
y = [poi.y for poi in pois]
|
|
||||||
return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
|
|
||||||
|
|
||||||
|
|
||||||
def poi_average(pois, settings):
|
|
||||||
weight = 0.0
|
|
||||||
x = 0.0
|
|
||||||
y = 0.0
|
|
||||||
for poi in pois:
|
|
||||||
weight += poi.weight
|
|
||||||
x += poi.x * poi.weight
|
|
||||||
y += poi.y * poi.weight
|
|
||||||
avg_x = round(weight and x / weight)
|
|
||||||
avg_y = round(weight and y / weight)
|
|
||||||
|
|
||||||
return PointOfInterest(avg_x, avg_y)
|
|
||||||
|
|
||||||
|
|
||||||
def is_landscape(w, h):
|
|
||||||
return w > h
|
|
||||||
|
|
||||||
|
|
||||||
def is_portrait(w, h):
|
|
||||||
return h > w
|
|
||||||
|
|
||||||
|
|
||||||
def is_square(w, h):
|
|
||||||
return w == h
|
|
||||||
|
|
||||||
|
|
||||||
model_dir_opencv = os.path.join(paths_internal.models_path, 'opencv')
|
|
||||||
if parse_version(cv2.__version__) >= parse_version('4.8'):
|
|
||||||
model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx')
|
|
||||||
model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true'
|
|
||||||
else:
|
|
||||||
model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx')
|
|
||||||
model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_cache_models():
|
|
||||||
if not os.path.exists(model_file_path):
|
|
||||||
os.makedirs(model_dir_opencv, exist_ok=True)
|
|
||||||
print(f"downloading face detection model from '{model_url}' to '{model_file_path}'")
|
|
||||||
response = requests.get(model_url)
|
|
||||||
with open(model_file_path, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
return model_file_path
|
|
||||||
|
|
||||||
|
|
||||||
class PointOfInterest:
|
|
||||||
def __init__(self, x, y, weight=1.0, size=10):
|
|
||||||
self.x = x
|
|
||||||
self.y = y
|
|
||||||
self.weight = weight
|
|
||||||
self.size = size
|
|
||||||
|
|
||||||
def bounding(self, size):
|
|
||||||
return [
|
|
||||||
self.x - size // 2,
|
|
||||||
self.y - size // 2,
|
|
||||||
self.x + size // 2,
|
|
||||||
self.y + size // 2
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Settings:
|
|
||||||
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
|
||||||
self.crop_width = crop_width
|
|
||||||
self.crop_height = crop_height
|
|
||||||
self.corner_points_weight = corner_points_weight
|
|
||||||
self.entropy_points_weight = entropy_points_weight
|
|
||||||
self.face_points_weight = face_points_weight
|
|
||||||
self.annotate_image = annotate_image
|
|
||||||
self.desktop_view_image = False
|
|
||||||
self.dnn_model_path = dnn_model_path
|
|
||||||
@@ -1,243 +0,0 @@
|
|||||||
# import os
|
|
||||||
# import numpy as np
|
|
||||||
# import PIL
|
|
||||||
# import torch
|
|
||||||
# from torch.utils.data import Dataset, DataLoader, Sampler
|
|
||||||
# from torchvision import transforms
|
|
||||||
# from collections import defaultdict
|
|
||||||
# from random import shuffle, choices
|
|
||||||
#
|
|
||||||
# import random
|
|
||||||
# import tqdm
|
|
||||||
# from modules import devices, shared, images
|
|
||||||
# import re
|
|
||||||
#
|
|
||||||
# re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# class DatasetEntry:
|
|
||||||
# def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
|
|
||||||
# self.filename = filename
|
|
||||||
# self.filename_text = filename_text
|
|
||||||
# self.weight = weight
|
|
||||||
# self.latent_dist = latent_dist
|
|
||||||
# self.latent_sample = latent_sample
|
|
||||||
# self.cond = cond
|
|
||||||
# self.cond_text = cond_text
|
|
||||||
# self.pixel_values = pixel_values
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# class PersonalizedBase(Dataset):
|
|
||||||
# def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
|
|
||||||
# re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
|
|
||||||
#
|
|
||||||
# self.placeholder_token = placeholder_token
|
|
||||||
#
|
|
||||||
# self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
|
||||||
#
|
|
||||||
# self.dataset = []
|
|
||||||
#
|
|
||||||
# with open(template_file, "r") as file:
|
|
||||||
# lines = [x.strip() for x in file.readlines()]
|
|
||||||
#
|
|
||||||
# self.lines = lines
|
|
||||||
#
|
|
||||||
# assert data_root, 'dataset directory not specified'
|
|
||||||
# assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
|
||||||
# assert os.listdir(data_root), "Dataset directory is empty"
|
|
||||||
#
|
|
||||||
# self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
|
||||||
#
|
|
||||||
# self.shuffle_tags = shuffle_tags
|
|
||||||
# self.tag_drop_out = tag_drop_out
|
|
||||||
# groups = defaultdict(list)
|
|
||||||
#
|
|
||||||
# print("Preparing dataset...")
|
|
||||||
# for path in tqdm.tqdm(self.image_paths):
|
|
||||||
# alpha_channel = None
|
|
||||||
# if shared.state.interrupted:
|
|
||||||
# raise Exception("interrupted")
|
|
||||||
# try:
|
|
||||||
# image = images.read(path)
|
|
||||||
# #Currently does not work for single color transparency
|
|
||||||
# #We would need to read image.info['transparency'] for that
|
|
||||||
# if use_weight and 'A' in image.getbands():
|
|
||||||
# alpha_channel = image.getchannel('A')
|
|
||||||
# image = image.convert('RGB')
|
|
||||||
# if not varsize:
|
|
||||||
# image = image.resize((width, height), PIL.Image.BICUBIC)
|
|
||||||
# except Exception:
|
|
||||||
# continue
|
|
||||||
#
|
|
||||||
# text_filename = f"{os.path.splitext(path)[0]}.txt"
|
|
||||||
# filename = os.path.basename(path)
|
|
||||||
#
|
|
||||||
# if os.path.exists(text_filename):
|
|
||||||
# with open(text_filename, "r", encoding="utf8") as file:
|
|
||||||
# filename_text = file.read()
|
|
||||||
# else:
|
|
||||||
# filename_text = os.path.splitext(filename)[0]
|
|
||||||
# filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
|
||||||
# if re_word:
|
|
||||||
# tokens = re_word.findall(filename_text)
|
|
||||||
# filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
|
||||||
#
|
|
||||||
# npimage = np.array(image).astype(np.uint8)
|
|
||||||
# npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
|
||||||
#
|
|
||||||
# torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
|
||||||
# latent_sample = None
|
|
||||||
#
|
|
||||||
# with devices.autocast():
|
|
||||||
# latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
|
||||||
#
|
|
||||||
# #Perform latent sampling, even for random sampling.
|
|
||||||
# #We need the sample dimensions for the weights
|
|
||||||
# if latent_sampling_method == "deterministic":
|
|
||||||
# if isinstance(latent_dist, DiagonalGaussianDistribution):
|
|
||||||
# # Works only for DiagonalGaussianDistribution
|
|
||||||
# latent_dist.std = 0
|
|
||||||
# else:
|
|
||||||
# latent_sampling_method = "once"
|
|
||||||
# latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
|
||||||
#
|
|
||||||
# if use_weight and alpha_channel is not None:
|
|
||||||
# channels, *latent_size = latent_sample.shape
|
|
||||||
# weight_img = alpha_channel.resize(latent_size)
|
|
||||||
# npweight = np.array(weight_img).astype(np.float32)
|
|
||||||
# #Repeat for every channel in the latent sample
|
|
||||||
# weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
|
|
||||||
# #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
|
|
||||||
# weight -= weight.min()
|
|
||||||
# weight /= weight.mean()
|
|
||||||
# elif use_weight:
|
|
||||||
# #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
|
|
||||||
# weight = torch.ones(latent_sample.shape)
|
|
||||||
# else:
|
|
||||||
# weight = None
|
|
||||||
#
|
|
||||||
# if latent_sampling_method == "random":
|
|
||||||
# entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
|
||||||
# else:
|
|
||||||
# entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
|
|
||||||
#
|
|
||||||
# if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
|
||||||
# entry.cond_text = self.create_text(filename_text)
|
|
||||||
#
|
|
||||||
# if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
|
||||||
# with devices.autocast():
|
|
||||||
# entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
|
||||||
# groups[image.size].append(len(self.dataset))
|
|
||||||
# self.dataset.append(entry)
|
|
||||||
# del torchdata
|
|
||||||
# del latent_dist
|
|
||||||
# del latent_sample
|
|
||||||
# del weight
|
|
||||||
#
|
|
||||||
# self.length = len(self.dataset)
|
|
||||||
# self.groups = list(groups.values())
|
|
||||||
# assert self.length > 0, "No images have been found in the dataset."
|
|
||||||
# self.batch_size = min(batch_size, self.length)
|
|
||||||
# self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
|
||||||
# self.latent_sampling_method = latent_sampling_method
|
|
||||||
#
|
|
||||||
# if len(groups) > 1:
|
|
||||||
# print("Buckets:")
|
|
||||||
# for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
|
|
||||||
# print(f" {w}x{h}: {len(ids)}")
|
|
||||||
# print()
|
|
||||||
#
|
|
||||||
# def create_text(self, filename_text):
|
|
||||||
# text = random.choice(self.lines)
|
|
||||||
# tags = filename_text.split(',')
|
|
||||||
# if self.tag_drop_out != 0:
|
|
||||||
# tags = [t for t in tags if random.random() > self.tag_drop_out]
|
|
||||||
# if self.shuffle_tags:
|
|
||||||
# random.shuffle(tags)
|
|
||||||
# text = text.replace("[filewords]", ','.join(tags))
|
|
||||||
# text = text.replace("[name]", self.placeholder_token)
|
|
||||||
# return text
|
|
||||||
#
|
|
||||||
# def __len__(self):
|
|
||||||
# return self.length
|
|
||||||
#
|
|
||||||
# def __getitem__(self, i):
|
|
||||||
# entry = self.dataset[i]
|
|
||||||
# if self.tag_drop_out != 0 or self.shuffle_tags:
|
|
||||||
# entry.cond_text = self.create_text(entry.filename_text)
|
|
||||||
# if self.latent_sampling_method == "random":
|
|
||||||
# entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
|
||||||
# return entry
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# class GroupedBatchSampler(Sampler):
|
|
||||||
# def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
|
||||||
# super().__init__(data_source)
|
|
||||||
#
|
|
||||||
# n = len(data_source)
|
|
||||||
# self.groups = data_source.groups
|
|
||||||
# self.len = n_batch = n // batch_size
|
|
||||||
# expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
|
||||||
# self.base = [int(e) // batch_size for e in expected]
|
|
||||||
# self.n_rand_batches = nrb = n_batch - sum(self.base)
|
|
||||||
# self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
|
||||||
# self.batch_size = batch_size
|
|
||||||
#
|
|
||||||
# def __len__(self):
|
|
||||||
# return self.len
|
|
||||||
#
|
|
||||||
# def __iter__(self):
|
|
||||||
# b = self.batch_size
|
|
||||||
#
|
|
||||||
# for g in self.groups:
|
|
||||||
# shuffle(g)
|
|
||||||
#
|
|
||||||
# batches = []
|
|
||||||
# for g in self.groups:
|
|
||||||
# batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
|
||||||
# for _ in range(self.n_rand_batches):
|
|
||||||
# rand_group = choices(self.groups, self.probs)[0]
|
|
||||||
# batches.append(choices(rand_group, k=b))
|
|
||||||
#
|
|
||||||
# shuffle(batches)
|
|
||||||
#
|
|
||||||
# yield from batches
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# class PersonalizedDataLoader(DataLoader):
|
|
||||||
# def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
|
||||||
# super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
|
||||||
# if latent_sampling_method == "random":
|
|
||||||
# self.collate_fn = collate_wrapper_random
|
|
||||||
# else:
|
|
||||||
# self.collate_fn = collate_wrapper
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# class BatchLoader:
|
|
||||||
# def __init__(self, data):
|
|
||||||
# self.cond_text = [entry.cond_text for entry in data]
|
|
||||||
# self.cond = [entry.cond for entry in data]
|
|
||||||
# self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
|
||||||
# if all(entry.weight is not None for entry in data):
|
|
||||||
# self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
|
|
||||||
# else:
|
|
||||||
# self.weight = None
|
|
||||||
# #self.emb_index = [entry.emb_index for entry in data]
|
|
||||||
# #print(self.latent_sample.device)
|
|
||||||
#
|
|
||||||
# def pin_memory(self):
|
|
||||||
# self.latent_sample = self.latent_sample.pin_memory()
|
|
||||||
# return self
|
|
||||||
#
|
|
||||||
# def collate_wrapper(batch):
|
|
||||||
# return BatchLoader(batch)
|
|
||||||
#
|
|
||||||
# class BatchLoaderRandom(BatchLoader):
|
|
||||||
# def __init__(self, data):
|
|
||||||
# super().__init__(data)
|
|
||||||
#
|
|
||||||
# def pin_memory(self):
|
|
||||||
# return self
|
|
||||||
#
|
|
||||||
# def collate_wrapper_random(batch):
|
|
||||||
# return BatchLoaderRandom(batch)
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class LearnScheduleIterator:
|
|
||||||
def __init__(self, learn_rate, max_steps, cur_step=0):
|
|
||||||
"""
|
|
||||||
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
|
|
||||||
"""
|
|
||||||
|
|
||||||
pairs = learn_rate.split(',')
|
|
||||||
self.rates = []
|
|
||||||
self.it = 0
|
|
||||||
self.maxit = 0
|
|
||||||
try:
|
|
||||||
for pair in pairs:
|
|
||||||
if not pair.strip():
|
|
||||||
continue
|
|
||||||
tmp = pair.split(':')
|
|
||||||
if len(tmp) == 2:
|
|
||||||
step = int(tmp[1])
|
|
||||||
if step > cur_step:
|
|
||||||
self.rates.append((float(tmp[0]), min(step, max_steps)))
|
|
||||||
self.maxit += 1
|
|
||||||
if step > max_steps:
|
|
||||||
return
|
|
||||||
elif step == -1:
|
|
||||||
self.rates.append((float(tmp[0]), max_steps))
|
|
||||||
self.maxit += 1
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.rates.append((float(tmp[0]), max_steps))
|
|
||||||
self.maxit += 1
|
|
||||||
return
|
|
||||||
assert self.rates
|
|
||||||
except (ValueError, AssertionError) as e:
|
|
||||||
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
|
|
||||||
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
if self.it < self.maxit:
|
|
||||||
self.it += 1
|
|
||||||
return self.rates[self.it - 1]
|
|
||||||
else:
|
|
||||||
raise StopIteration
|
|
||||||
|
|
||||||
|
|
||||||
class LearnRateScheduler:
|
|
||||||
def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
|
|
||||||
self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
|
|
||||||
(self.learn_rate, self.end_step) = next(self.schedules)
|
|
||||||
self.verbose = verbose
|
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
|
|
||||||
|
|
||||||
self.finished = False
|
|
||||||
|
|
||||||
def step(self, step_number):
|
|
||||||
if step_number < self.end_step:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
(self.learn_rate, self.end_step) = next(self.schedules)
|
|
||||||
except StopIteration:
|
|
||||||
self.finished = True
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def apply(self, optimizer, step_number):
|
|
||||||
if not self.step(step_number):
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
|
|
||||||
|
|
||||||
for pg in optimizer.param_groups:
|
|
||||||
pg['lr'] = self.learn_rate
|
|
||||||
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
saved_params_shared = {
|
|
||||||
"batch_size",
|
|
||||||
"clip_grad_mode",
|
|
||||||
"clip_grad_value",
|
|
||||||
"create_image_every",
|
|
||||||
"data_root",
|
|
||||||
"gradient_step",
|
|
||||||
"initial_step",
|
|
||||||
"latent_sampling_method",
|
|
||||||
"learn_rate",
|
|
||||||
"log_directory",
|
|
||||||
"model_hash",
|
|
||||||
"model_name",
|
|
||||||
"num_of_dataset_images",
|
|
||||||
"steps",
|
|
||||||
"template_file",
|
|
||||||
"training_height",
|
|
||||||
"training_width",
|
|
||||||
}
|
|
||||||
saved_params_ti = {
|
|
||||||
"embedding_name",
|
|
||||||
"num_vectors_per_token",
|
|
||||||
"save_embedding_every",
|
|
||||||
"save_image_with_stored_embedding",
|
|
||||||
}
|
|
||||||
saved_params_hypernet = {
|
|
||||||
"activation_func",
|
|
||||||
"add_layer_norm",
|
|
||||||
"hypernetwork_name",
|
|
||||||
"layer_structure",
|
|
||||||
"save_hypernetwork_every",
|
|
||||||
"use_dropout",
|
|
||||||
"weight_init",
|
|
||||||
}
|
|
||||||
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
|
||||||
saved_params_previews = {
|
|
||||||
"preview_cfg_scale",
|
|
||||||
"preview_height",
|
|
||||||
"preview_negative_prompt",
|
|
||||||
"preview_prompt",
|
|
||||||
"preview_sampler_index",
|
|
||||||
"preview_seed",
|
|
||||||
"preview_steps",
|
|
||||||
"preview_width",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def save_settings_to_file(log_directory, all_params):
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
|
|
||||||
|
|
||||||
keys = saved_params_all
|
|
||||||
if all_params.get('preview_from_txt2img'):
|
|
||||||
keys = keys | saved_params_previews
|
|
||||||
|
|
||||||
params.update({k: v for k, v in all_params.items() if k in keys})
|
|
||||||
|
|
||||||
filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json'
|
|
||||||
with open(os.path.join(log_directory, filename), "w") as file:
|
|
||||||
json.dump(params, file, indent=4)
|
|
||||||
@@ -13,11 +13,8 @@ import numpy as np
|
|||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
||||||
import modules.textual_inversion.dataset
|
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
|
||||||
|
|
||||||
from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
|
from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
|
||||||
from modules.textual_inversion.saving_settings import save_settings_to_file
|
|
||||||
|
|
||||||
|
|
||||||
TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
|
TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
|
||||||
@@ -320,388 +317,3 @@ def create_embedding_from_data(data, name, filename='unknown embedding file', fi
|
|||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
def write_loss(log_directory, filename, step, epoch_len, values):
|
|
||||||
if shared.opts.training_write_csv_every == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
if step % shared.opts.training_write_csv_every != 0:
|
|
||||||
return
|
|
||||||
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
|
|
||||||
|
|
||||||
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
|
|
||||||
csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
|
|
||||||
|
|
||||||
if write_csv_header:
|
|
||||||
csv_writer.writeheader()
|
|
||||||
|
|
||||||
epoch = (step - 1) // epoch_len
|
|
||||||
epoch_step = (step - 1) % epoch_len
|
|
||||||
|
|
||||||
csv_writer.writerow({
|
|
||||||
"step": step,
|
|
||||||
"epoch": epoch,
|
|
||||||
"epoch_step": epoch_step,
|
|
||||||
**values,
|
|
||||||
})
|
|
||||||
|
|
||||||
def tensorboard_setup(log_directory):
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
|
|
||||||
return SummaryWriter(
|
|
||||||
log_dir=os.path.join(log_directory, "tensorboard"),
|
|
||||||
flush_secs=shared.opts.training_tensorboard_flush_every)
|
|
||||||
|
|
||||||
def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
|
|
||||||
tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
|
|
||||||
tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
|
|
||||||
tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
|
|
||||||
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
|
|
||||||
|
|
||||||
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
|
|
||||||
tensorboard_writer.add_scalar(tag=tag,
|
|
||||||
scalar_value=value, global_step=step)
|
|
||||||
|
|
||||||
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
|
|
||||||
# Convert a pil image to a torch tensor
|
|
||||||
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
|
|
||||||
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
|
||||||
len(pil_image.getbands()))
|
|
||||||
img_tensor = img_tensor.permute((2, 0, 1))
|
|
||||||
|
|
||||||
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
|
|
||||||
|
|
||||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
|
||||||
assert model_name, f"{name} not selected"
|
|
||||||
assert learn_rate, "Learning rate is empty or 0"
|
|
||||||
assert isinstance(batch_size, int), "Batch size must be integer"
|
|
||||||
assert batch_size > 0, "Batch size must be positive"
|
|
||||||
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
|
|
||||||
assert gradient_step > 0, "Gradient accumulation step must be positive"
|
|
||||||
assert data_root, "Dataset directory is empty"
|
|
||||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
|
||||||
assert os.listdir(data_root), "Dataset directory is empty"
|
|
||||||
assert template_filename, "Prompt template file not selected"
|
|
||||||
assert template_file, f"Prompt template file {template_filename} not found"
|
|
||||||
assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
|
|
||||||
assert steps, "Max steps is empty or 0"
|
|
||||||
assert isinstance(steps, int), "Max steps must be integer"
|
|
||||||
assert steps > 0, "Max steps must be positive"
|
|
||||||
assert isinstance(save_model_every, int), "Save {name} must be integer"
|
|
||||||
assert save_model_every >= 0, "Save {name} must be positive or 0"
|
|
||||||
assert isinstance(create_image_every, int), "Create image must be integer"
|
|
||||||
assert create_image_every >= 0, "Create image must be positive or 0"
|
|
||||||
if save_model_every or create_image_every:
|
|
||||||
assert log_directory, "Log directory is empty"
|
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_name, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
|
||||||
from modules import processing
|
|
||||||
|
|
||||||
save_embedding_every = save_embedding_every or 0
|
|
||||||
create_image_every = create_image_every or 0
|
|
||||||
template_file = textual_inversion_templates.get(template_filename, None)
|
|
||||||
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
|
||||||
template_file = template_file.path
|
|
||||||
|
|
||||||
shared.state.job = "train-embedding"
|
|
||||||
shared.state.textinfo = "Initializing textual inversion training..."
|
|
||||||
shared.state.job_count = steps
|
|
||||||
|
|
||||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
|
||||||
|
|
||||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
|
|
||||||
unload = shared.opts.unload_models_when_training
|
|
||||||
|
|
||||||
if save_embedding_every > 0:
|
|
||||||
embedding_dir = os.path.join(log_directory, "embeddings")
|
|
||||||
os.makedirs(embedding_dir, exist_ok=True)
|
|
||||||
else:
|
|
||||||
embedding_dir = None
|
|
||||||
|
|
||||||
if create_image_every > 0:
|
|
||||||
images_dir = os.path.join(log_directory, "images")
|
|
||||||
os.makedirs(images_dir, exist_ok=True)
|
|
||||||
else:
|
|
||||||
images_dir = None
|
|
||||||
|
|
||||||
if create_image_every > 0 and save_image_with_stored_embedding:
|
|
||||||
images_embeds_dir = os.path.join(log_directory, "image_embeddings")
|
|
||||||
os.makedirs(images_embeds_dir, exist_ok=True)
|
|
||||||
else:
|
|
||||||
images_embeds_dir = None
|
|
||||||
|
|
||||||
hijack = sd_hijack.model_hijack
|
|
||||||
|
|
||||||
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
|
||||||
checkpoint = sd_models.select_checkpoint()
|
|
||||||
|
|
||||||
initial_step = embedding.step or 0
|
|
||||||
if initial_step >= steps:
|
|
||||||
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
|
||||||
return embedding, filename
|
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
|
||||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
|
||||||
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
|
||||||
None
|
|
||||||
if clip_grad:
|
|
||||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
|
||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
|
||||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
|
||||||
|
|
||||||
tensorboard_writer = None
|
|
||||||
if shared.opts.training_enable_tensorboard:
|
|
||||||
try:
|
|
||||||
tensorboard_writer = tensorboard_setup(log_directory)
|
|
||||||
except ImportError:
|
|
||||||
errors.report("Error initializing tensorboard", exc_info=True)
|
|
||||||
|
|
||||||
pin_memory = shared.opts.pin_memory
|
|
||||||
|
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
|
|
||||||
|
|
||||||
if shared.opts.save_training_settings_to_txt:
|
|
||||||
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
|
|
||||||
|
|
||||||
latent_sampling_method = ds.latent_sampling_method
|
|
||||||
|
|
||||||
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
|
||||||
|
|
||||||
if unload:
|
|
||||||
shared.parallel_processing_allowed = False
|
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
||||||
|
|
||||||
embedding.vec.requires_grad = True
|
|
||||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
|
||||||
if shared.opts.save_optimizer_state:
|
|
||||||
optimizer_state_dict = None
|
|
||||||
if os.path.exists(f"{filename}.optim"):
|
|
||||||
optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
|
|
||||||
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
|
||||||
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
|
||||||
|
|
||||||
if optimizer_state_dict is not None:
|
|
||||||
optimizer.load_state_dict(optimizer_state_dict)
|
|
||||||
print("Loaded existing optimizer from checkpoint")
|
|
||||||
else:
|
|
||||||
print("No saved optimizer exists in checkpoint")
|
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
|
||||||
|
|
||||||
batch_size = ds.batch_size
|
|
||||||
gradient_step = ds.gradient_step
|
|
||||||
# n steps = batch_size * gradient_step * n image processed
|
|
||||||
steps_per_epoch = len(ds) // batch_size // gradient_step
|
|
||||||
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
|
||||||
loss_step = 0
|
|
||||||
_loss_step = 0 #internal
|
|
||||||
|
|
||||||
last_saved_file = "<none>"
|
|
||||||
last_saved_image = "<none>"
|
|
||||||
forced_filename = "<none>"
|
|
||||||
embedding_yet_to_be_embedded = False
|
|
||||||
|
|
||||||
is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
|
||||||
img_c = None
|
|
||||||
|
|
||||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
|
||||||
try:
|
|
||||||
sd_hijack_checkpoint.add()
|
|
||||||
|
|
||||||
for _ in range((steps-initial_step) * gradient_step):
|
|
||||||
if scheduler.finished:
|
|
||||||
break
|
|
||||||
if shared.state.interrupted:
|
|
||||||
break
|
|
||||||
for j, batch in enumerate(dl):
|
|
||||||
# works as a drop_last=True for gradient accumulation
|
|
||||||
if j == max_steps_per_epoch:
|
|
||||||
break
|
|
||||||
scheduler.apply(optimizer, embedding.step)
|
|
||||||
if scheduler.finished:
|
|
||||||
break
|
|
||||||
if shared.state.interrupted:
|
|
||||||
break
|
|
||||||
|
|
||||||
if clip_grad:
|
|
||||||
clip_grad_sched.step(embedding.step)
|
|
||||||
|
|
||||||
with devices.autocast():
|
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
|
||||||
if use_weight:
|
|
||||||
w = batch.weight.to(devices.device, non_blocking=pin_memory)
|
|
||||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
|
||||||
|
|
||||||
if is_training_inpainting_model:
|
|
||||||
if img_c is None:
|
|
||||||
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
|
|
||||||
|
|
||||||
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
|
||||||
else:
|
|
||||||
cond = c
|
|
||||||
|
|
||||||
if use_weight:
|
|
||||||
loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step
|
|
||||||
del w
|
|
||||||
else:
|
|
||||||
loss = shared.sd_model.forward(x, cond)[0] / gradient_step
|
|
||||||
del x
|
|
||||||
|
|
||||||
_loss_step += loss.item()
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
|
|
||||||
# go back until we reach gradient accumulation steps
|
|
||||||
if (j + 1) % gradient_step != 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if clip_grad:
|
|
||||||
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|
|
||||||
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
embedding.step += 1
|
|
||||||
pbar.update()
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
|
||||||
loss_step = _loss_step
|
|
||||||
_loss_step = 0
|
|
||||||
|
|
||||||
steps_done = embedding.step + 1
|
|
||||||
|
|
||||||
epoch_num = embedding.step // steps_per_epoch
|
|
||||||
epoch_step = embedding.step % steps_per_epoch
|
|
||||||
|
|
||||||
description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
|
|
||||||
pbar.set_description(description)
|
|
||||||
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
|
||||||
# Before saving, change name to match current checkpoint.
|
|
||||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
|
||||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
|
||||||
save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
|
||||||
embedding_yet_to_be_embedded = True
|
|
||||||
|
|
||||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
|
||||||
"loss": f"{loss_step:.7f}",
|
|
||||||
"learn_rate": scheduler.learn_rate
|
|
||||||
})
|
|
||||||
|
|
||||||
if images_dir is not None and steps_done % create_image_every == 0:
|
|
||||||
forced_filename = f'{embedding_name}-{steps_done}'
|
|
||||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
|
||||||
|
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
|
||||||
sd_model=shared.sd_model,
|
|
||||||
do_not_save_grid=True,
|
|
||||||
do_not_save_samples=True,
|
|
||||||
do_not_reload_embeddings=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if preview_from_txt2img:
|
|
||||||
p.prompt = preview_prompt
|
|
||||||
p.negative_prompt = preview_negative_prompt
|
|
||||||
p.steps = preview_steps
|
|
||||||
p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
|
|
||||||
p.cfg_scale = preview_cfg_scale
|
|
||||||
p.seed = preview_seed
|
|
||||||
p.width = preview_width
|
|
||||||
p.height = preview_height
|
|
||||||
else:
|
|
||||||
p.prompt = batch.cond_text[0]
|
|
||||||
p.steps = 20
|
|
||||||
p.width = training_width
|
|
||||||
p.height = training_height
|
|
||||||
|
|
||||||
preview_text = p.prompt
|
|
||||||
|
|
||||||
with closing(p):
|
|
||||||
processed = processing.process_images(p)
|
|
||||||
image = processed.images[0] if len(processed.images) > 0 else None
|
|
||||||
|
|
||||||
if unload:
|
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
||||||
|
|
||||||
if image is not None:
|
|
||||||
shared.state.assign_current_image(image)
|
|
||||||
|
|
||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
|
||||||
|
|
||||||
if tensorboard_writer and shared.opts.training_tensorboard_save_images:
|
|
||||||
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
|
|
||||||
|
|
||||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
|
||||||
|
|
||||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
|
||||||
|
|
||||||
info = PngImagePlugin.PngInfo()
|
|
||||||
data = torch.load(last_saved_file)
|
|
||||||
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
|
||||||
|
|
||||||
title = f"<{data.get('name', '???')}>"
|
|
||||||
|
|
||||||
try:
|
|
||||||
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
|
||||||
except Exception:
|
|
||||||
vectorSize = '?'
|
|
||||||
|
|
||||||
checkpoint = sd_models.select_checkpoint()
|
|
||||||
footer_left = checkpoint.model_name
|
|
||||||
footer_mid = f'[{checkpoint.shorthash}]'
|
|
||||||
footer_right = f'{vectorSize}v {steps_done}s'
|
|
||||||
|
|
||||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
|
||||||
captioned_image = insert_image_data_embed(captioned_image, data)
|
|
||||||
|
|
||||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
|
||||||
embedding_yet_to_be_embedded = False
|
|
||||||
|
|
||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
|
||||||
|
|
||||||
shared.state.job_no = embedding.step
|
|
||||||
|
|
||||||
shared.state.textinfo = f"""
|
|
||||||
<p>
|
|
||||||
Loss: {loss_step:.7f}<br/>
|
|
||||||
Step: {steps_done}<br/>
|
|
||||||
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
|
||||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
|
||||||
</p>
|
|
||||||
"""
|
|
||||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
|
||||||
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
|
||||||
except Exception:
|
|
||||||
errors.report("Error training embedding", exc_info=True)
|
|
||||||
finally:
|
|
||||||
pbar.leave = False
|
|
||||||
pbar.close()
|
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
|
||||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
|
||||||
sd_hijack_checkpoint.remove()
|
|
||||||
|
|
||||||
return embedding, filename
|
|
||||||
|
|
||||||
|
|
||||||
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
|
||||||
old_embedding_name = embedding.name
|
|
||||||
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
|
||||||
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
|
||||||
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
|
|
||||||
try:
|
|
||||||
embedding.sd_checkpoint = checkpoint.shorthash
|
|
||||||
embedding.sd_checkpoint_name = checkpoint.model_name
|
|
||||||
if remove_cached_checksum:
|
|
||||||
embedding.cached_checksum = None
|
|
||||||
embedding.name = embedding_name
|
|
||||||
embedding.optimizer_state_dict = optimizer.state_dict()
|
|
||||||
embedding.save(filename)
|
|
||||||
except:
|
|
||||||
embedding.sd_checkpoint = old_sd_checkpoint
|
|
||||||
embedding.sd_checkpoint_name = old_sd_checkpoint_name
|
|
||||||
embedding.name = old_embedding_name
|
|
||||||
embedding.cached_checksum = old_cached_checksum
|
|
||||||
raise
|
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
import base64
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
test_files_path = os.path.dirname(__file__) + "/test_files"
|
|
||||||
test_outputs_path = os.path.dirname(__file__) + "/test_outputs"
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
|
||||||
# We don't want to fail on Py.test command line arguments being
|
|
||||||
# parsed by webui:
|
|
||||||
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
|
|
||||||
|
|
||||||
|
|
||||||
def file_to_base64(filename):
|
|
||||||
with open(filename, "rb") as file:
|
|
||||||
data = file.read()
|
|
||||||
|
|
||||||
base64_str = str(base64.b64encode(data), "utf-8")
|
|
||||||
return "data:image/png;base64," + base64_str
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session") # session so we don't read this over and over
|
|
||||||
def img2img_basic_image_base64() -> str:
|
|
||||||
return file_to_base64(os.path.join(test_files_path, "img2img_basic.png"))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session") # session so we don't read this over and over
|
|
||||||
def mask_basic_image_base64() -> str:
|
|
||||||
return file_to_base64(os.path.join(test_files_path, "mask_basic.png"))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def initialize() -> None:
|
|
||||||
import webui # noqa: F401
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
def test_simple_upscaling_performed(base_url, img2img_basic_image_base64):
|
|
||||||
payload = {
|
|
||||||
"resize_mode": 0,
|
|
||||||
"show_extras_results": True,
|
|
||||||
"gfpgan_visibility": 0,
|
|
||||||
"codeformer_visibility": 0,
|
|
||||||
"codeformer_weight": 0,
|
|
||||||
"upscaling_resize": 2,
|
|
||||||
"upscaling_resize_w": 128,
|
|
||||||
"upscaling_resize_h": 128,
|
|
||||||
"upscaling_crop": True,
|
|
||||||
"upscaler_1": "Lanczos",
|
|
||||||
"upscaler_2": "None",
|
|
||||||
"extras_upscaler_2_visibility": 0,
|
|
||||||
"image": img2img_basic_image_base64,
|
|
||||||
}
|
|
||||||
assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
def test_png_info_performed(base_url, img2img_basic_image_base64):
|
|
||||||
payload = {
|
|
||||||
"image": img2img_basic_image_base64,
|
|
||||||
}
|
|
||||||
assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
def test_interrogate_performed(base_url, img2img_basic_image_base64):
|
|
||||||
payload = {
|
|
||||||
"image": img2img_basic_image_base64,
|
|
||||||
"model": "clip",
|
|
||||||
}
|
|
||||||
assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
import os
|
|
||||||
from test.conftest import test_files_path, test_outputs_path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("initialize")
|
|
||||||
@pytest.mark.parametrize("restorer_name", ["gfpgan", "codeformer"])
|
|
||||||
@pytest.mark.skip # Skip for forge.
|
|
||||||
def test_face_restorers(restorer_name):
|
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
if restorer_name == "gfpgan":
|
|
||||||
from modules import gfpgan_model
|
|
||||||
gfpgan_model.setup_model(shared.cmd_opts.gfpgan_models_path)
|
|
||||||
restorer = gfpgan_model.gfpgan_fix_faces
|
|
||||||
elif restorer_name == "codeformer":
|
|
||||||
from modules import codeformer_model
|
|
||||||
codeformer_model.setup_model(shared.cmd_opts.codeformer_models_path)
|
|
||||||
restorer = codeformer_model.codeformer.restore
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("...")
|
|
||||||
img = Image.open(os.path.join(test_files_path, "two-faces.jpg"))
|
|
||||||
np_img = np.array(img, dtype=np.uint8)
|
|
||||||
fixed_image = restorer(np_img)
|
|
||||||
assert fixed_image.shape == np_img.shape
|
|
||||||
assert not np.allclose(fixed_image, np_img) # should have visibly changed
|
|
||||||
Image.fromarray(fixed_image).save(os.path.join(test_outputs_path, f"{restorer_name}.png"))
|
|
||||||
Binary file not shown.
Binary file not shown.
|
Before Width: | Height: | Size: 9.7 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 362 B |
Binary file not shown.
|
Before Width: | Height: | Size: 14 KiB |
@@ -1,68 +0,0 @@
|
|||||||
|
|
||||||
import pytest
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def url_img2img(base_url):
|
|
||||||
return f"{base_url}/sdapi/v1/img2img"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def simple_img2img_request(img2img_basic_image_base64):
|
|
||||||
return {
|
|
||||||
"batch_size": 1,
|
|
||||||
"cfg_scale": 7,
|
|
||||||
"denoising_strength": 0.75,
|
|
||||||
"eta": 0,
|
|
||||||
"height": 64,
|
|
||||||
"include_init_images": False,
|
|
||||||
"init_images": [img2img_basic_image_base64],
|
|
||||||
"inpaint_full_res": False,
|
|
||||||
"inpaint_full_res_padding": 0,
|
|
||||||
"inpainting_fill": 0,
|
|
||||||
"inpainting_mask_invert": False,
|
|
||||||
"mask": None,
|
|
||||||
"mask_blur": 4,
|
|
||||||
"n_iter": 1,
|
|
||||||
"negative_prompt": "",
|
|
||||||
"override_settings": {},
|
|
||||||
"prompt": "example prompt",
|
|
||||||
"resize_mode": 0,
|
|
||||||
"restore_faces": False,
|
|
||||||
"s_churn": 0,
|
|
||||||
"s_noise": 1,
|
|
||||||
"s_tmax": 0,
|
|
||||||
"s_tmin": 0,
|
|
||||||
"sampler_index": "Euler a",
|
|
||||||
"seed": -1,
|
|
||||||
"seed_resize_from_h": -1,
|
|
||||||
"seed_resize_from_w": -1,
|
|
||||||
"steps": 3,
|
|
||||||
"styles": [],
|
|
||||||
"subseed": -1,
|
|
||||||
"subseed_strength": 0,
|
|
||||||
"tiling": False,
|
|
||||||
"width": 64,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_img2img_simple_performed(url_img2img, simple_img2img_request):
|
|
||||||
assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
def test_inpainting_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64):
|
|
||||||
simple_img2img_request["mask"] = mask_basic_image_base64
|
|
||||||
assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
def test_inpainting_with_inverted_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64):
|
|
||||||
simple_img2img_request["mask"] = mask_basic_image_base64
|
|
||||||
simple_img2img_request["inpainting_mask_invert"] = True
|
|
||||||
assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
def test_img2img_sd_upscale_performed(url_img2img, simple_img2img_request):
|
|
||||||
simple_img2img_request["script_name"] = "sd upscale"
|
|
||||||
simple_img2img_request["script_args"] = ["", 8, "Lanczos", 2.0]
|
|
||||||
assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
import types
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from modules import torch_utils
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("wrapped", [True, False])
|
|
||||||
def test_get_param(wrapped):
|
|
||||||
mod = torch.nn.Linear(1, 1)
|
|
||||||
cpu = torch.device("cpu")
|
|
||||||
mod.to(dtype=torch.float16, device=cpu)
|
|
||||||
if wrapped:
|
|
||||||
# more or less how spandrel wraps a thing
|
|
||||||
mod = types.SimpleNamespace(model=mod)
|
|
||||||
p = torch_utils.get_param(mod)
|
|
||||||
assert p.dtype == torch.float16
|
|
||||||
assert p.device == cpu
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
|
|
||||||
import pytest
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def url_txt2img(base_url):
|
|
||||||
return f"{base_url}/sdapi/v1/txt2img"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def simple_txt2img_request():
|
|
||||||
return {
|
|
||||||
"batch_size": 1,
|
|
||||||
"cfg_scale": 7,
|
|
||||||
"denoising_strength": 0,
|
|
||||||
"enable_hr": False,
|
|
||||||
"eta": 0,
|
|
||||||
"firstphase_height": 0,
|
|
||||||
"firstphase_width": 0,
|
|
||||||
"height": 64,
|
|
||||||
"n_iter": 1,
|
|
||||||
"negative_prompt": "",
|
|
||||||
"prompt": "example prompt",
|
|
||||||
"restore_faces": False,
|
|
||||||
"s_churn": 0,
|
|
||||||
"s_noise": 1,
|
|
||||||
"s_tmax": 0,
|
|
||||||
"s_tmin": 0,
|
|
||||||
"sampler_index": "Euler a",
|
|
||||||
"seed": -1,
|
|
||||||
"seed_resize_from_h": -1,
|
|
||||||
"seed_resize_from_w": -1,
|
|
||||||
"steps": 3,
|
|
||||||
"styles": [],
|
|
||||||
"subseed": -1,
|
|
||||||
"subseed_strength": 0,
|
|
||||||
"tiling": False,
|
|
||||||
"width": 64,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_txt2img_simple_performed(url_txt2img, simple_txt2img_request):
|
|
||||||
assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
def test_txt2img_with_negative_prompt_performed(url_txt2img, simple_txt2img_request):
|
|
||||||
simple_txt2img_request["negative_prompt"] = "example negative prompt"
|
|
||||||
assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
def test_txt2img_with_complex_prompt_performed(url_txt2img, simple_txt2img_request):
|
|
||||||
simple_txt2img_request["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]"
|
|
||||||
assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
def test_txt2img_not_square_image_performed(url_txt2img, simple_txt2img_request):
|
|
||||||
simple_txt2img_request["height"] = 128
|
|
||||||
assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
def test_txt2img_with_hrfix_performed(url_txt2img, simple_txt2img_request):
|
|
||||||
simple_txt2img_request["enable_hr"] = True
|
|
||||||
assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
def test_txt2img_with_tiling_performed(url_txt2img, simple_txt2img_request):
|
|
||||||
simple_txt2img_request["tiling"] = True
|
|
||||||
assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip # Skip for forge.
|
|
||||||
def test_txt2img_with_restore_faces_performed(url_txt2img, simple_txt2img_request):
|
|
||||||
simple_txt2img_request["restore_faces"] = True
|
|
||||||
assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sampler", ["PLMS", "DDIM", "UniPC"])
|
|
||||||
def test_txt2img_with_vanilla_sampler_performed(url_txt2img, simple_txt2img_request, sampler):
|
|
||||||
simple_txt2img_request["sampler_index"] = sampler
|
|
||||||
assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
def test_txt2img_multiple_batches_performed(url_txt2img, simple_txt2img_request):
|
|
||||||
simple_txt2img_request["n_iter"] = 2
|
|
||||||
assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
def test_txt2img_batch_performed(url_txt2img, simple_txt2img_request):
|
|
||||||
simple_txt2img_request["batch_size"] = 2
|
|
||||||
assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
def test_options_write(base_url):
|
|
||||||
url_options = f"{base_url}/sdapi/v1/options"
|
|
||||||
response = requests.get(url_options)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
pre_value = response.json()["send_seed"]
|
|
||||||
|
|
||||||
assert requests.post(url_options, json={'send_seed': (not pre_value)}).status_code == 200
|
|
||||||
|
|
||||||
response = requests.get(url_options)
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()['send_seed'] == (not pre_value)
|
|
||||||
|
|
||||||
requests.post(url_options, json={"send_seed": pre_value})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("url", [
|
|
||||||
"sdapi/v1/cmd-flags",
|
|
||||||
"sdapi/v1/samplers",
|
|
||||||
"sdapi/v1/upscalers",
|
|
||||||
"sdapi/v1/sd-models",
|
|
||||||
"sdapi/v1/hypernetworks",
|
|
||||||
"sdapi/v1/face-restorers",
|
|
||||||
"sdapi/v1/realesrgan-models",
|
|
||||||
"sdapi/v1/prompt-styles",
|
|
||||||
"sdapi/v1/embeddings",
|
|
||||||
])
|
|
||||||
def test_get_api_url(base_url, url):
|
|
||||||
assert requests.get(f"{base_url}/{url}").status_code == 200
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
a photo of a [filewords]
|
|
||||||
a rendering of a [filewords]
|
|
||||||
a cropped photo of the [filewords]
|
|
||||||
the photo of a [filewords]
|
|
||||||
a photo of a clean [filewords]
|
|
||||||
a photo of a dirty [filewords]
|
|
||||||
a dark photo of the [filewords]
|
|
||||||
a photo of my [filewords]
|
|
||||||
a photo of the cool [filewords]
|
|
||||||
a close-up photo of a [filewords]
|
|
||||||
a bright photo of the [filewords]
|
|
||||||
a cropped photo of a [filewords]
|
|
||||||
a photo of the [filewords]
|
|
||||||
a good photo of the [filewords]
|
|
||||||
a photo of one [filewords]
|
|
||||||
a close-up photo of the [filewords]
|
|
||||||
a rendition of the [filewords]
|
|
||||||
a photo of the clean [filewords]
|
|
||||||
a rendition of a [filewords]
|
|
||||||
a photo of a nice [filewords]
|
|
||||||
a good photo of a [filewords]
|
|
||||||
a photo of the nice [filewords]
|
|
||||||
a photo of the small [filewords]
|
|
||||||
a photo of the weird [filewords]
|
|
||||||
a photo of the large [filewords]
|
|
||||||
a photo of a cool [filewords]
|
|
||||||
a photo of a small [filewords]
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
picture
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
a painting, art by [name]
|
|
||||||
a rendering, art by [name]
|
|
||||||
a cropped painting, art by [name]
|
|
||||||
the painting, art by [name]
|
|
||||||
a clean painting, art by [name]
|
|
||||||
a dirty painting, art by [name]
|
|
||||||
a dark painting, art by [name]
|
|
||||||
a picture, art by [name]
|
|
||||||
a cool painting, art by [name]
|
|
||||||
a close-up painting, art by [name]
|
|
||||||
a bright painting, art by [name]
|
|
||||||
a cropped painting, art by [name]
|
|
||||||
a good painting, art by [name]
|
|
||||||
a close-up painting, art by [name]
|
|
||||||
a rendition, art by [name]
|
|
||||||
a nice painting, art by [name]
|
|
||||||
a small painting, art by [name]
|
|
||||||
a weird painting, art by [name]
|
|
||||||
a large painting, art by [name]
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
a painting of [filewords], art by [name]
|
|
||||||
a rendering of [filewords], art by [name]
|
|
||||||
a cropped painting of [filewords], art by [name]
|
|
||||||
the painting of [filewords], art by [name]
|
|
||||||
a clean painting of [filewords], art by [name]
|
|
||||||
a dirty painting of [filewords], art by [name]
|
|
||||||
a dark painting of [filewords], art by [name]
|
|
||||||
a picture of [filewords], art by [name]
|
|
||||||
a cool painting of [filewords], art by [name]
|
|
||||||
a close-up painting of [filewords], art by [name]
|
|
||||||
a bright painting of [filewords], art by [name]
|
|
||||||
a cropped painting of [filewords], art by [name]
|
|
||||||
a good painting of [filewords], art by [name]
|
|
||||||
a close-up painting of [filewords], art by [name]
|
|
||||||
a rendition of [filewords], art by [name]
|
|
||||||
a nice painting of [filewords], art by [name]
|
|
||||||
a small painting of [filewords], art by [name]
|
|
||||||
a weird painting of [filewords], art by [name]
|
|
||||||
a large painting of [filewords], art by [name]
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
a photo of a [name]
|
|
||||||
a rendering of a [name]
|
|
||||||
a cropped photo of the [name]
|
|
||||||
the photo of a [name]
|
|
||||||
a photo of a clean [name]
|
|
||||||
a photo of a dirty [name]
|
|
||||||
a dark photo of the [name]
|
|
||||||
a photo of my [name]
|
|
||||||
a photo of the cool [name]
|
|
||||||
a close-up photo of a [name]
|
|
||||||
a bright photo of the [name]
|
|
||||||
a cropped photo of a [name]
|
|
||||||
a photo of the [name]
|
|
||||||
a good photo of the [name]
|
|
||||||
a photo of one [name]
|
|
||||||
a close-up photo of the [name]
|
|
||||||
a rendition of the [name]
|
|
||||||
a photo of the clean [name]
|
|
||||||
a rendition of a [name]
|
|
||||||
a photo of a nice [name]
|
|
||||||
a good photo of a [name]
|
|
||||||
a photo of the nice [name]
|
|
||||||
a photo of the small [name]
|
|
||||||
a photo of the weird [name]
|
|
||||||
a photo of the large [name]
|
|
||||||
a photo of a cool [name]
|
|
||||||
a photo of a small [name]
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
a photo of a [name], [filewords]
|
|
||||||
a rendering of a [name], [filewords]
|
|
||||||
a cropped photo of the [name], [filewords]
|
|
||||||
the photo of a [name], [filewords]
|
|
||||||
a photo of a clean [name], [filewords]
|
|
||||||
a photo of a dirty [name], [filewords]
|
|
||||||
a dark photo of the [name], [filewords]
|
|
||||||
a photo of my [name], [filewords]
|
|
||||||
a photo of the cool [name], [filewords]
|
|
||||||
a close-up photo of a [name], [filewords]
|
|
||||||
a bright photo of the [name], [filewords]
|
|
||||||
a cropped photo of a [name], [filewords]
|
|
||||||
a photo of the [name], [filewords]
|
|
||||||
a good photo of the [name], [filewords]
|
|
||||||
a photo of one [name], [filewords]
|
|
||||||
a close-up photo of the [name], [filewords]
|
|
||||||
a rendition of the [name], [filewords]
|
|
||||||
a photo of the clean [name], [filewords]
|
|
||||||
a rendition of a [name], [filewords]
|
|
||||||
a photo of a nice [name], [filewords]
|
|
||||||
a good photo of a [name], [filewords]
|
|
||||||
a photo of the nice [name], [filewords]
|
|
||||||
a photo of the small [name], [filewords]
|
|
||||||
a photo of the weird [name], [filewords]
|
|
||||||
a photo of the large [name], [filewords]
|
|
||||||
a photo of a cool [name], [filewords]
|
|
||||||
a photo of a small [name], [filewords]
|
|
||||||
Reference in New Issue
Block a user