mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Complete reqork of how slider training works and optimized it to hell. Can run entire algorythm in 1 batch now with less VRAM consumption than a quarter of it used to take
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
class ReductionKernel(nn.Module):
|
||||
@@ -29,3 +30,15 @@ class ReductionKernel(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return nn.functional.conv2d(x, self.kernel, stride=self.kernel_size, padding=0, groups=1)
|
||||
|
||||
|
||||
class CheckpointGradients(nn.Module):
|
||||
def __init__(self, is_gradient_checkpointing=True):
|
||||
super(CheckpointGradients, self).__init__()
|
||||
self.is_gradient_checkpointing = is_gradient_checkpointing
|
||||
|
||||
def forward(self, module, *args, num_chunks=1):
|
||||
if self.is_gradient_checkpointing:
|
||||
return checkpoint(module, *args, num_chunks=self.num_chunks)
|
||||
else:
|
||||
return module(*args)
|
||||
|
||||
Reference in New Issue
Block a user