Remove ip adapter submodule

This commit is contained in:
Jaret Burkett
2025-04-18 09:59:42 -06:00
parent c90615f8bb
commit 5f312cd46b
13 changed files with 709 additions and 70 deletions

View File

@@ -2,10 +2,14 @@
# Convert Diffusers Flux/Flex to diffusion model ComfyUI safetensors file
# This will only have the transformer weights, not the TEs and VAE
# You can save the transformer weights as bf16 or 8-bit with the --do_8_bit flag
# You can also save with scaled 8-bit using the --do_8bit_scaled flag
#
# Call like this for 8-bit transformer weights:
# Call like this for 8-bit transformer weights with stochastic rounding:
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8_bit
#
# Call like this for 8-bit transformer weights with scaling:
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8bit_scaled
#
# Call like this for bf16 transformer weights:
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors
#
@@ -33,7 +37,9 @@ parser.add_argument("diffusers_path", type=str,
parser.add_argument("flux_path", type=str,
help="Output path for the Flux safetensors file.")
parser.add_argument("--do_8_bit", action="store_true",
help="Use 8-bit weights instead of bf16.")
help="Use 8-bit weights with stochastic rounding instead of bf16.")
parser.add_argument("--do_8bit_scaled", action="store_true",
help="Use scaled 8-bit weights instead of bf16.")
args = parser.parse_args()
flux_path = Path(args.flux_path)
@@ -43,6 +49,12 @@ if os.path.exists(os.path.join(diffusers_path, "transformer")):
diffusers_path = Path(os.path.join(diffusers_path, "transformer"))
do_8_bit = args.do_8_bit
do_8bit_scaled = args.do_8bit_scaled
# Don't allow both flags to be active simultaneously
if do_8_bit and do_8bit_scaled:
print("Error: Cannot use both --do_8_bit and --do_8bit_scaled at the same time.")
exit()
if not os.path.exists(flux_path.parent):
os.makedirs(flux_path.parent)
@@ -373,16 +385,64 @@ def stochastic_round_to(tensor, dtype=torch.float8_e4m3fn):
return rounded.to(dtype)
# set all the keys to bf16
# List of keys that should not be scaled (usually embedding layers and biases)
blacklist = []
for key in flux.keys():
if do_8_bit:
if not key.endswith(".weight") or "embed" in key:
blacklist.append(key)
# Function to scale weights for 8-bit quantization
def scale_weights_to_8bit(tensor, max_value=416.0, dtype=torch.float8_e4m3fn):
# Get the limits of the dtype
min_val = torch.finfo(dtype).min
max_val = torch.finfo(dtype).max
# Only process 2D tensors that are not in the blacklist
if tensor.dim() == 2:
# Calculate the scaling factor
abs_max = torch.max(torch.abs(tensor))
scale = abs_max / max_value
# Scale the tensor and clip to float8 range
scaled_tensor = (tensor / scale).clip(min=min_val, max=max_val).to(dtype)
return scaled_tensor, scale
else:
# For tensors that shouldn't be scaled, just convert to float8
return tensor.clip(min=min_val, max=max_val).to(dtype), None
# set all the keys to appropriate dtype
if do_8_bit:
print("Converting to 8-bit with stochastic rounding...")
for key in flux.keys():
flux[key] = stochastic_round_to(
flux[key], torch.float8_e4m3fn).to('cpu')
else:
elif do_8bit_scaled:
print("Converting to scaled 8-bit...")
scales = {}
for key in tqdm.tqdm(flux.keys()):
if key.endswith(".weight") and key not in blacklist:
flux[key], scale = scale_weights_to_8bit(flux[key])
if scale is not None:
scale_key = key[:-len(".weight")] + ".scale_weight"
scales[scale_key] = scale
else:
# For non-weight tensors or blacklisted ones, just convert without scaling
min_val = torch.finfo(torch.float8_e4m3fn).min
max_val = torch.finfo(torch.float8_e4m3fn).max
flux[key] = flux[key].clip(min=min_val, max=max_val).to(torch.float8_e4m3fn).to('cpu')
# Add all the scales to the flux dictionary
flux.update(scales)
# Add a marker tensor to indicate this is a scaled fp8 model
flux["scaled_fp8"] = torch.tensor([]).to(torch.float8_e4m3fn)
else:
print("Converting to bfloat16...")
for key in flux.keys():
flux[key] = flux[key].clone().to('cpu', torch.bfloat16)
meta = OrderedDict()
meta['format'] = 'pt'
# date format like 2024-08-01 YYYY-MM-DD
@@ -394,4 +454,4 @@ print(f"Saving to {flux_path}")
safetensors.torch.save_file(flux, flux_path, metadata=meta)
print("Done.")
print("Done.")