8 bit training working on flux

This commit is contained in:
Jaret Burkett
2024-08-06 11:53:27 -06:00
parent 272c8608c2
commit c2424087d6
7 changed files with 82 additions and 31 deletions

View File

@@ -28,12 +28,14 @@ RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers
# diffusers specific stuff
LINEAR_MODULES = [
'Linear',
'LoRACompatibleLinear'
'LoRACompatibleLinear',
'QLinear',
# 'GroupNorm',
]
CONV_MODULES = [
'Conv2d',
'LoRACompatibleConv'
'LoRACompatibleConv',
'QConv2d',
]
class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):