mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized
This commit is contained in:
@@ -209,6 +209,8 @@ class StableDiffusion:
|
||||
# todo update this based on the model
|
||||
self.is_transformer = False
|
||||
|
||||
self.sample_prompts_cache = None
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def is_xl(self):
|
||||
@@ -1426,18 +1428,22 @@ class StableDiffusion:
|
||||
quad_count=4
|
||||
)
|
||||
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
if self.sample_prompts_cache is not None:
|
||||
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
else:
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
|
||||
# allow any manipulations to take place to embeddings
|
||||
gen_config.post_process_embeddings(
|
||||
|
||||
Reference in New Issue
Block a user