mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added better optimizer chooised and param support
This commit is contained in:
@@ -50,7 +50,8 @@ class Critic:
|
||||
lambda_gp=10,
|
||||
start_step=0,
|
||||
warmup_steps=1000,
|
||||
process=None
|
||||
process=None,
|
||||
optimizer_params=None,
|
||||
):
|
||||
self.learning_rate = learning_rate
|
||||
self.device = device
|
||||
@@ -65,6 +66,10 @@ class Critic:
|
||||
self.warmup_steps = warmup_steps
|
||||
self.start_step = start_step
|
||||
self.lambda_gp = lambda_gp
|
||||
|
||||
if optimizer_params is None:
|
||||
optimizer_params = {}
|
||||
self.optimizer_params = optimizer_params
|
||||
self.print = self.process.print
|
||||
print(f" Critic config: {self.__dict__}")
|
||||
|
||||
@@ -75,7 +80,8 @@ class Critic:
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
params = self.model.parameters()
|
||||
self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate)
|
||||
self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
|
||||
optimizer_params=self.optimizer_params)
|
||||
self.scheduler = torch.optim.lr_scheduler.ConstantLR(
|
||||
self.optimizer,
|
||||
total_iters=self.process.max_steps * self.num_critic_per_gen,
|
||||
@@ -196,6 +202,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
|
||||
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
|
||||
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
|
||||
self.optimizer_params = self.get_conf('optimizer_params', {})
|
||||
|
||||
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
@@ -342,7 +349,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
def get_pattern_loss(self, pred, target):
|
||||
if self._pattern_loss is None:
|
||||
self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device, dtype=self.torch_dtype)
|
||||
self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device,
|
||||
dtype=self.torch_dtype)
|
||||
loss = torch.mean(self._pattern_loss(pred, target))
|
||||
return loss
|
||||
|
||||
@@ -504,7 +512,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if self.use_critic:
|
||||
self.critic.setup()
|
||||
|
||||
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate)
|
||||
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
|
||||
optimizer_params=self.optimizer_params)
|
||||
|
||||
# setup scheduler
|
||||
# todo allow other schedulers
|
||||
|
||||
Reference in New Issue
Block a user