mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 08:20:35 +00:00
Allow user to set a training seed via env vars for repeat result testing
This commit is contained in:
21
run.py
21
run.py
@@ -1,11 +1,16 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Union, OrderedDict
|
||||
from dotenv import load_dotenv
|
||||
# Load the .env file if it exists
|
||||
load_dotenv()
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = os.getenv("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
|
||||
seed = None
|
||||
if "SEED" in os.environ:
|
||||
try:
|
||||
seed = int(os.environ["SEED"])
|
||||
except ValueError:
|
||||
print(f"Invalid SEED value: {os.environ['SEED']}. SEED must be an integer.")
|
||||
|
||||
sys.path.insert(0, os.getcwd())
|
||||
# must come before ANY torch or fastai imports
|
||||
@@ -14,11 +19,21 @@ sys.path.insert(0, os.getcwd())
|
||||
# turn off diffusers telemetry until I can figure out how to make it opt-in
|
||||
os.environ['DISABLE_TELEMETRY'] = 'YES'
|
||||
|
||||
# set torch to trace mode
|
||||
import torch
|
||||
|
||||
# check if we have DEBUG_TOOLKIT in env
|
||||
if os.environ.get("DEBUG_TOOLKIT", "0") == "1":
|
||||
# set torch to trace mode
|
||||
import torch
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
if seed is not None:
|
||||
import random
|
||||
import numpy as np
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
import argparse
|
||||
from toolkit.job import get_job
|
||||
from toolkit.accelerator import get_accelerator
|
||||
|
||||
Reference in New Issue
Block a user