mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added reference adapters, many bug fixes, more ip adapter work and customizability
This commit is contained in:
@@ -31,12 +31,18 @@ def get_mean_std(tensor):
|
||||
def adain(content_features, style_features):
|
||||
# Assumes that the content and style features are of shape (batch_size, channels, width, height)
|
||||
|
||||
dims = [2, 3]
|
||||
if len(content_features.shape) == 3:
|
||||
# content_features = content_features.unsqueeze(0)
|
||||
# style_features = style_features.unsqueeze(0)
|
||||
dims = [1]
|
||||
|
||||
# Step 1: Calculate mean and variance of content features
|
||||
content_mean, content_var = torch.mean(content_features, dim=[2, 3], keepdim=True), torch.var(content_features,
|
||||
dim=[2, 3],
|
||||
content_mean, content_var = torch.mean(content_features, dim=dims, keepdim=True), torch.var(content_features,
|
||||
dim=dims,
|
||||
keepdim=True)
|
||||
# Step 2: Calculate mean and variance of style features
|
||||
style_mean, style_var = torch.mean(style_features, dim=[2, 3], keepdim=True), torch.var(style_features, dim=[2, 3],
|
||||
style_mean, style_var = torch.mean(style_features, dim=dims, keepdim=True), torch.var(style_features, dim=dims,
|
||||
keepdim=True)
|
||||
|
||||
# Step 3: Normalize content features
|
||||
|
||||
Reference in New Issue
Block a user