Merge branch 'master' into hn-activation

This commit is contained in:
guaneec
2022-10-26 14:58:04 +08:00
committed by GitHub
11 changed files with 876 additions and 32 deletions

View File

@@ -7,6 +7,7 @@ import uvicorn
from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
from typing import List
import json
import io
import base64
@@ -15,12 +16,12 @@ from PIL import Image
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
@@ -65,7 +66,7 @@ class Api:
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
@@ -111,7 +112,11 @@ class Api:
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info))
if (not img2imgreq.include_init_images):
img2imgreq.init_images = None
img2imgreq.mask = None
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
def extrasapi(self):
raise NotImplementedError

View File

@@ -31,6 +31,7 @@ class ModelDef(BaseModel):
field_alias: str
field_type: Any
field_value: Any
field_exclude: bool = False
class PydanticModelGenerator:
@@ -78,7 +79,8 @@ class PydanticModelGenerator:
field=underscore(fields["key"]),
field_alias=fields["key"],
field_type=fields["type"],
field_value=fields["default"]))
field_value=fields["default"],
field_exclude=fields["exclude"] if "exclude" in fields else False))
def generate_model(self):
"""
@@ -86,7 +88,7 @@ class PydanticModelGenerator:
from the json and overrides provided at initialization
"""
fields = {
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
}
DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True
@@ -102,5 +104,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model()

View File

@@ -5,6 +5,7 @@ import html
import os
import sys
import traceback
import inspect
import modules.textual_inversion.dataset
import torch
@@ -15,10 +16,12 @@ from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
from collections import defaultdict, deque
from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -26,9 +29,12 @@ class HypernetworkModule(torch.nn.Module):
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
}
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False):
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False, activate_output=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
@@ -65,9 +71,24 @@ class HypernetworkModule(torch.nn.Module):
else:
for layer in self.linear:
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
w, b = layer.weight.data, layer.bias.data
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
normal_(w, mean=0.0, std=0.01)
normal_(b, mean=0.0, std=0.005)
elif weight_init == 'XavierUniform':
xavier_uniform_(w)
zeros_(b)
elif weight_init == 'XavierNormal':
xavier_normal_(w)
zeros_(b)
elif weight_init == 'KaimingUniform':
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
zeros_(b)
elif weight_init == 'KaimingNormal':
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
zeros_(b)
else:
raise KeyError(f"Key {weight_init} is not defined as initialization!")
self.to(devices.device)
def fix_old_state_dict(self, state_dict):
@@ -105,7 +126,7 @@ class Hypernetwork:
filename = None
name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False):
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False)
self.filename = None
self.name = name
self.layers = {}
@@ -114,14 +135,15 @@ class Hypernetwork:
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
self.activation_func = activation_func
self.weight_init = weight_init
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
self.activate_output = activate_output
for size in enable_sizes or []:
self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
)
def weights(self):
@@ -145,6 +167,7 @@ class Hypernetwork:
state_dict['layer_structure'] = self.layer_structure
state_dict['activation_func'] = self.activation_func
state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['weight_initialization'] = self.weight_init
state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -160,16 +183,22 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
print(self.layer_structure)
self.activation_func = state_dict.get('activation_func', None)
print(f"Activation function is {self.activation_func}")
self.weight_init = state_dict.get('weight_initialization', 'Normal')
print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.use_dropout = state_dict.get('use_dropout', False)
print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False
print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True)
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
)
self.name = state_dict.get('name', self.name)

View File

@@ -8,8 +8,9 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -25,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
activation_func=activation_func,
weight_init=weight_init,
add_layer_norm=add_layer_norm,
use_dropout=use_dropout,
)

View File

@@ -277,7 +277,7 @@ invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
re_pattern = re.compile(r"([^\[\]]+|\[([^]]+)]|[\[\]]*)")
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
max_filename_part_length = 128
@@ -343,7 +343,7 @@ class FilenameGenerator:
def datetime(self, *args):
time_datetime = datetime.datetime.now()
time_format = args[0] if len(args) > 0 else self.default_time_format
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
try:
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
except pytz.exceptions.UnknownTimeZoneError as _:
@@ -362,9 +362,9 @@ class FilenameGenerator:
for m in re_pattern.finditer(x):
text, pattern = m.groups()
res += text
if pattern is None:
res += text
continue
pattern_args = []
@@ -385,12 +385,9 @@ class FilenameGenerator:
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if replacement is None:
res += f'[{pattern}]'
else:
if replacement is not None:
res += str(replacement)
continue
continue
res += f'[{pattern}]'

View File

