WIP. just need to put it here

This commit is contained in:
Jaret Burkett
2023-07-27 01:46:30 -06:00
parent 2305e55c82
commit 6ab8b8b0f1
4 changed files with 279 additions and 63 deletions

View File

@@ -1,3 +1,4 @@
import typing
from typing import Union, OrderedDict
import sys
import os
@@ -36,7 +37,15 @@ class PromptEmbeds:
return self
# if is type checking
if typing.TYPE_CHECKING:
from diffusers import StableDiffusionPipeline
from toolkit.pipelines import CustomStableDiffusionXLPipeline
class StableDiffusion:
pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline']
def __init__(
self,
vae,
@@ -44,7 +53,8 @@ class StableDiffusion:
text_encoder,
unet,
noise_scheduler,
is_xl=False
is_xl=False,
pipeline=None,
):
# text encoder has a list of 2 for xl
self.vae = vae
@@ -53,6 +63,7 @@ class StableDiffusion:
self.unet = unet
self.noise_scheduler = noise_scheduler
self.is_xl = is_xl
self.pipeline = pipeline
def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds:
prompt = prompt