mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Added prompt flag to adjust network multiplier
This commit is contained in:
@@ -67,6 +67,7 @@ class SampleConfig:
|
||||
self.walk_seed = kwargs.get('walk_seed', False)
|
||||
self.guidance_scale = kwargs.get('guidance_scale', 7)
|
||||
self.sample_steps = kwargs.get('sample_steps', 20)
|
||||
self.network_multiplier = kwargs.get('network_multiplier', 1)
|
||||
|
||||
|
||||
class NetworkConfig:
|
||||
@@ -214,6 +215,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
start_seed = self.sample_config.seed
|
||||
start_multiplier = self.network.multiplier
|
||||
current_seed = start_seed
|
||||
|
||||
pipeline.to(self.device_torch)
|
||||
@@ -228,12 +230,22 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
|
||||
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}"):
|
||||
raw_prompt = self.sample_config.prompts[i]
|
||||
prompt = raw_prompt
|
||||
|
||||
neg = self.sample_config.neg
|
||||
p_split = raw_prompt.split('--n')
|
||||
multiplier = self.sample_config.network_multiplier
|
||||
p_split = raw_prompt.split('--')
|
||||
prompt = p_split[0].strip()
|
||||
|
||||
if len(p_split) > 1:
|
||||
prompt = p_split[0].strip()
|
||||
neg = p_split[1].strip()
|
||||
for split in p_split:
|
||||
flag = split[:1]
|
||||
content = split[1:].strip()
|
||||
if flag == 'n':
|
||||
neg = content
|
||||
elif flag == 'm':
|
||||
# multiplier
|
||||
multiplier = float(content)
|
||||
|
||||
height = self.sample_config.height
|
||||
width = self.sample_config.width
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
@@ -242,6 +254,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
if self.sample_config.walk_seed:
|
||||
current_seed += i
|
||||
|
||||
self.network.multiplier = multiplier
|
||||
torch.manual_seed(current_seed)
|
||||
torch.cuda.manual_seed(current_seed)
|
||||
|
||||
@@ -278,6 +291,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
self.sd.unet.to(original_device_dict['unet'])
|
||||
self.sd.text_encoder.to(original_device_dict['text_encoder'])
|
||||
self.network.train()
|
||||
self.network.multiplier = start_multiplier
|
||||
# self.sd.tokenizer.to(original_device_dict['tokenizer'])
|
||||
|
||||
def update_training_metadata(self):
|
||||
|
||||
Reference in New Issue
Block a user