@@ -0,0 +1,341 @@
import cv2
import requests
import os
from collections import defaultdict
from math import log, sqrt
import numpy as np
from PIL import Image, ImageDraw
GREEN = "#0F0"
BLUE = "#00F"
RED = "#F00"
def crop_image(im, settings):
""" Intelligently crop an image to the subject matter """
scale_by = 1
if is_landscape(im.width, im.height):
scale_by = settings.crop_height / im.height
elif is_portrait(im.width, im.height):
scale_by = settings.crop_width / im.width
elif is_square(im.width, im.height):
if is_square(settings.crop_width, settings.crop_height):
scale_by = settings.crop_width / im.width
elif is_landscape(settings.crop_width, settings.crop_height):
scale_by = settings.crop_width / im.width
elif is_portrait(settings.crop_width, settings.crop_height):
scale_by = settings.crop_height / im.height
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
im_debug = im.copy()
focus = focal_point(im_debug, settings)
# take the focal point and turn it into crop coordinates that try to center over the focal
# point but then get adjusted back into the frame
y_half = int(settings.crop_height / 2)
x_half = int(settings.crop_width / 2)
x1 = focus.x - x_half
if x1 < 0:
x1 = 0
elif x1 + settings.crop_width > im.width:
x1 = im.width - settings.crop_width
y1 = focus.y - y_half
if y1 < 0:
y1 = 0
elif y1 + settings.crop_height > im.height:
y1 = im.height - settings.crop_height
x2 = x1 + settings.crop_width
y2 = y1 + settings.crop_height
crop = [x1, y1, x2, y2]
results = []
results.append(im.crop(tuple(crop)))
if settings.annotate_image:
d = ImageDraw.Draw(im_debug)
rect = list(crop)
rect[2] -= 1
rect[3] -= 1
d.rectangle(rect, outline=GREEN)
results.append(im_debug)
if settings.destop_view_image:
im_debug.show()
return results
def focal_point(im, settings):
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
pois = []
weight_pref_total = 0
if len(corner_points) > 0:
weight_pref_total += settings.corner_points_weight
if len(entropy_points) > 0:
weight_pref_total += settings.entropy_points_weight
if len(face_points) > 0:
weight_pref_total += settings.face_points_weight
corner_centroid = None
if len(corner_points) > 0:
corner_centroid = centroid(corner_points)
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
pois.append(corner_centroid)
entropy_centroid = None
if len(entropy_points) > 0:
entropy_centroid = centroid(entropy_points)
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
pois.append(entropy_centroid)
face_centroid = None
if len(face_points) > 0:
face_centroid = centroid(face_points)
face_centroid.weight = settings.face_points_weight / weight_pref_total
pois.append(face_centroid)
average_point = poi_average(pois, settings)
if settings.annotate_image:
d = ImageDraw.Draw(im)
max_size = min(im.width, im.height) * 0.07
if corner_centroid is not None:
color = BLUE
box = corner_centroid.bounding(max_size * corner_centroid.weight)
d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
d.ellipse(box, outline=color)
if len(corner_points) > 1:
for f in corner_points:
d.rectangle(f.bounding(4), outline=color)
if entropy_centroid is not None:
color = "#ff0"
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
d.ellipse(box, outline=color)
if len(entropy_points) > 1:
for f in entropy_points:
d.rectangle(f.bounding(4), outline=color)
if face_centroid is not None:
color = RED
box = face_centroid.bounding(max_size * face_centroid.weight)
d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
d.ellipse(box, outline=color)
if len(face_points) > 1:
for f in face_points:
d.rectangle(f.bounding(4), outline=color)
d.ellipse(average_point.bounding(max_size), outline=GREEN)
return average_point
def image_face_points(im, settings):
if settings.dnn_model_path is not None:
detector = cv2.FaceDetectorYN.create(
settings.dnn_model_path,
"",
(im.width, im.height),
0.9, # score threshold
0.3, # nms threshold
5000 # keep top k before nms
)
faces = detector.detect(np.array(im))
results = []
if faces[1] is not None:
for face in faces[1]:
x = face[0]
y = face[1]
w = face[2]
h = face[3]
results.append(
PointOfInterest(
int(x + (w * 0.5)), # face focus left/right is center
int(y + (h * 0.33)), # face focus up/down is close to the top of the head
size = w,
weight = 1/len(faces[1])
)
)
return results
else:
np_im = np.array(im)
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
tries = [
[ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
[ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
[ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
[ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
[ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
]
for t in tries:
classifier = cv2.CascadeClassifier(t[0])
minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
try:
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
except:
continue
if len(faces) > 0:
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
return []
def image_corner_points(im, settings):
grayscale = im.convert("L")
# naive attempt at preventing focal points from collecting at watermarks near the bottom
gd = ImageDraw.Draw(grayscale)
gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
np_im = np.array(grayscale)
points = cv2.goodFeaturesToTrack(
np_im,
maxCorners=100,
qualityLevel=0.04,
minDistance=min(grayscale.width, grayscale.height)*0.06,
useHarrisDetector=False,
)
if points is None:
return []
focal_points = []
for point in points:
x, y = point.ravel()
focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
return focal_points
def image_entropy_points(im, settings):
landscape = im.height < im.width
portrait = im.height > im.width
if landscape:
move_idx = [0, 2]
move_max = im.size[0]
elif portrait:
move_idx = [1, 3]
move_max = im.size[1]
else:
return []
e_max = 0
crop_current = [0, 0, settings.crop_width, settings.crop_height]
crop_best = crop_current
while crop_current[move_idx[1]] < move_max:
crop = im.crop(tuple(crop_current))
e = image_entropy(crop)
if (e > e_max):
e_max = e
crop_best = list(crop_current)
crop_current[move_idx[0]] += 4
crop_current[move_idx[1]] += 4
x_mid = int(crop_best[0] + settings.crop_width/2)
y_mid = int(crop_best[1] + settings.crop_height/2)
return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
def image_entropy(im):
# greyscale image entropy
# band = np.asarray(im.convert("L"))
band = np.asarray(im.convert("1"), dtype=np.uint8)
hist, _ = np.histogram(band, bins=range(0, 256))
hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum()
def centroid(pois):
x = [poi.x for poi in pois]
y = [poi.y for poi in pois]
return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
def poi_average(pois, settings):
weight = 0.0
x = 0.0
y = 0.0
for poi in pois:
weight += poi.weight
x += poi.x * poi.weight
y += poi.y * poi.weight
avg_x = round(x / weight)
avg_y = round(y / weight)
return PointOfInterest(avg_x, avg_y)
def is_landscape(w, h):
return w > h
def is_portrait(w, h):
return h > w
def is_square(w, h):
return w == h
def download_and_cache_models(dirname):
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
model_file_name = 'face_detection_yunet.onnx'
if not os.path.exists(dirname):
os.makedirs(dirname)
cache_file = os.path.join(dirname, model_file_name)
if not os.path.exists(cache_file):
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
response = requests.get(download_url)
with open(cache_file, "wb") as f:
f.write(response.content)
if os.path.exists(cache_file):
return cache_file
return None
class PointOfInterest:
def __init__(self, x, y, weight=1.0, size=10):
self.x = x
self.y = y
self.weight = weight
self.size = size
def bounding(self, size):
return [
self.x - size//2,
self.y - size//2,
self.x + size//2,
self.y + size//2
]
class Settings:
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
self.crop_width = crop_width
self.crop_height = crop_height
self.corner_points_weight = corner_points_weight
self.entropy_points_weight = entropy_points_weight
self.face_points_weight = face_points_weight
self.annotate_image = annotate_image
self.destop_view_image = False
self.dnn_model_path = dnn_model_path

View File

@@ -7,12 +7,14 @@ import tqdm
import time
from modules import shared, images
from modules.paths import models_path
from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
try:
if process_caption:
shared.interrogator.load()
@@ -22,7 +24,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio)
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
finally:
@@ -34,7 +36,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -113,6 +115,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
@@ -137,11 +140,36 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
ratio = (img.height * width) / (img.width * height)
inverse_xy = True
process_default_resize = True
if process_split and ratio < 1.0 and ratio <= split_threshold:
for splitted in split_pic(img, inverse_xy):
save_pic(splitted, index, existing_caption=existing_caption)
else:
process_default_resize = False
if process_focal_crop and img.height != img.width:
dnn_model_path = None
try:
dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
except Exception as e:
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
autocrop_settings = autocrop.Settings(
crop_width = width,
crop_height = height,
face_points_weight = process_focal_crop_face_weight,
entropy_points_weight = process_focal_crop_entropy_weight,
corner_points_weight = process_focal_crop_edges_weight,
annotate_image = process_focal_crop_debug,
dnn_model_path = dnn_model_path,
)
for focal in autocrop.crop_image(img, autocrop_settings):
save_pic(focal, index, existing_caption=existing_caption)
process_default_resize = False
if process_default_resize:
img = images.resize_image(1, img, width, height)
save_pic(img, index, existing_caption=existing_caption)
shared.state.nextjob()
shared.state.nextjob()

View File

@@ -157,6 +157,9 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)

View File

@@ -1238,7 +1238,8 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
@@ -1260,6 +1261,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images')
process_focal_crop = gr.Checkbox(label='Auto focal point crop')
process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
@@ -1267,6 +1269,12 @@ def create_ui(wrap_gradio_gpu_call):
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
with gr.Row(visible=False) as process_focal_crop_row:
process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_debug = gr.Checkbox(label='Create debug image')
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
@@ -1280,6 +1288,12 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[process_split_extra_row],
)
process_focal_crop.change(
fn=lambda show: gr_show(show),
inputs=[process_focal_crop],
outputs=[process_focal_crop_row],
)
with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row():
@@ -1342,6 +1356,7 @@ def create_ui(wrap_gradio_gpu_call):
overwrite_old_hypernetwork,
new_hypernetwork_layer_structure,
new_hypernetwork_activation_func,
new_hypernetwork_initialization_option,
new_hypernetwork_add_layer_norm,
new_hypernetwork_use_dropout
],
@@ -1367,6 +1382,11 @@ def create_ui(wrap_gradio_gpu_call):
process_caption_deepbooru,
process_split_threshold,
process_overlap_ratio,
process_focal_crop,
process_focal_crop_face_weight,
process_focal_crop_entropy_weight,
process_focal_crop_edges_weight,
process_focal_crop_debug,
],
outputs=[
ti_output,