move to new backend - part 2

This commit is contained in:
layerdiffusion
2024-08-03 15:10:37 -07:00
parent 8a01b2c5db
commit 4add428e25
9 changed files with 61 additions and 69 deletions

View File

@@ -1,22 +1,23 @@
from modules import sd_samplers_kdiffusion, sd_samplers_common
from ldm_patched.k_diffusion import sampling as k_diffusion_sampling
class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
def __init__(self, sd_model, sampler_name):
self.sampler_name = sampler_name
self.unet = sd_model.forge_objects.unet
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
super().__init__(sampler_function, sd_model, None)
def build_constructor(sampler_name):
def constructor(m):
return AlterSampler(m, sampler_name)
return constructor
# from modules import sd_samplers_kdiffusion, sd_samplers_common
# from ldm_patched.k_diffusion import sampling as k_diffusion_sampling
#
#
# class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
# def __init__(self, sd_model, sampler_name):
# self.sampler_name = sampler_name
# self.unet = sd_model.forge_objects.unet
# sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
# super().__init__(sampler_function, sd_model, None)
#
#
# def build_constructor(sampler_name):
# def constructor(m):
# return AlterSampler(m, sampler_name)
#
# return constructor
#
#
samplers_data_alter = [
sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}),
# sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}),
]