mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Added chroma model to the ui. Added logic to easily pull latest, use local, or use a specific version of chroma. Allow ustom name or path in the ui for custom models
This commit is contained in:
@@ -22,6 +22,7 @@ import torch.nn.functional as F
|
||||
from .src.model import Chroma, chroma_params
|
||||
from safetensors.torch import load_file, save_file
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
import huggingface_hub
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
@@ -100,7 +101,44 @@ class ChromaModel(BaseModel):
|
||||
# will be updated if we detect a existing checkpoint in training folder
|
||||
model_path = self.model_config.name_or_path
|
||||
|
||||
extras_path = 'black-forest-labs/FLUX.1-schnell'
|
||||
if model_path == "lodestones/Chroma":
|
||||
print("Looking for latest Chroma checkpoint")
|
||||
# get the latest checkpoint
|
||||
files_list = huggingface_hub.list_repo_files(model_path)
|
||||
print(files_list)
|
||||
latest_version = 28 # current latest version at time of writing
|
||||
while True:
|
||||
if f"chroma-unlocked-v{latest_version}.safetensors" not in files_list:
|
||||
latest_version -= 1
|
||||
break
|
||||
else:
|
||||
latest_version += 1
|
||||
print(f"Using latest Chroma version: v{latest_version}")
|
||||
|
||||
# make sure we have it
|
||||
model_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=model_path,
|
||||
filename=f"chroma-unlocked-v{latest_version}.safetensors",
|
||||
)
|
||||
elif model_path.startswith("lodestones/Chroma/v"):
|
||||
# get the version number
|
||||
version = model_path.split("/")[-1].split("v")[-1]
|
||||
print(f"Using Chroma version: v{version}")
|
||||
# make sure we have it
|
||||
model_path = huggingface_hub.hf_hub_download(
|
||||
repo_id='lodestones/Chroma',
|
||||
filename=f"chroma-unlocked-v{version}.safetensors",
|
||||
)
|
||||
else:
|
||||
# check if the model path is a local file
|
||||
if os.path.exists(model_path):
|
||||
print(f"Using local model: {model_path}")
|
||||
else:
|
||||
raise ValueError(f"Model path {model_path} does not exist")
|
||||
|
||||
# extras_path = 'black-forest-labs/FLUX.1-schnell'
|
||||
# schnell model is gated now, use flex instead
|
||||
extras_path = 'ostris/Flex.1-alpha'
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user