Update mbart50.py

This commit is contained in:
Physton
2023-06-25 00:08:32 +08:00
parent 1cd5ff3161
commit 6976a4ce0d

View File

@@ -15,6 +15,12 @@ def initialize(reload=False):
if not reload and model is not None:
return
loading = True
model_path = os.path.join(cache_dir, "mbart-large-50-many-to-many-mmt")
model_file = os.path.join(model_path, "pytorch_model.bin")
if os.path.exists(model_path) and os.path.exists(model_file):
model_name = model_path
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
print(f'[sd-webui-prompt-all-in-one] Loading model {model_name} from {cache_dir}...')
model = MBartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)