mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Remove ip adapter submodule
This commit is contained in:
@@ -2,10 +2,14 @@
|
||||
# Convert Diffusers Flux/Flex to diffusion model ComfyUI safetensors file
|
||||
# This will only have the transformer weights, not the TEs and VAE
|
||||
# You can save the transformer weights as bf16 or 8-bit with the --do_8_bit flag
|
||||
# You can also save with scaled 8-bit using the --do_8bit_scaled flag
|
||||
#
|
||||
# Call like this for 8-bit transformer weights:
|
||||
# Call like this for 8-bit transformer weights with stochastic rounding:
|
||||
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8_bit
|
||||
#
|
||||
# Call like this for 8-bit transformer weights with scaling:
|
||||
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8bit_scaled
|
||||
#
|
||||
# Call like this for bf16 transformer weights:
|
||||
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors
|
||||
#
|
||||
@@ -33,7 +37,9 @@ parser.add_argument("diffusers_path", type=str,
|
||||
parser.add_argument("flux_path", type=str,
|
||||
help="Output path for the Flux safetensors file.")
|
||||
parser.add_argument("--do_8_bit", action="store_true",
|
||||
help="Use 8-bit weights instead of bf16.")
|
||||
help="Use 8-bit weights with stochastic rounding instead of bf16.")
|
||||
parser.add_argument("--do_8bit_scaled", action="store_true",
|
||||
help="Use scaled 8-bit weights instead of bf16.")
|
||||
args = parser.parse_args()
|
||||
|
||||
flux_path = Path(args.flux_path)
|
||||
@@ -43,6 +49,12 @@ if os.path.exists(os.path.join(diffusers_path, "transformer")):
|
||||
diffusers_path = Path(os.path.join(diffusers_path, "transformer"))
|
||||
|
||||
do_8_bit = args.do_8_bit
|
||||
do_8bit_scaled = args.do_8bit_scaled
|
||||
|
||||
# Don't allow both flags to be active simultaneously
|
||||
if do_8_bit and do_8bit_scaled:
|
||||
print("Error: Cannot use both --do_8_bit and --do_8bit_scaled at the same time.")
|
||||
exit()
|
||||
|
||||
if not os.path.exists(flux_path.parent):
|
||||
os.makedirs(flux_path.parent)
|
||||
@@ -373,16 +385,64 @@ def stochastic_round_to(tensor, dtype=torch.float8_e4m3fn):
|
||||
return rounded.to(dtype)
|
||||
|
||||
|
||||
# set all the keys to bf16
|
||||
# List of keys that should not be scaled (usually embedding layers and biases)
|
||||
blacklist = []
|
||||
for key in flux.keys():
|
||||
if do_8_bit:
|
||||
if not key.endswith(".weight") or "embed" in key:
|
||||
blacklist.append(key)
|
||||
|
||||
# Function to scale weights for 8-bit quantization
|
||||
def scale_weights_to_8bit(tensor, max_value=416.0, dtype=torch.float8_e4m3fn):
|
||||
# Get the limits of the dtype
|
||||
min_val = torch.finfo(dtype).min
|
||||
max_val = torch.finfo(dtype).max
|
||||
|
||||
# Only process 2D tensors that are not in the blacklist
|
||||
if tensor.dim() == 2:
|
||||
# Calculate the scaling factor
|
||||
abs_max = torch.max(torch.abs(tensor))
|
||||
scale = abs_max / max_value
|
||||
|
||||
# Scale the tensor and clip to float8 range
|
||||
scaled_tensor = (tensor / scale).clip(min=min_val, max=max_val).to(dtype)
|
||||
|
||||
return scaled_tensor, scale
|
||||
else:
|
||||
# For tensors that shouldn't be scaled, just convert to float8
|
||||
return tensor.clip(min=min_val, max=max_val).to(dtype), None
|
||||
|
||||
|
||||
# set all the keys to appropriate dtype
|
||||
if do_8_bit:
|
||||
print("Converting to 8-bit with stochastic rounding...")
|
||||
for key in flux.keys():
|
||||
flux[key] = stochastic_round_to(
|
||||
flux[key], torch.float8_e4m3fn).to('cpu')
|
||||
else:
|
||||
elif do_8bit_scaled:
|
||||
print("Converting to scaled 8-bit...")
|
||||
scales = {}
|
||||
for key in tqdm.tqdm(flux.keys()):
|
||||
if key.endswith(".weight") and key not in blacklist:
|
||||
flux[key], scale = scale_weights_to_8bit(flux[key])
|
||||
if scale is not None:
|
||||
scale_key = key[:-len(".weight")] + ".scale_weight"
|
||||
scales[scale_key] = scale
|
||||
else:
|
||||
# For non-weight tensors or blacklisted ones, just convert without scaling
|
||||
min_val = torch.finfo(torch.float8_e4m3fn).min
|
||||
max_val = torch.finfo(torch.float8_e4m3fn).max
|
||||
flux[key] = flux[key].clip(min=min_val, max=max_val).to(torch.float8_e4m3fn).to('cpu')
|
||||
|
||||
# Add all the scales to the flux dictionary
|
||||
flux.update(scales)
|
||||
|
||||
# Add a marker tensor to indicate this is a scaled fp8 model
|
||||
flux["scaled_fp8"] = torch.tensor([]).to(torch.float8_e4m3fn)
|
||||
else:
|
||||
print("Converting to bfloat16...")
|
||||
for key in flux.keys():
|
||||
flux[key] = flux[key].clone().to('cpu', torch.bfloat16)
|
||||
|
||||
|
||||
|
||||
meta = OrderedDict()
|
||||
meta['format'] = 'pt'
|
||||
# date format like 2024-08-01 YYYY-MM-DD
|
||||
@@ -394,4 +454,4 @@ print(f"Saving to {flux_path}")
|
||||
|
||||
safetensors.torch.save_file(flux, flux_path, metadata=meta)
|
||||
|
||||
print("Done.")
|
||||
print("Done.")
|
||||
Reference in New Issue
Block a user