Complete reqork of how slider training works and optimized it to hell. Can run entire algorythm in 1 batch now with less VRAM consumption than a quarter of it used to take

This commit is contained in:
Jaret Burkett
2023-08-05 18:46:08 -06:00
parent 7e4e660663
commit 8c90fa86c6
10 changed files with 944 additions and 379 deletions

View File

@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import numpy as np
from torch.utils.checkpoint import checkpoint
class ReductionKernel(nn.Module):
@@ -29,3 +30,15 @@ class ReductionKernel(nn.Module):
def forward(self, x):
return nn.functional.conv2d(x, self.kernel, stride=self.kernel_size, padding=0, groups=1)
class CheckpointGradients(nn.Module):
def __init__(self, is_gradient_checkpointing=True):
super(CheckpointGradients, self).__init__()
self.is_gradient_checkpointing = is_gradient_checkpointing
def forward(self, module, *args, num_chunks=1):
if self.is_gradient_checkpointing:
return checkpoint(module, *args, num_chunks=self.num_chunks)
else:
return module(*args)