Files
ai-toolkit/toolkit/basic.py

36 lines
1.5 KiB
Python

import gc
import torch
def value_map(inputs, min_in, max_in, min_out, max_out):
return (inputs - min_in) * (max_out - min_out) / (max_in - min_in) + min_out
def flush(garbage_collect=True):
torch.cuda.empty_cache()
if garbage_collect:
gc.collect()
def adain(content_features, style_features):
# Assumes that the content and style features are of shape (batch_size, channels, width, height)
# Step 1: Calculate mean and variance of content features
content_mean, content_var = torch.mean(content_features, dim=[2, 3], keepdim=True), torch.var(content_features,
dim=[2, 3],
keepdim=True)
# Step 2: Calculate mean and variance of style features
style_mean, style_var = torch.mean(style_features, dim=[2, 3], keepdim=True), torch.var(style_features, dim=[2, 3],
keepdim=True)
# Step 3: Normalize content features
content_std = torch.sqrt(content_var + 1e-5)
normalized_content = (content_features - content_mean) / content_std
# Step 4: Scale and shift normalized content with style's statistics
style_std = torch.sqrt(style_var + 1e-5)
stylized_content = normalized_content * style_std + style_mean
return stylized_content