diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index ba44ff8d..27e57025 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -501,13 +501,22 @@ class SDTrainer(BaseSDTrainProcess): loss = wavelet_loss(pred, batch.latents, noise) else: loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") + + do_weighted_timesteps = False + if self.sd.is_flow_matching: + if self.train_config.linear_timesteps or self.train_config.linear_timesteps2: + do_weighted_timesteps = True + if self.train_config.timestep_type == "weighted": + # use the noise scheduler to get the weights for the timesteps + do_weighted_timesteps = True # handle linear timesteps and only adjust the weight of the timesteps - if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2): + if do_weighted_timesteps: # calculate the weights for the timesteps timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps( timesteps, - v2=self.train_config.linear_timesteps2 + v2=self.train_config.linear_timesteps2, + timestep_type=self.train_config.timestep_type ).to(loss.device, dtype=loss.dtype) timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() loss = loss * timestep_weight diff --git a/scripts/calculate_timestep_weighing_flex.py b/scripts/calculate_timestep_weighing_flex.py new file mode 100644 index 00000000..05a21766 --- /dev/null +++ b/scripts/calculate_timestep_weighing_flex.py @@ -0,0 +1,228 @@ +import gc +import os, sys +from tqdm import tqdm +import numpy as np +import json +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# set visible devices to 0 +# os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +# protect from formatting +if True: + import torch + from optimum.quanto import freeze, qfloat8, QTensor, qint4 + from diffusers import FluxTransformer2DModel, FluxPipeline, AutoencoderKL, FlowMatchEulerDiscreteScheduler + from toolkit.util.quantize import quantize, get_qtype + from transformers import T5EncoderModel, T5TokenizerFast, CLIPTextModel, CLIPTokenizer + from torchvision import transforms + +qtype = "qfloat8" +dtype = torch.bfloat16 +# base_model_path = "black-forest-labs/FLUX.1-dev" +base_model_path = "ostris/Flex.1-alpha" +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print("Loading Transformer...") +prompt = "Photo of a man and a woman in a park, sunny day" + +output_root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "output") +output_path = os.path.join(output_root, "flex_timestep_weights.json") +img_output_path = os.path.join(output_root, "flex_timestep_weights.png") + +quantization_type = get_qtype(qtype) + +def flush(): + torch.cuda.empty_cache() + gc.collect() + +pil_to_tensor = transforms.ToTensor() + +with torch.no_grad(): + transformer = FluxTransformer2DModel.from_pretrained( + base_model_path, + subfolder='transformer', + torch_dtype=dtype + ) + + transformer.to(device, dtype=dtype) + + print("Quantizing Transformer...") + quantize(transformer, weights=quantization_type) + freeze(transformer) + flush() + + print("Loading Scheduler...") + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") + + print("Loading Autoencoder...") + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) + + vae.to(device, dtype=dtype) + + flush() + print("Loading Text Encoder...") + tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) + text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) + text_encoder_2.to(device, dtype=dtype) + + print("Quantizing Text Encoder...") + quantize(text_encoder_2, weights=get_qtype(qtype)) + freeze(text_encoder_2) + flush() + + print("Loading CLIP") + text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder.to(device, dtype=dtype) + + print("Making pipe") + + pipe: FluxPipeline = FluxPipeline( + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + ) + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + pipe.to(device, dtype=dtype) + + print("Encoding prompt...") + + prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt( + prompt, + prompt_2=prompt, + device=device + ) + + + generator = torch.manual_seed(42) + + height = 1024 + width = 1024 + + print("Generating image...") + + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + latents = callback_kwargs["latents"] + if latents.dtype != dtype: + latents = latents.to(dtype) + return {"latents": latents} + img = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + height=height, + width=height, + num_inference_steps=30, + guidance_scale=3.5, + generator=generator, + callback_on_step_end=callback_on_step_end, + ).images[0] + + img.save(img_output_path) + print(f"Image saved to {img_output_path}") + + print("Encoding image...") + # img is a PIL image. convert it to a -1 to 1 tensor + img = pil_to_tensor(img) + img = img.unsqueeze(0) # add batch dimension + img = img * 2 - 1 # convert to -1 to 1 range + img = img.to(device, dtype=dtype) + latents = vae.encode(img).latent_dist.sample() + + shift = vae.config['shift_factor'] if vae.config['shift_factor'] is not None else 0 + latents = vae.config['scaling_factor'] * (latents - shift) + + num_channels_latents = pipe.transformer.config.in_channels // 4 + + l_height = 2 * (int(height) // (pipe.vae_scale_factor * 2)) + l_width = 2 * (int(width) // (pipe.vae_scale_factor * 2)) + packed_latents = pipe._pack_latents(latents, 1, num_channels_latents, l_height, l_width) + + packed_latents, latent_image_ids = pipe.prepare_latents( + 1, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + packed_latents, + ) + + print("Calculating timestep weights...") + + torch.manual_seed(8675309) + noise = torch.randn_like(packed_latents, device=device, dtype=dtype) + + # Create linear timesteps from 1000 to 0 + num_train_timesteps = 1000 + timesteps_torch = torch.linspace(1000, 1, num_train_timesteps, device='cpu') + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + timestep_weights = torch.zeros(num_train_timesteps, dtype=torch.float32, device=device) + + guidance = torch.full([1], 1.0, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + pbar = tqdm(range(num_train_timesteps), desc="loss: 0.000000 scaler: 0.0000") + for i in pbar: + timestep = timesteps[i:i+1].to(device) + t_01 = (timestep / 1000).to(device) + t_01 = t_01.reshape(-1, 1, 1) + noisy_latents = (1.0 - t_01) * packed_latents + t_01 * noise + + noise_pred = pipe.transformer( + hidden_states=noisy_latents, # torch.Size([1, 4096, 64]) + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + + target = noise - packed_latents + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float()) + loss = loss + + # determine scaler to multiply loss by to make it 1 + scaler = 1.0 / (loss + 1e-6) + + timestep_weights[i] = scaler + pbar.set_description(f"loss: {loss.item():.6f} scaler: {scaler.item():.4f}") + + print("normalizing timestep weights...") + # normalize the timestep weights so they are a mean of 1.0 + timestep_weights = timestep_weights / timestep_weights.mean() + timestep_weights = timestep_weights.cpu().numpy().tolist() + + print("Saving timestep weights...") + + with open(output_path, 'w') as f: + json.dump(timestep_weights, f) + + +print(f"Timestep weights saved to {output_path}") +print("Done!") +flush() + + + + + + + + + + + + \ No newline at end of file diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index aa3d84a6..0202eb8e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -437,7 +437,7 @@ class TrainConfig: # adds an additional loss to the network to encourage it output a normalized standard deviation self.target_norm_std = kwargs.get('target_norm_std', None) self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0) - self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample + self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample, weighted self.next_sample_timesteps = kwargs.get('next_sample_timesteps', 8) self.linear_timesteps = kwargs.get('linear_timesteps', False) self.linear_timesteps2 = kwargs.get('linear_timesteps2', False) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 7985bfe7..5f419ff0 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -1773,6 +1773,97 @@ class LatentCachingMixin: self.sd.restore_device_state() + +class TextEmbeddingCachingMixin: + def __init__(self: 'AiToolkitDataset', **kwargs): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__(**kwargs) + self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings + + def cache_text_embeddings(self: 'AiToolkitDataset'): + + with accelerator.main_process_first(): + print_acc(f"Caching text_embeddings for {self.dataset_path}") + # cache all latents to disk + to_disk = self.is_caching_latents_to_disk + to_memory = self.is_caching_latents_to_memory + print_acc(" - Saving text embeddings to disk") + # move sd items to cpu except for vae + self.sd.set_device_state_preset('cache_latents') + + # use tqdm to show progress + i = 0 + for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): + # set latent space version + if self.sd.model_config.latent_space_version is not None: + file_item.latent_space_version = self.sd.model_config.latent_space_version + elif self.sd.is_xl: + file_item.latent_space_version = 'sdxl' + elif self.sd.is_v3: + file_item.latent_space_version = 'sd3' + elif self.sd.is_auraflow: + file_item.latent_space_version = 'sdxl' + elif self.sd.is_flux: + file_item.latent_space_version = 'flux1' + elif self.sd.model_config.is_pixart_sigma: + file_item.latent_space_version = 'sdxl' + else: + file_item.latent_space_version = self.sd.model_config.arch + file_item.is_caching_to_disk = to_disk + file_item.is_caching_to_memory = to_memory + file_item.latent_load_device = self.sd.device + + latent_path = file_item.get_latent_path(recalculate=True) + # check if it is saved to disk already + if os.path.exists(latent_path): + if to_memory: + # load it into memory + state_dict = load_file(latent_path, device='cpu') + file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) + else: + # not saved to disk, calculate + # load the image first + file_item.load_and_process_image(self.transform, only_load_latents=True) + dtype = self.sd.torch_dtype + device = self.sd.device_torch + # add batch dimension + try: + imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) + latent = self.sd.encode_images(imgs).squeeze(0) + except Exception as e: + print_acc(f"Error processing image: {file_item.path}") + print_acc(f"Error: {str(e)}") + raise e + # save_latent + if to_disk: + state_dict = OrderedDict([ + ('latent', latent.clone().detach().cpu()), + ]) + # metadata + meta = get_meta_for_safetensors(file_item.get_latent_info_dict()) + os.makedirs(os.path.dirname(latent_path), exist_ok=True) + save_file(state_dict, latent_path, metadata=meta) + + if to_memory: + # keep it in memory + file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype) + + del imgs + del latent + del file_item.tensor + + # flush(garbage_collect=False) + file_item.is_latent_cached = True + i += 1 + # flush every 100 + # if i % 100 == 0: + # flush() + + # restore device state + self.sd.restore_device_state() + + class CLIPCachingMixin: def __init__(self: 'AiToolkitDataset', **kwargs): # if we have super, call it diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 1e0ae2ab..bac7f3fd 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -4,6 +4,7 @@ from torch.distributions import LogNormal from diffusers import FlowMatchEulerDiscreteScheduler import torch import numpy as np +from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme def calculate_shift( @@ -47,20 +48,26 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max() - # Create linear timesteps from 1000 to 0 - timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') + # Create linear timesteps from 1000 to 1 + timesteps = torch.linspace(1000, 1, num_timesteps, device='cpu') self.linear_timesteps = timesteps self.linear_timesteps_weights = bsmntw_weighing self.linear_timesteps_weights2 = hbsmntw_weighing pass - def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor: + def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False, timestep_type="linear") -> torch.Tensor: # Get the indices of the timesteps step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] # Get the weights for the timesteps + if timestep_type == "weighted": + weights = torch.tensor( + [default_weighing_scheme[i] for i in step_indices], + device=timesteps.device, + dtype=timesteps.dtype + ) if v2: weights = self.linear_timesteps_weights2[step_indices].flatten() else: @@ -106,8 +113,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): patch_size=1 ): self.timestep_type = timestep_type - if timestep_type == 'linear': - timesteps = torch.linspace(1000, 0, num_timesteps, device=device) + if timestep_type == 'linear' or timestep_type == 'weighted': + timesteps = torch.linspace(1000, 1, num_timesteps, device=device) self.timesteps = timesteps return timesteps elif timestep_type == 'sigmoid': @@ -198,7 +205,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): t1 = ((1 - t1/t1.max()) * 1000) # add half of linear - t2 = torch.linspace(1000, 0, int( + t2 = torch.linspace(1000, 1, int( num_timesteps * (1 - alpha)), device=device) timesteps = torch.cat((t1, t2)) diff --git a/toolkit/timestep_weighing/__init__.py b/toolkit/timestep_weighing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkit/timestep_weighing/default_weighing_scheme.py b/toolkit/timestep_weighing/default_weighing_scheme.py new file mode 100644 index 00000000..72699744 --- /dev/null +++ b/toolkit/timestep_weighing/default_weighing_scheme.py @@ -0,0 +1,1004 @@ +# these weights were calculated using flex.1-alpha. A similar weighing scheme has been seen with other flowmatch models as well. + +default_weighing_scheme = [ + 0.905706524848938, + 0.9097874164581299, + 0.9251009821891785, + 0.9399133920669556, + 0.9497355818748474, + 0.962195873260498, + 0.9691638946533203, + 0.9927358627319336, + 1.01659095287323, + 1.0392684936523438, + 1.0429067611694336, + 1.0677919387817383, + 1.092896580696106, + 1.1149251461029053, + 1.1015851497650146, + 1.1209120750427246, + 1.1399472951889038, + 1.1559219360351562, + 1.143755555152893, + 1.160578727722168, + 1.1761468648910522, + 1.1895899772644043, + 1.1867371797561646, + 1.1993824243545532, + 1.2113513946533203, + 1.2201216220855713, + 1.221794605255127, + 1.2329251766204834, + 1.2426395416259766, + 1.251118540763855, + 1.2562459707260132, + 1.2668883800506592, + 1.2760146856307983, + 1.2822540998458862, + 1.2892439365386963, + 1.2972776889801025, + 1.3042182922363281, + 1.3104143142700195, + 1.3190876245498657, + 1.3253265619277954, + 1.3295484781265259, + 1.3324649333953857, + 1.3436106443405151, + 1.3482736349105835, + 1.351183533668518, + 1.3556060791015625, + 1.359933614730835, + 1.3625402450561523, + 1.3641390800476074, + 1.3723753690719604, + 1.3748365640640259, + 1.3776681423187256, + 1.3802807331085205, + 1.3848408460617065, + 1.3877429962158203, + 1.390969157218933, + 1.3927756547927856, + 1.4031920433044434, + 1.4064390659332275, + 1.4097232818603516, + 1.4132285118103027, + 1.4208053350448608, + 1.4254504442214966, + 1.4294012784957886, + 1.4323071241378784, + 1.44380521774292, + 1.4465110301971436, + 1.4490883350372314, + 1.451446533203125, + 1.458960771560669, + 1.460111379623413, + 1.4603426456451416, + 1.4595434665679932, + 1.4694839715957642, + 1.470526099205017, + 1.4705318212509155, + 1.4697961807250977, + 1.477265477180481, + 1.4771337509155273, + 1.4768965244293213, + 1.4760032892227173, + 1.4871419668197632, + 1.4877341985702515, + 1.488408088684082, + 1.489357829093933, + 1.4884581565856934, + 1.4882304668426514, + 1.4878745079040527, + 1.4949846267700195, + 1.4957915544509888, + 1.4951081275939941, + 1.4950779676437378, + 1.4995663166046143, + 1.4994632005691528, + 1.4986882209777832, + 1.4976568222045898, + 1.5038518905639648, + 1.503743052482605, + 1.502982258796692, + 1.502566933631897, + 1.5068002939224243, + 1.506575584411621, + 1.5067178010940552, + 1.5050104856491089, + 1.508437991142273, + 1.5073161125183105, + 1.506223440170288, + 1.504751205444336, + 1.512485384941101, + 1.5121595859527588, + 1.5116060972213745, + 1.5102864503860474, + 1.5156408548355103, + 1.5157924890518188, + 1.5142825841903687, + 1.5141233205795288, + 1.5191627740859985, + 1.518998622894287, + 1.5177574157714844, + 1.516713261604309, + 1.5186537504196167, + 1.5179635286331177, + 1.5159885883331299, + 1.5150446891784668, + 1.5226575136184692, + 1.5215232372283936, + 1.5201103687286377, + 1.5225480794906616, + 1.5215368270874023, + 1.520209789276123, + 1.5184354782104492, + 1.5220975875854492, + 1.5209708213806152, + 1.519667625427246, + 1.5177332162857056, + 1.5203157663345337, + 1.519271731376648, + 1.5175830125808716, + 1.5161852836608887, + 1.517741322517395, + 1.5163099765777588, + 1.5150357484817505, + 1.513055682182312, + 1.5154107809066772, + 1.513993501663208, + 1.5126358270645142, + 1.510875940322876, + 1.5125701427459717, + 1.510590672492981, + 1.5086392164230347, + 1.5068501234054565, + 1.5094420909881592, + 1.5080277919769287, + 1.5060293674468994, + 1.5040230751037598, + 1.504892110824585, + 1.5031654834747314, + 1.5013214349746704, + 1.499170184135437, + 1.500420331954956, + 1.4979915618896484, + 1.4959659576416016, + 1.494127631187439, + 1.4961415529251099, + 1.494597315788269, + 1.4926577806472778, + 1.4904958009719849, + 1.4913445711135864, + 1.489889144897461, + 1.4880430698394775, + 1.4873780012130737, + 1.485144019126892, + 1.4830609560012817, + 1.48123037815094, + 1.483134150505066, + 1.4808504581451416, + 1.4793643951416016, + 1.4773707389831543, + 1.4774049520492554, + 1.4752962589263916, + 1.4733281135559082, + 1.4712235927581787, + 1.470260500907898, + 1.4684780836105347, + 1.4657323360443115, + 1.4639259576797485, + 1.4644291400909424, + 1.4621423482894897, + 1.4604870080947876, + 1.4585931301116943, + 1.458650827407837, + 1.456871747970581, + 1.4548134803771973, + 1.4528228044509888, + 1.4519706964492798, + 1.4501826763153076, + 1.448083519935608, + 1.446316123008728, + 1.4454604387283325, + 1.4432463645935059, + 1.4412567615509033, + 1.4390138387680054, + 1.4388699531555176, + 1.437005877494812, + 1.4348008632659912, + 1.4327361583709717, + 1.4320861101150513, + 1.4304683208465576, + 1.4287059307098389, + 1.4264065027236938, + 1.424929141998291, + 1.4226492643356323, + 1.4205398559570312, + 1.4207192659378052, + 1.4187930822372437, + 1.4170515537261963, + 1.415097713470459, + 1.4142191410064697, + 1.412431001663208, + 1.4104442596435547, + 1.4082318544387817, + 1.4076896905899048, + 1.4057838916778564, + 1.4034852981567383, + 1.4016139507293701, + 1.4003900289535522, + 1.3983466625213623, + 1.3962912559509277, + 1.3940908908843994, + 1.3933204412460327, + 1.3910188674926758, + 1.3888089656829834, + 1.3865182399749756, + 1.3853325843811035, + 1.3832392692565918, + 1.3812776803970337, + 1.3789170980453491, + 1.3797601461410522, + 1.3777130842208862, + 1.3756908178329468, + 1.3735847473144531, + 1.3711339235305786, + 1.3690725564956665, + 1.366849660873413, + 1.364676833152771, + 1.364690899848938, + 1.362541675567627, + 1.3608990907669067, + 1.358725666999817, + 1.3566731214523315, + 1.3549315929412842, + 1.3522361516952515, + 1.3527694940567017, + 1.3506978750228882, + 1.3486748933792114, + 1.3464853763580322, + 1.344880223274231, + 1.3429863452911377, + 1.3410741090774536, + 1.3392194509506226, + 1.3374866247177124, + 1.3356633186340332, + 1.333699345588684, + 1.3316830396652222, + 1.3293700218200684, + 1.327602744102478, + 1.325606107711792, + 1.323636770248413, + 1.322487473487854, + 1.3201935291290283, + 1.3185865879058838, + 1.3163570165634155, + 1.315348505973816, + 1.3135069608688354, + 1.3115581274032593, + 1.309351921081543, + 1.3080940246582031, + 1.306084394454956, + 1.3041918277740479, + 1.3022172451019287, + 1.30052649974823, + 1.2984906435012817, + 1.296433925628662, + 1.294618844985962, + 1.2924813032150269, + 1.290629267692566, + 1.288609504699707, + 1.286437749862671, + 1.2852808237075806, + 1.2831010818481445, + 1.281022071838379, + 1.2789161205291748, + 1.2785669565200806, + 1.2766060829162598, + 1.274585485458374, + 1.2728400230407715, + 1.2709832191467285, + 1.2691309452056885, + 1.2671318054199219, + 1.265442132949829, + 1.2635501623153687, + 1.2614946365356445, + 1.2593908309936523, + 1.257880449295044, + 1.2560313940048218, + 1.254082441329956, + 1.2522804737091064, + 1.2505096197128296, + 1.2482692003250122, + 1.2462091445922852, + 1.2445822954177856, + 1.2432236671447754, + 1.2414650917053223, + 1.2396503686904907, + 1.2376699447631836, + 1.2357380390167236, + 1.2339240312576294, + 1.2320566177368164, + 1.2299892902374268, + 1.2286840677261353, + 1.226925015449524, + 1.2250070571899414, + 1.223126769065857, + 1.2215166091918945, + 1.2196996212005615, + 1.2178195714950562, + 1.2158279418945312, + 1.2140803337097168, + 1.2121261358261108, + 1.210100769996643, + 1.2083507776260376, + 1.2063525915145874, + 1.2046012878417969, + 1.2027149200439453, + 1.201154112815857, + 1.1992254257202148, + 1.1971834897994995, + 1.1951549053192139, + 1.1935709714889526, + 1.191764235496521, + 1.1898930072784424, + 1.187896728515625, + 1.186535120010376, + 1.184600591659546, + 1.1826894283294678, + 1.1809728145599365, + 1.1789331436157227, + 1.1774684190750122, + 1.1756458282470703, + 1.1736308336257935, + 1.1719911098480225, + 1.1701127290725708, + 1.1681468486785889, + 1.165921926498413, + 1.16463041305542, + 1.1627451181411743, + 1.1608567237854004, + 1.1590938568115234, + 1.1575335264205933, + 1.1555901765823364, + 1.1538552045822144, + 1.1518657207489014, + 1.14994215965271, + 1.1481153964996338, + 1.14644455909729, + 1.1444120407104492, + 1.1428122520446777, + 1.1410313844680786, + 1.1391836404800415, + 1.137330412864685, + 1.135434627532959, + 1.1336791515350342, + 1.131978154182434, + 1.1300874948501587, + 1.128359317779541, + 1.1264809370040894, + 1.1248611211776733, + 1.122762680053711, + 1.1209162473678589, + 1.1190710067749023, + 1.1172044277191162, + 1.1158984899520874, + 1.1140459775924683, + 1.1124012470245361, + 1.110682487487793, + 1.1087219715118408, + 1.106826901435852, + 1.1050584316253662, + 1.1034021377563477, + 1.1011031866073608, + 1.0996853113174438, + 1.0978131294250488, + 1.0963127613067627, + 1.0944904088974, + 1.0927494764328003, + 1.0910944938659668, + 1.0892736911773682, + 1.0878331661224365, + 1.0860958099365234, + 1.0842169523239136, + 1.0826303958892822, + 1.0806686878204346, + 1.078961730003357, + 1.0773676633834839, + 1.0755786895751953, + 1.073934555053711, + 1.0721861124038696, + 1.0704376697540283, + 1.0689181089401245, + 1.067183256149292, + 1.0654473304748535, + 1.0637754201889038, + 1.0620981454849243, + 1.0604465007781982, + 1.0587077140808105, + 1.0570865869522095, + 1.0553107261657715, + 1.053688883781433, + 1.0520380735397339, + 1.0502020120620728, + 1.048741340637207, + 1.046962022781372, + 1.0453627109527588, + 1.0439050197601318, + 1.041886806488037, + 1.0405514240264893, + 1.0387938022613525, + 1.0370451211929321, + 1.035706877708435, + 1.03403902053833, + 1.0325669050216675, + 1.0308712720870972, + 1.0291008949279785, + 1.0275760889053345, + 1.0258709192276, + 1.024153470993042, + 1.022807240486145, + 1.0211118459701538, + 1.019489049911499, + 1.0178107023239136, + 1.0159832239151, + 1.0143824815750122, + 1.0128840208053589, + 1.0111985206604004, + 1.0098742246627808, + 1.0081874132156372, + 1.0064918994903564, + 1.0050266981124878, + 1.0036821365356445, + 1.0018991231918335, + 1.0004172325134277, + 0.9988566637039185, + 0.9969817399978638, + 0.9954714179039001, + 0.9939242005348206, + 0.9923979640007019, + 0.9910774230957031, + 0.9894015789031982, + 0.9880895614624023, + 0.9861252903938293, + 0.9846389889717102, + 0.9831112027168274, + 0.9815076589584351, + 0.9799305200576782, + 0.9784950017929077, + 0.976923942565918, + 0.975475549697876, + 0.9737277626991272, + 0.9722781181335449, + 0.9707712531089783, + 0.9693742394447327, + 0.9677569270133972, + 0.9663806557655334, + 0.9648120999336243, + 0.963326096534729, + 0.9619874358177185, + 0.9605197906494141, + 0.9590029120445251, + 0.9575618505477905, + 0.9558634757995605, + 0.9542866945266724, + 0.9530059099197388, + 0.9513764977455139, + 0.9499674439430237, + 0.948621392250061, + 0.947046160697937, + 0.945502519607544, + 0.9441988468170166, + 0.9427464604377747, + 0.9413387179374695, + 0.9397821426391602, + 0.9385508894920349, + 0.9372508525848389, + 0.9356773495674133, + 0.9340954422950745, + 0.9325379133224487, + 0.9311357140541077, + 0.9296550154685974, + 0.9283716082572937, + 0.9268398880958557, + 0.9254037141799927, + 0.9239259362220764, + 0.9225856065750122, + 0.921108603477478, + 0.9197893142700195, + 0.9185012578964233, + 0.9169778823852539, + 0.9154301881790161, + 0.9140625, + 0.9127756357192993, + 0.9113842844963074, + 0.9101965427398682, + 0.9088224172592163, + 0.9074375629425049, + 0.9061430096626282, + 0.9046499133110046, + 0.9033547043800354, + 0.9018712639808655, + 0.9006990790367126, + 0.8993589878082275, + 0.8980291485786438, + 0.8965833187103271, + 0.8953617811203003, + 0.8940249681472778, + 0.8928234577178955, + 0.8914735913276672, + 0.8900470733642578, + 0.8885773420333862, + 0.887448251247406, + 0.8860753178596497, + 0.8848751783370972, + 0.8835704326629639, + 0.8822427988052368, + 0.8808343410491943, + 0.8794860243797302, + 0.8782272338867188, + 0.876940131187439, + 0.8755697011947632, + 0.8743593096733093, + 0.8731096982955933, + 0.8717764019966125, + 0.870373547077179, + 0.869137704372406, + 0.8679963946342468, + 0.8665465116500854, + 0.8653771281242371, + 0.8643192052841187, + 0.8630129098892212, + 0.8618021011352539, + 0.8606610894203186, + 0.8596193194389343, + 0.8584977984428406, + 0.8571111559867859, + 0.8558118343353271, + 0.854767382144928, + 0.8535858392715454, + 0.8525562882423401, + 0.851208508014679, + 0.8500548601150513, + 0.8489854335784912, + 0.8476380109786987, + 0.8465084433555603, + 0.8454263806343079, + 0.8440982699394226, + 0.8429536819458008, + 0.8419493436813354, + 0.8406177759170532, + 0.8395005464553833, + 0.83843994140625, + 0.8372390866279602, + 0.836262583732605, + 0.8351759910583496, + 0.8340833187103271, + 0.8330100178718567, + 0.8318305611610413, + 0.8307360410690308, + 0.8296796083450317, + 0.8287205696105957, + 0.8275678753852844, + 0.8264811038970947, + 0.8253570795059204, + 0.8243551254272461, + 0.8232539296150208, + 0.822137176990509, + 0.8212800025939941, + 0.8199703097343445, + 0.8190608024597168, + 0.8179953098297119, + 0.8167867064476013, + 0.8158150315284729, + 0.8149182200431824, + 0.8140754699707031, + 0.8131433129310608, + 0.8118599057197571, + 0.8109708428382874, + 0.8099024891853333, + 0.8090004324913025, + 0.8079776763916016, + 0.807029664516449, + 0.8058684468269348, + 0.8049055337905884, + 0.8039948344230652, + 0.803061306476593, + 0.8021382689476013, + 0.8012913465499878, + 0.8002091646194458, + 0.7992268204689026, + 0.7981467247009277, + 0.7973214983940125, + 0.7964017987251282, + 0.7954541444778442, + 0.7945792078971863, + 0.7938122153282166, + 0.7926003932952881, + 0.7917800545692444, + 0.7908596396446228, + 0.7899304628372192, + 0.7890149354934692, + 0.7882192730903625, + 0.7870058417320251, + 0.7863731980323792, + 0.7852027416229248, + 0.7844488024711609, + 0.783501386642456, + 0.7827003598213196, + 0.7819803357124329, + 0.7808201909065247, + 0.7800688147544861, + 0.7791293263435364, + 0.7784658670425415, + 0.7775732278823853, + 0.7768633961677551, + 0.7760342359542847, + 0.775243878364563, + 0.7743030786514282, + 0.7735926508903503, + 0.7724748849868774, + 0.7718163728713989, + 0.77097487449646, + 0.7702510356903076, + 0.7693900465965271, + 0.7687169313430786, + 0.7678922414779663, + 0.7672128081321716, + 0.7663589715957642, + 0.7657037377357483, + 0.7647771239280701, + 0.7640203237533569, + 0.7633466720581055, + 0.7625623941421509, + 0.7617509961128235, + 0.7610896229743958, + 0.760379433631897, + 0.7596492767333984, + 0.7588953971862793, + 0.7581916451454163, + 0.7573999166488647, + 0.7568274736404419, + 0.756077229976654, + 0.7554765939712524, + 0.7546539306640625, + 0.7539674043655396, + 0.753139853477478, + 0.7525543570518494, + 0.7519160509109497, + 0.7513154149055481, + 0.7505142688751221, + 0.7497125864028931, + 0.74923175573349, + 0.7484207153320312, + 0.7479155659675598, + 0.7473617792129517, + 0.7468436360359192, + 0.7462318539619446, + 0.7456430792808533, + 0.7447810769081116, + 0.7442206144332886, + 0.7435954809188843, + 0.7431489825248718, + 0.7422271370887756, + 0.7418114542961121, + 0.7412892580032349, + 0.740713357925415, + 0.7401546239852905, + 0.7396021485328674, + 0.7390599846839905, + 0.73844313621521, + 0.7377904653549194, + 0.737305223941803, + 0.7368288636207581, + 0.7363747358322144, + 0.7358483076095581, + 0.7354381680488586, + 0.7348212003707886, + 0.7343763709068298, + 0.7336553335189819, + 0.7332231402397156, + 0.73262619972229, + 0.7321929931640625, + 0.7315752506256104, + 0.7312256693840027, + 0.7306149005889893, + 0.7302426695823669, + 0.7299467325210571, + 0.7294563055038452, + 0.728706419467926, + 0.7283353209495544, + 0.7279900312423706, + 0.7276231646537781, + 0.7273217439651489, + 0.7269001007080078, + 0.7265130877494812, + 0.7261000871658325, + 0.7257733345031738, + 0.725188672542572, + 0.724976658821106, + 0.7242119908332825, + 0.7238465547561646, + 0.7236427664756775, + 0.7232236266136169, + 0.7227877974510193, + 0.7226144075393677, + 0.7221262454986572, + 0.7218216061592102, + 0.7215451002120972, + 0.7213869094848633, + 0.7209206819534302, + 0.7207257747650146, + 0.7203994989395142, + 0.7200448513031006, + 0.7197679281234741, + 0.7195186018943787, + 0.7190226912498474, + 0.7188836932182312, + 0.7186117768287659, + 0.7185105681419373, + 0.718199610710144, + 0.7180152535438538, + 0.7175536155700684, + 0.7173341512680054, + 0.7171205878257751, + 0.7168837189674377, + 0.7163654565811157, + 0.7162774801254272, + 0.7161651253700256, + 0.7160663604736328, + 0.7159175872802734, + 0.7157440185546875, + 0.7154026031494141, + 0.7153436541557312, + 0.715220034122467, + 0.7150475978851318, + 0.7150062322616577, + 0.7149052619934082, + 0.7147804498672485, + 0.7147180438041687, + 0.7146442532539368, + 0.7143230438232422, + 0.7142894268035889, + 0.7143252491950989, + 0.7141361236572266, + 0.7140751481056213, + 0.7138771414756775, + 0.7138750553131104, + 0.7138450145721436, + 0.7138748168945312, + 0.7137607336044312, + 0.7137340903282166, + 0.7137055993080139, + 0.7136792540550232, + 0.7135677337646484, + 0.7134214639663696, + 0.7135778069496155, + 0.7136402130126953, + 0.713737964630127, + 0.7136131525039673, + 0.7135958075523376, + 0.713367760181427, + 0.7136869430541992, + 0.7137601971626282, + 0.7137682437896729, + 0.7137079834938049, + 0.7138195633888245, + 0.713512122631073, + 0.7136629819869995, + 0.7137271761894226, + 0.7138593792915344, + 0.714098334312439, + 0.714293360710144, + 0.7142848372459412, + 0.7144456505775452, + 0.714730978012085, + 0.7147268652915955, + 0.7149925231933594, + 0.7151434421539307, + 0.7151704430580139, + 0.7152837514877319, + 0.7152866125106812, + 0.7155521512031555, + 0.7158286571502686, + 0.7161067724227905, + 0.7161760926246643, + 0.716302752494812, + 0.7165276408195496, + 0.716739296913147, + 0.7168811559677124, + 0.7171459794044495, + 0.7174181342124939, + 0.7176154255867004, + 0.7177509069442749, + 0.7180988192558289, + 0.718157172203064, + 0.7184935212135315, + 0.7185874581336975, + 0.7188709378242493, + 0.7191057801246643, + 0.7192631959915161, + 0.7195265293121338, + 0.7197085618972778, + 0.720137357711792, + 0.7203015089035034, + 0.7204228639602661, + 0.720592737197876, + 0.7209603786468506, + 0.7212156057357788, + 0.7214911580085754, + 0.72171950340271, + 0.7221818566322327, + 0.7225285768508911, + 0.7227242588996887, + 0.7229769825935364, + 0.7232232689857483, + 0.7235181927680969, + 0.7238690853118896, + 0.7240993976593018, + 0.7244613170623779, + 0.7245984673500061, + 0.7250016331672668, + 0.7251067161560059, + 0.7255734205245972, + 0.7258168458938599, + 0.7260231375694275, + 0.7262008786201477, + 0.7266826629638672, + 0.7266356945037842, + 0.7272213697433472, + 0.7274652719497681, + 0.7279736399650574, + 0.7281084656715393, + 0.7283049821853638, + 0.7285397052764893, + 0.7289696931838989, + 0.7293713688850403, + 0.72969651222229, + 0.7298030853271484, + 0.7300551533699036, + 0.7303762435913086, + 0.7306288480758667, + 0.7307636141777039, + 0.7312272787094116, + 0.7314295172691345, + 0.7316324710845947, + 0.7314282655715942, + 0.7316331267356873, + 0.7319273352622986, + 0.732114315032959, + 0.732216477394104, + 0.7322894930839539, + 0.7325413227081299, + 0.7325369715690613, + 0.7325401902198792, + 0.7323561906814575, + 0.7322503924369812, + 0.7322250008583069, + 0.7320146560668945, + 0.7321474552154541, + 0.732187032699585, + 0.7320796251296997, + 0.7315171360969543, + 0.7313243746757507, + 0.7310792803764343, + 0.7308080196380615, + 0.7306304574012756, + 0.7296295166015625, + 0.7289745807647705, + 0.7288649082183838, + 0.7281786799430847, + 0.7277448773384094, + 0.7272322177886963, + 0.7265949845314026, + 0.725848376750946, + 0.7252488136291504, + 0.7245422601699829, + 0.723800003528595, + 0.7228619456291199, + 0.7218728065490723, + 0.720892071723938, + 0.7198144793510437, + 0.7186352610588074, + 0.717454731464386, + 0.7162171602249146, + 0.7149246335029602, + 0.7136655449867249, + 0.7121627926826477, + 0.7108365297317505, + 0.7092090249061584, + 0.7076245546340942, + 0.7060236930847168, + 0.704273521900177, + 0.7023670673370361, + 0.7007937431335449, + 0.698858916759491, + 0.696875810623169, + 0.69483882188797, + 0.6926799416542053, + 0.6907360553741455, + 0.6884570121765137, + 0.68642657995224, + 0.6837813258171082, + 0.6816286444664001, + 0.6790634989738464, + 0.6767467260360718, + 0.6742716431617737, + 0.6715688705444336, + 0.6689924001693726, + 0.6662940382957458, + 0.6634176969528198, + 0.6603904962539673, + 0.6574862599372864, + 0.6544369459152222, + 0.651530385017395, + 0.6485568284988403, + 0.6453983187675476, + 0.6423068046569824, + 0.6392328143119812, + 0.6360040307044983, + 0.6325976252555847, + 0.6291093826293945, + 0.6256147623062134, + 0.6223041415214539, + 0.6188730001449585, + 0.615329921245575, + 0.6118036508560181, + 0.6081951260566711, + 0.604539155960083, + 0.6009404063224792, + 0.5972440242767334, + 0.5937116146087646, + 0.5897051692008972, + 0.5859677195549011, + 0.5821571946144104, + 0.578381359577179, + 0.5747998952865601, + 0.5709896683692932, + 0.5671953558921814, + 0.5633583068847656, + 0.5594204068183899, + 0.5555382370948792, + 0.5519280433654785, + 0.5482934713363647, + 0.544551432132721, + 0.5410515666007996, + 0.5374910831451416, + 0.5340041518211365, + 0.5304144024848938, + 0.5269584655761719, + 0.5235306620597839, + 0.520039975643158, + 0.516674280166626, + 0.513296902179718, + 0.5098193883895874, + 0.5064578652381897, + 0.5030517578125, + 0.4997297525405884, + 0.4967145025730133, + 0.49335765838623047, + 0.4902186989784241, + 0.486634224653244, + 0.48311659693717957, + 0.4792158007621765, + 0.4755136966705322, + 0.4720709025859833, + 0.4689248502254486, + 0.4660993814468384, + 0.46355342864990234, + 0.46058982610702515, + 0.45763304829597473, + 0.45535609126091003, + 0.45405313372612, + 0.45241352915763855, + 0.45207348465919495, + 0.45095735788345337, + 0.45052871108055115, + 0.449806272983551, + 0.4484655559062958, + 0.44648951292037964, + 0.44580715894699097, + 0.4447800815105438, + 0.4453802704811096, + 0.4472601115703583 +] \ No newline at end of file diff --git a/toolkit/timestep_weighing/flex_timestep_weights_plot.png b/toolkit/timestep_weighing/flex_timestep_weights_plot.png new file mode 100644 index 00000000..b77632e2 Binary files /dev/null and b/toolkit/timestep_weighing/flex_timestep_weights_plot.png differ