mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Added Model rescale and prepared a release upgrade
This commit is contained in:
@@ -19,7 +19,7 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2D
|
||||
DDIMScheduler, DDPMScheduler
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
|
||||
@@ -192,6 +192,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
num_inference_steps=sample_config.sample_steps,
|
||||
guidance_scale=sample_config.guidance_scale,
|
||||
negative_prompt=neg,
|
||||
guidance_rescale=0.7,
|
||||
).images[0]
|
||||
else:
|
||||
img = pipeline(
|
||||
@@ -236,21 +237,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# self.sd.tokenizer.to(original_device_dict['tokenizer'])
|
||||
|
||||
def update_training_metadata(self):
|
||||
dict = OrderedDict({
|
||||
o_dict = OrderedDict({
|
||||
"training_info": self.get_training_info()
|
||||
})
|
||||
if self.model_config.is_v2:
|
||||
dict['ss_v2'] = True
|
||||
dict['ss_base_model_version'] = 'sd_2.1'
|
||||
o_dict['ss_v2'] = True
|
||||
o_dict['ss_base_model_version'] = 'sd_2.1'
|
||||
|
||||
elif self.model_config.is_xl:
|
||||
dict['ss_base_model_version'] = 'sdxl_1.0'
|
||||
o_dict['ss_base_model_version'] = 'sdxl_1.0'
|
||||
else:
|
||||
dict['ss_base_model_version'] = 'sd_1.5'
|
||||
o_dict['ss_base_model_version'] = 'sd_1.5'
|
||||
|
||||
dict['ss_output_name'] = self.job.name
|
||||
o_dict = add_base_model_info_to_meta(
|
||||
o_dict,
|
||||
is_v2=self.model_config.is_v2,
|
||||
is_xl=self.model_config.is_xl,
|
||||
)
|
||||
o_dict['ss_output_name'] = self.job.name
|
||||
|
||||
self.add_meta(dict)
|
||||
self.add_meta(o_dict)
|
||||
|
||||
def get_training_info(self):
|
||||
info = OrderedDict({
|
||||
@@ -381,7 +387,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
text_embeddings: PromptEmbeds,
|
||||
timestep: int,
|
||||
guidance_scale=7.5,
|
||||
guidance_rescale=0.7,
|
||||
guidance_rescale=0, # 0.7
|
||||
add_time_ids=None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -389,7 +395,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.sd.is_xl:
|
||||
if add_time_ids is None:
|
||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||
# todo LECOs code looks like it is omitting noise_pred
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
|
||||
@@ -500,13 +505,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# TODO handle other schedulers
|
||||
sch = KDPM2DiscreteScheduler
|
||||
# sch = KDPM2DiscreteScheduler
|
||||
sch = DDPMScheduler
|
||||
# do our own scheduler
|
||||
prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
||||
scheduler = sch(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.0120,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
if self.model_config.is_xl:
|
||||
if self.custom_pipeline is not None:
|
||||
|
||||
Reference in New Issue
Block a user