Added Model rescale and prepared a release upgrade

This commit is contained in:
Jaret Burkett
2023-08-01 13:49:54 -06:00
parent 63cacf4362
commit 8b8d53888d
15 changed files with 388 additions and 64 deletions

View File

@@ -19,7 +19,7 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2D
DDIMScheduler, DDPMScheduler
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
@@ -192,6 +192,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
negative_prompt=neg,
guidance_rescale=0.7,
).images[0]
else:
img = pipeline(
@@ -236,21 +237,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
# self.sd.tokenizer.to(original_device_dict['tokenizer'])
def update_training_metadata(self):
dict = OrderedDict({
o_dict = OrderedDict({
"training_info": self.get_training_info()
})
if self.model_config.is_v2:
dict['ss_v2'] = True
dict['ss_base_model_version'] = 'sd_2.1'
o_dict['ss_v2'] = True
o_dict['ss_base_model_version'] = 'sd_2.1'
elif self.model_config.is_xl:
dict['ss_base_model_version'] = 'sdxl_1.0'
o_dict['ss_base_model_version'] = 'sdxl_1.0'
else:
dict['ss_base_model_version'] = 'sd_1.5'
o_dict['ss_base_model_version'] = 'sd_1.5'
dict['ss_output_name'] = self.job.name
o_dict = add_base_model_info_to_meta(
o_dict,
is_v2=self.model_config.is_v2,
is_xl=self.model_config.is_xl,
)
o_dict['ss_output_name'] = self.job.name
self.add_meta(dict)
self.add_meta(o_dict)
def get_training_info(self):
info = OrderedDict({
@@ -381,7 +387,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
text_embeddings: PromptEmbeds,
timestep: int,
guidance_scale=7.5,
guidance_rescale=0.7,
guidance_rescale=0, # 0.7
add_time_ids=None,
**kwargs,
):
@@ -389,7 +395,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.sd.is_xl:
if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents)
# todo LECOs code looks like it is omitting noise_pred
latent_model_input = torch.cat([latents] * 2)
@@ -500,13 +505,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
# TODO handle other schedulers
sch = KDPM2DiscreteScheduler
# sch = KDPM2DiscreteScheduler
sch = DDPMScheduler
# do our own scheduler
prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
scheduler = sch(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.0120,
beta_schedule="scaled_linear",
clip_sample=False,
prediction_type=prediction_type,
)
if self.model_config.is_xl:
if self.custom_pipeline is not None: