Added chroma model to the ui. Added logic to easily pull latest, use local, or use a specific version of chroma. Allow ustom name or path in the ui for custom models

This commit is contained in:
Jaret Burkett
2025-05-07 12:06:30 -06:00
parent d9700bdb99
commit 43cb5603ad
11 changed files with 830 additions and 181 deletions

View File

@@ -22,6 +22,7 @@ import torch.nn.functional as F
from .src.model import Chroma, chroma_params
from safetensors.torch import load_file, save_file
from toolkit.metadata import get_meta_for_safetensors
import huggingface_hub
if TYPE_CHECKING:
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
@@ -100,7 +101,44 @@ class ChromaModel(BaseModel):
# will be updated if we detect a existing checkpoint in training folder
model_path = self.model_config.name_or_path
extras_path = 'black-forest-labs/FLUX.1-schnell'
if model_path == "lodestones/Chroma":
print("Looking for latest Chroma checkpoint")
# get the latest checkpoint
files_list = huggingface_hub.list_repo_files(model_path)
print(files_list)
latest_version = 28 # current latest version at time of writing
while True:
if f"chroma-unlocked-v{latest_version}.safetensors" not in files_list:
latest_version -= 1
break
else:
latest_version += 1
print(f"Using latest Chroma version: v{latest_version}")
# make sure we have it
model_path = huggingface_hub.hf_hub_download(
repo_id=model_path,
filename=f"chroma-unlocked-v{latest_version}.safetensors",
)
elif model_path.startswith("lodestones/Chroma/v"):
# get the version number
version = model_path.split("/")[-1].split("v")[-1]
print(f"Using Chroma version: v{version}")
# make sure we have it
model_path = huggingface_hub.hf_hub_download(
repo_id='lodestones/Chroma',
filename=f"chroma-unlocked-v{version}.safetensors",
)
else:
# check if the model path is a local file
if os.path.exists(model_path):
print(f"Using local model: {model_path}")
else:
raise ValueError(f"Model path {model_path} does not exist")
# extras_path = 'black-forest-labs/FLUX.1-schnell'
# schnell model is gated now, use flex instead
extras_path = 'ostris/Flex.1-alpha'
self.print_and_status_update("Loading transformer")