mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-05 13:09:48 +00:00
use trange
This commit is contained in:
@@ -2,17 +2,17 @@ import torch
|
||||
import math
|
||||
import itertools
|
||||
|
||||
from tqdm import tqdm
|
||||
from tqdm import trange
|
||||
from backend import memory_management
|
||||
from backend.patcher.base import ModelPatcher
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None):
|
||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu"):
|
||||
dims = len(tile)
|
||||
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
|
||||
|
||||
for b in range(samples.shape[0]):
|
||||
for b in trange(samples.shape[0]):
|
||||
s = samples[b:b + 1]
|
||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
@@ -45,9 +45,6 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
o += ps * mask
|
||||
o_d += mask
|
||||
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
output[b:b + 1] = out / out_div
|
||||
return output
|
||||
|
||||
@@ -56,29 +53,8 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||
|
||||
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None):
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar)
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total, title=None):
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.tqdm = tqdm(total=total, desc=title)
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
self.total = total
|
||||
if value > self.total:
|
||||
value = self.total
|
||||
inc = value - self.current
|
||||
self.tqdm.update(inc)
|
||||
self.current = value
|
||||
if self.current >= self.total:
|
||||
self.tqdm.close()
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu"):
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device)
|
||||
|
||||
|
||||
class VAE:
|
||||
@@ -129,12 +105,11 @@ class VAE:
|
||||
steps = samples.shape[0] * get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += samples.shape[0] * get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = ProgressBar(steps, title='VAE tiled decode')
|
||||
|
||||
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
||||
output = torch.clamp(((tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device, pbar=pbar) +
|
||||
tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device, pbar=pbar) +
|
||||
tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device, pbar=pbar))
|
||||
output = torch.clamp(((tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device) +
|
||||
tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device) +
|
||||
tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device))
|
||||
/ 3.0) / 2.0, min=0.0, max=1.0)
|
||||
return output
|
||||
|
||||
@@ -142,12 +117,11 @@ class VAE:
|
||||
steps = pixel_samples.shape[0] * get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += pixel_samples.shape[0] * get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = ProgressBar(steps, title='VAE tiled encode')
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
||||
samples = tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples = tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
samples += tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
samples += tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
samples /= 3.0
|
||||
return samples
|
||||
|
||||
|
||||
Reference in New Issue
Block a user