diffusion in fp8 landed

This commit is contained in:
lllyasviel
2024-08-06 16:47:39 -07:00
committed by GitHub
parent dd8997ee2e
commit 71c94799d1
7 changed files with 96 additions and 46 deletions

View File

@@ -2,10 +2,13 @@ import os
import torch
import logging
import importlib
import huggingface_guess
from diffusers import DiffusionPipeline
from transformers import modeling_utils
from backend import memory_management
from backend.state_dict import try_filter_state_dict, load_state_dict
from backend.operations import using_forge_operations
from backend.nn.vae import IntegratedAutoencoderKL
@@ -57,9 +60,13 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
return model
if cls_name == 'UNet2DConditionModel':
unet_config = guess.unet_config.copy()
state_dict_size = memory_management.state_dict_size(state_dict)
unet_config['dtype'] = memory_management.unet_dtype(model_params=state_dict_size)
with using_forge_operations():
model = IntegratedUNet2DConditionModel.from_config(guess.unet_config)
model._internal_dict = guess.unet_config
model = IntegratedUNet2DConditionModel.from_config(unet_config)
model._internal_dict = unet_config
load_state_dict(model, state_dict)
return model