mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
WIP. just need to put it here
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import typing
|
||||
from typing import Union, OrderedDict
|
||||
import sys
|
||||
import os
|
||||
@@ -36,7 +37,15 @@ class PromptEmbeds:
|
||||
return self
|
||||
|
||||
|
||||
# if is type checking
|
||||
if typing.TYPE_CHECKING:
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
@@ -44,7 +53,8 @@ class StableDiffusion:
|
||||
text_encoder,
|
||||
unet,
|
||||
noise_scheduler,
|
||||
is_xl=False
|
||||
is_xl=False,
|
||||
pipeline=None,
|
||||
):
|
||||
# text encoder has a list of 2 for xl
|
||||
self.vae = vae
|
||||
@@ -53,6 +63,7 @@ class StableDiffusion:
|
||||
self.unet = unet
|
||||
self.noise_scheduler = noise_scheduler
|
||||
self.is_xl = is_xl
|
||||
self.pipeline = pipeline
|
||||
|
||||
def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds:
|
||||
prompt = prompt
|
||||
|
||||
Reference in New Issue
Block a user