From 86dcf39eee294ed99c3b35f70bb1fa3d8f6420f1 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 29 Mar 2026 13:34:46 -0600 Subject: [PATCH] Allow user to set a training seed via env vars for repeat result testing --- run.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/run.py b/run.py index e3cbdef4..c717c555 100644 --- a/run.py +++ b/run.py @@ -1,11 +1,16 @@ import os import sys -from typing import Union, OrderedDict from dotenv import load_dotenv # Load the .env file if it exists load_dotenv() os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = os.getenv("HF_HUB_ENABLE_HF_TRANSFER", "1") os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1" +seed = None +if "SEED" in os.environ: + try: + seed = int(os.environ["SEED"]) + except ValueError: + print(f"Invalid SEED value: {os.environ['SEED']}. SEED must be an integer.") sys.path.insert(0, os.getcwd()) # must come before ANY torch or fastai imports @@ -14,11 +19,21 @@ sys.path.insert(0, os.getcwd()) # turn off diffusers telemetry until I can figure out how to make it opt-in os.environ['DISABLE_TELEMETRY'] = 'YES' +# set torch to trace mode +import torch + # check if we have DEBUG_TOOLKIT in env if os.environ.get("DEBUG_TOOLKIT", "0") == "1": - # set torch to trace mode - import torch torch.autograd.set_detect_anomaly(True) + +if seed is not None: + import random + import numpy as np + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + import argparse from toolkit.job import get_job from toolkit.accelerator import get_accelerator