mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Minor fixes
This commit is contained in:
@@ -7,16 +7,21 @@ from safetensors.torch import load_file, save_file
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import json
|
import json
|
||||||
|
|
||||||
model_path = "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/objective_reality_v2.safetensors"
|
model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000"
|
||||||
te_path = "google/flan-t5-xl"
|
te_path = "google/flan-t5-xl"
|
||||||
te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
||||||
output_path = "/home/jaret/Dev/models/hf/t5xl_sd15_v1"
|
output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
||||||
|
|
||||||
print("Loading te adapter")
|
print("Loading te adapter")
|
||||||
te_aug_sd = load_file(te_aug_path)
|
te_aug_sd = load_file(te_aug_path)
|
||||||
|
|
||||||
print("Loading model")
|
print("Loading model")
|
||||||
sd = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch.float16)
|
is_diffusers = (not os.path.exists(model_path)) or os.path.isdir(model_path)
|
||||||
|
|
||||||
|
if is_diffusers:
|
||||||
|
sd = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||||
|
else:
|
||||||
|
sd = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch.float16)
|
||||||
|
|
||||||
print("Loading Text Encoder")
|
print("Loading Text Encoder")
|
||||||
# Load the text encoder
|
# Load the text encoder
|
||||||
@@ -74,6 +79,7 @@ for name in sd.unet.attn_processors.keys():
|
|||||||
|
|
||||||
|
|
||||||
print("Saving unmodified model")
|
print("Saving unmodified model")
|
||||||
|
sd = sd.to("cpu", torch.float16)
|
||||||
sd.save_pretrained(
|
sd.save_pretrained(
|
||||||
output_path,
|
output_path,
|
||||||
safe_serialization=True,
|
safe_serialization=True,
|
||||||
|
|||||||
@@ -339,7 +339,7 @@ class ModelConfig:
|
|||||||
self.is_v2: bool = kwargs.get('is_v2', False)
|
self.is_v2: bool = kwargs.get('is_v2', False)
|
||||||
self.is_xl: bool = kwargs.get('is_xl', False)
|
self.is_xl: bool = kwargs.get('is_xl', False)
|
||||||
self.is_pixart: bool = kwargs.get('is_pixart', False)
|
self.is_pixart: bool = kwargs.get('is_pixart', False)
|
||||||
self.is_pixart_sigma: bool = kwargs.get('is_pixart', False)
|
self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False)
|
||||||
self.is_v3: bool = kwargs.get('is_v3', False)
|
self.is_v3: bool = kwargs.get('is_v3', False)
|
||||||
if self.is_pixart_sigma:
|
if self.is_pixart_sigma:
|
||||||
self.is_pixart = True
|
self.is_pixart = True
|
||||||
|
|||||||
@@ -240,7 +240,7 @@ def get_direct_guidance_loss(
|
|||||||
|
|
||||||
noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0)
|
noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0)
|
||||||
|
|
||||||
guidance_scale = 1.25
|
guidance_scale = 1.1
|
||||||
guidance_pred = noise_pred_uncond + guidance_scale * (
|
guidance_pred = noise_pred_uncond + guidance_scale * (
|
||||||
noise_pred_cond - noise_pred_uncond
|
noise_pred_cond - noise_pred_uncond
|
||||||
)
|
)
|
||||||
@@ -552,39 +552,17 @@ def get_guided_tnt(
|
|||||||
reduction="none"
|
reduction="none"
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
tnt_loss = this_loss - that_loss
|
|
||||||
|
|
||||||
# create a mask by scaling loss from 0 to mean to 1 to 0
|
|
||||||
# this will act to regularize unchanged areas to prior prediction
|
|
||||||
loss_min = tnt_loss.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0].min(dim=3, keepdim=True)[0]
|
|
||||||
loss_mean = tnt_loss.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
|
|
||||||
mask = value_map(
|
|
||||||
torch.abs(tnt_loss),
|
|
||||||
loss_min,
|
|
||||||
loss_mean,
|
|
||||||
0.0,
|
|
||||||
1.0
|
|
||||||
).clamp(0.0, 1.0).detach()
|
|
||||||
|
|
||||||
prior_mask = 1.0 - mask
|
|
||||||
|
|
||||||
|
|
||||||
this_loss = this_loss * mask
|
|
||||||
that_loss = that_loss * prior_mask
|
|
||||||
|
|
||||||
this_loss = this_loss.mean([1, 2, 3])
|
this_loss = this_loss.mean([1, 2, 3])
|
||||||
that_loss = that_loss.mean([1, 2, 3])
|
# negative loss on that
|
||||||
|
that_loss = -that_loss.mean([1, 2, 3])
|
||||||
|
|
||||||
prior_loss = torch.nn.functional.mse_loss(
|
with torch.no_grad():
|
||||||
this_prediction.float(),
|
# match that loss with this loss so it is not a negative value and same scale
|
||||||
prior_pred.detach().float(),
|
that_loss_scaler = torch.abs(this_loss) / torch.abs(that_loss)
|
||||||
reduction="none"
|
|
||||||
)
|
|
||||||
prior_loss = prior_loss * prior_mask
|
|
||||||
prior_loss = prior_loss.mean([1, 2, 3])
|
|
||||||
|
|
||||||
loss = prior_loss + this_loss - that_loss
|
that_loss = that_loss * that_loss_scaler * 0.01
|
||||||
|
|
||||||
|
loss = this_loss + that_loss
|
||||||
|
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user