Optimize offline translation model loading

优化离线翻译模型加载
This commit is contained in:
Physton
2023-06-28 11:24:50 +08:00
parent c7b4a8378f
commit e19b3e238d

View File

@@ -1,4 +1,5 @@
import os
import time
from scripts.physton_prompt.get_lang import get_lang
model = None
@@ -10,23 +11,34 @@ loading = False
def initialize(reload=False):
global model, tokenizer, model_name, cache_dir, loading
if loading:
raise Exception(get_lang('model_is_loading'))
while not loading:
time.sleep(0.1)
pass
if model is None or tokenizer is None:
raise Exception('error')
# raise Exception(get_lang('model_is_loading'))
return
if not reload and model is not None:
return
loading = True
model = None
tokenizer = None
model_path = os.path.join(cache_dir, "mbart-large-50-many-to-many-mmt")
model_file = os.path.join(model_path, "pytorch_model.bin")
if os.path.exists(model_path) and os.path.exists(model_file):
model_name = model_path
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
print(f'[sd-webui-prompt-all-in-one] Loading model {model_name} from {cache_dir}...')
model = MBartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
tokenizer = MBart50TokenizerFast.from_pretrained(model_name, cache_dir=cache_dir)
print(f'[sd-webui-prompt-all-in-one] Model {model_name} loaded.')
loading = False
try:
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
print(f'[sd-webui-prompt-all-in-one] Loading model {model_name} from {cache_dir}...')
model = MBartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
tokenizer = MBart50TokenizerFast.from_pretrained(model_name, cache_dir=cache_dir)
print(f'[sd-webui-prompt-all-in-one] Model {model_name} loaded.')
loading = False
except Exception as e:
loading = False
raise e
def translate(text, src_lang, target_lang):
global model, tokenizer