mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added a guidance burning loss. Modified DFE to work with new model. Bug fixes
This commit is contained in:
@@ -1012,7 +1012,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = None
|
||||
is_reg = any(batch.get_is_reg_list())
|
||||
cfm_batch = None
|
||||
if batch.tensor is not None:
|
||||
imgs = batch.tensor
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
|
||||
@@ -113,6 +113,8 @@ class GenerateProcess(BaseProcess):
|
||||
prompt_image_configs = []
|
||||
for _ in range(self.generate_config.num_repeats):
|
||||
for prompt in self.generate_config.prompts:
|
||||
# remove --
|
||||
prompt = prompt.replace('--', '').strip()
|
||||
width = self.generate_config.width
|
||||
height = self.generate_config.height
|
||||
# prompt = self.clean_prompt(prompt)
|
||||
|
||||
Reference in New Issue
Block a user