mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
Diffirential guidance working, but I may have a better way
This commit is contained in:
@@ -11,3 +11,25 @@ def flush(garbage_collect=True):
|
||||
torch.cuda.empty_cache()
|
||||
if garbage_collect:
|
||||
gc.collect()
|
||||
|
||||
|
||||
def adain(content_features, style_features):
|
||||
# Assumes that the content and style features are of shape (batch_size, channels, width, height)
|
||||
|
||||
# Step 1: Calculate mean and variance of content features
|
||||
content_mean, content_var = torch.mean(content_features, dim=[2, 3], keepdim=True), torch.var(content_features,
|
||||
dim=[2, 3],
|
||||
keepdim=True)
|
||||
# Step 2: Calculate mean and variance of style features
|
||||
style_mean, style_var = torch.mean(style_features, dim=[2, 3], keepdim=True), torch.var(style_features, dim=[2, 3],
|
||||
keepdim=True)
|
||||
|
||||
# Step 3: Normalize content features
|
||||
content_std = torch.sqrt(content_var + 1e-5)
|
||||
normalized_content = (content_features - content_mean) / content_std
|
||||
|
||||
# Step 4: Scale and shift normalized content with style's statistics
|
||||
style_std = torch.sqrt(style_var + 1e-5)
|
||||
stylized_content = normalized_content * style_std + style_mean
|
||||
|
||||
return stylized_content
|
||||
|
||||
Reference in New Issue
Block a user