mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Allow control image for omnigen training and sampling
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import inspect
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import yaml
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
@@ -10,22 +8,29 @@ from toolkit.models.base_model import BaseModel
|
||||
from diffusers import AutoencoderKL
|
||||
from toolkit.basic import flush
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
from toolkit.samplers.custom_flowmatch_sampler import (
|
||||
CustomFlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from toolkit.accelerator import unwrap_model
|
||||
from optimum.quanto import freeze
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from .src.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
|
||||
from .src.models.transformers import OmniGen2Transformer2DModel
|
||||
from .src.models.transformers.repo import OmniGen2RotaryPosEmbed
|
||||
from .src.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler
|
||||
from transformers import CLIPProcessor, Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
from .src.schedulers.scheduling_flow_match_euler_discrete import (
|
||||
FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler,
|
||||
)
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
CLIPProcessor,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
import torch.nn.functional as F
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
|
||||
scheduler_config = {
|
||||
"num_train_timesteps": 1000
|
||||
}
|
||||
scheduler_config = {"num_train_timesteps": 1000}
|
||||
|
||||
BASE_MODEL_PATH = "OmniGen2/OmniGen2"
|
||||
|
||||
@@ -34,25 +39,21 @@ class OmniGen2Model(BaseModel):
|
||||
arch = "omnigen2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype='bf16',
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype="bf16",
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
device,
|
||||
model_config,
|
||||
dtype,
|
||||
custom_pipeline,
|
||||
noise_scheduler,
|
||||
**kwargs
|
||||
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
|
||||
)
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ['OmniGen2Transformer2DModel']
|
||||
self.target_lora_modules = ["OmniGen2Transformer2DModel"]
|
||||
self._control_latent = None
|
||||
|
||||
# static method to get the noise scheduler
|
||||
@staticmethod
|
||||
@@ -69,20 +70,16 @@ class OmniGen2Model(BaseModel):
|
||||
# will be updated if we detect a existing checkpoint in training folder
|
||||
model_path = self.model_config.name_or_path
|
||||
extras_path = self.model_config.extras_name_or_path
|
||||
|
||||
|
||||
scheduler = OmniGen2Model.get_train_scheduler()
|
||||
|
||||
|
||||
self.print_and_status_update("Loading Qwen2.5 VL")
|
||||
processor = CLIPProcessor.from_pretrained(
|
||||
extras_path,
|
||||
subfolder="processor",
|
||||
use_fast=True
|
||||
extras_path, subfolder="processor", use_fast=True
|
||||
)
|
||||
|
||||
|
||||
mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
extras_path,
|
||||
subfolder="mllm",
|
||||
torch_dtype=torch.bfloat16
|
||||
extras_path, subfolder="mllm", torch_dtype=torch.bfloat16
|
||||
)
|
||||
mllm.to(self.device_torch, dtype=dtype)
|
||||
if self.model_config.quantize_te:
|
||||
@@ -90,57 +87,52 @@ class OmniGen2Model(BaseModel):
|
||||
quantization_type = get_qtype(self.model_config.qtype_te)
|
||||
quantize(mllm, weights=quantization_type)
|
||||
freeze(mllm)
|
||||
|
||||
|
||||
if self.low_vram:
|
||||
# unload it for now
|
||||
mllm.to('cpu')
|
||||
|
||||
mllm.to("cpu")
|
||||
|
||||
flush()
|
||||
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
|
||||
|
||||
transformer = OmniGen2Transformer2DModel.from_pretrained(
|
||||
model_path,
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16
|
||||
model_path, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
|
||||
if not self.low_vram:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
|
||||
if self.model_config.quantize:
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
quantize(transformer, weights=quantization_type)
|
||||
freeze(transformer)
|
||||
|
||||
|
||||
if self.low_vram:
|
||||
# unload it for now
|
||||
transformer.to('cpu')
|
||||
|
||||
transformer.to("cpu")
|
||||
|
||||
flush()
|
||||
|
||||
|
||||
self.print_and_status_update("Loading vae")
|
||||
|
||||
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
extras_path,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.bfloat16
|
||||
extras_path, subfolder="vae", torch_dtype=torch.bfloat16
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
|
||||
|
||||
flush()
|
||||
self.print_and_status_update("Loading Qwen2.5 VLProcessor")
|
||||
|
||||
|
||||
flush()
|
||||
|
||||
|
||||
if self.low_vram:
|
||||
self.print_and_status_update("Moving everything to device")
|
||||
# move it all back
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
vae.to(self.device_torch, dtype=dtype)
|
||||
mllm.to(self.device_torch, dtype=dtype)
|
||||
|
||||
|
||||
# set to eval mode
|
||||
# transformer.eval()
|
||||
vae.eval()
|
||||
@@ -149,28 +141,17 @@ class OmniGen2Model(BaseModel):
|
||||
|
||||
pipe: OmniGen2Pipeline = OmniGen2Pipeline(
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
mllm=mllm,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
# pipe: OmniGen2Pipeline = OmniGen2Pipeline.from_pretrained(
|
||||
# model_path,
|
||||
# transformer=transformer,
|
||||
# vae=vae,
|
||||
# scheduler=scheduler,
|
||||
# mllm=mllm,
|
||||
# trust_remote_code=True,
|
||||
# )
|
||||
# processor = pipe.processor
|
||||
|
||||
flush()
|
||||
|
||||
|
||||
text_encoder_list = [mllm]
|
||||
tokenizer_list = [processor]
|
||||
|
||||
|
||||
|
||||
flush()
|
||||
|
||||
# save it to the model class
|
||||
@@ -179,21 +160,20 @@ class OmniGen2Model(BaseModel):
|
||||
self.tokenizer = tokenizer_list # list of tokenizers
|
||||
self.model = pipe.transformer
|
||||
self.pipeline = pipe
|
||||
|
||||
|
||||
self.freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis(
|
||||
transformer.config.axes_dim_rope,
|
||||
transformer.config.axes_lens,
|
||||
theta=10000,
|
||||
)
|
||||
|
||||
|
||||
self.print_and_status_update("Model Loaded")
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = OmniFlowMatchEuler(
|
||||
dynamic_time_shift=True,
|
||||
num_train_timesteps=1000
|
||||
dynamic_time_shift=True, num_train_timesteps=1000
|
||||
)
|
||||
|
||||
|
||||
pipeline: OmniGen2Pipeline = OmniGen2Pipeline(
|
||||
transformer=self.model,
|
||||
vae=self.vae,
|
||||
@@ -215,6 +195,17 @@ class OmniGen2Model(BaseModel):
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
input_images = []
|
||||
if gen_config.ctrl_img is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img)
|
||||
control_img = control_img.convert("RGB")
|
||||
# resize to width and height
|
||||
if control_img.size != (gen_config.width, gen_config.height):
|
||||
control_img = control_img.resize(
|
||||
(gen_config.width, gen_config.height), Image.BILINEAR
|
||||
)
|
||||
input_images = [control_img]
|
||||
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
prompt_attention_mask=conditional_embeds.attention_mask,
|
||||
@@ -224,10 +215,12 @@ class OmniGen2Model(BaseModel):
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
text_guidance_scale=gen_config.guidance_scale,
|
||||
image_guidance_scale=1.0, # reference image guidance scale. Add this for controls
|
||||
image_guidance_scale=1.0, # reference image guidance scale. Add this for controls
|
||||
latents=gen_config.latents,
|
||||
align_res=False,
|
||||
generator=generator,
|
||||
**extra
|
||||
input_images=input_images,
|
||||
**extra,
|
||||
).images[0]
|
||||
return img
|
||||
|
||||
@@ -236,18 +229,16 @@ class OmniGen2Model(BaseModel):
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
try:
|
||||
timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
||||
timestep = timestep.expand(latent_model_input.shape[0]).to(
|
||||
latent_model_input.dtype
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# optional_kwargs = {}
|
||||
# if 'ref_image_hidden_states' in set(inspect.signature(self.model.forward).parameters.keys()):
|
||||
# optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states
|
||||
|
||||
|
||||
timesteps = timestep / 1000 # convert to 0 to 1 scale
|
||||
# timestep for model starts at 0 instead of 1. So we need to reverse them
|
||||
timestep = 1 - timesteps
|
||||
@@ -257,18 +248,60 @@ class OmniGen2Model(BaseModel):
|
||||
text_embeddings.text_embeds,
|
||||
self.freqs_cis,
|
||||
text_embeddings.attention_mask,
|
||||
ref_image_hidden_states=None, # todo add ref latent ability
|
||||
ref_image_hidden_states=self._control_latent,
|
||||
)
|
||||
|
||||
return model_pred
|
||||
|
||||
|
||||
def condition_noisy_latents(
|
||||
self, latents: torch.Tensor, batch: "DataLoaderBatchDTO"
|
||||
):
|
||||
# reset the control latent
|
||||
self._control_latent = None
|
||||
with torch.no_grad():
|
||||
control_tensor = batch.control_tensor
|
||||
if control_tensor is not None:
|
||||
self.vae.to(self.device_torch)
|
||||
# we are not packed here, so we just need to pass them so we can pack them later
|
||||
control_tensor = control_tensor * 2 - 1
|
||||
control_tensor = control_tensor.to(
|
||||
self.vae_device_torch, dtype=self.torch_dtype
|
||||
)
|
||||
|
||||
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
||||
# todo, we may not need to do this, check
|
||||
if batch.tensor is not None:
|
||||
target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3]
|
||||
else:
|
||||
# When caching latents, batch.tensor is None. We get the size from the file_items instead.
|
||||
target_h = batch.file_items[0].crop_height
|
||||
target_w = batch.file_items[0].crop_width
|
||||
|
||||
if (
|
||||
control_tensor.shape[2] != target_h
|
||||
or control_tensor.shape[3] != target_w
|
||||
):
|
||||
control_tensor = F.interpolate(
|
||||
control_tensor, size=(target_h, target_w), mode="bilinear"
|
||||
)
|
||||
|
||||
control_latent = self.encode_images(control_tensor).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
self._control_latent = [
|
||||
[x.squeeze(0)]
|
||||
for x in torch.chunk(control_latent, control_latent.shape[0], dim=0)
|
||||
]
|
||||
|
||||
return latents.detach()
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [self.pipeline._apply_chat_template(_prompt) for _prompt in prompt]
|
||||
self.text_encoder_to(self.device_torch, dtype=self.torch_dtype)
|
||||
max_sequence_length = 256
|
||||
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
|
||||
prompt = prompt,
|
||||
prompt=prompt,
|
||||
do_classifier_free_guidance=False,
|
||||
device=self.device_torch,
|
||||
max_sequence_length=max_sequence_length,
|
||||
@@ -276,7 +309,7 @@ class OmniGen2Model(BaseModel):
|
||||
pe = PromptEmbeds(prompt_embeds)
|
||||
pe.attention_mask = prompt_attention_mask
|
||||
return pe
|
||||
|
||||
|
||||
def get_model_has_grad(self):
|
||||
# return from a weight if it has grad
|
||||
return False
|
||||
@@ -284,30 +317,31 @@ class OmniGen2Model(BaseModel):
|
||||
def get_te_has_grad(self):
|
||||
# assume no one wants to finetune 4 text encoders.
|
||||
return False
|
||||
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
# only save the transformer
|
||||
transformer: OmniGen2Transformer2DModel = unwrap_model(self.model)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_path, 'transformer'),
|
||||
save_directory=os.path.join(output_path, "transformer"),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
meta_path = os.path.join(output_path, "aitk_meta.yaml")
|
||||
with open(meta_path, "w") as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
def get_loss_target(self, *args, **kwargs):
|
||||
noise = kwargs.get('noise')
|
||||
batch = kwargs.get('batch')
|
||||
noise = kwargs.get("noise")
|
||||
batch = kwargs.get("batch")
|
||||
# return (noise - batch.latents).detach()
|
||||
return (batch.latents - noise).detach()
|
||||
|
||||
|
||||
def get_transformer_block_names(self) -> Optional[List[str]]:
|
||||
# omnigen2 had a few blocks for things like noise_refiner, ref_image_refiner, context_refiner, and layers.
|
||||
# lets do all but image refiner until we add it
|
||||
return ['noise_refiner', 'context_refiner', 'layers']
|
||||
# return ['layers']
|
||||
if self.model_config.model_kwargs.get("use_image_refiner", False):
|
||||
return ["noise_refiner", "context_refiner", "ref_image_refiner", "layers"]
|
||||
return ["noise_refiner", "context_refiner", "layers"]
|
||||
|
||||
def convert_lora_weights_before_save(self, state_dict):
|
||||
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
|
||||
@@ -324,7 +358,6 @@ class OmniGen2Model(BaseModel):
|
||||
new_key = key.replace("diffusion_model.", "transformer.")
|
||||
new_sd[new_key] = value
|
||||
return new_sd
|
||||
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "omnigen2"
|
||||
|
||||
|
||||
@@ -676,7 +676,8 @@ class OmniGen2Pipeline(DiffusionPipeline):
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
freqs_cis=freqs_cis,
|
||||
prompt_attention_mask=negative_prompt_attention_mask,
|
||||
ref_image_hidden_states=None,
|
||||
ref_image_hidden_states=ref_latents,
|
||||
# ref_image_hidden_states=None,
|
||||
)
|
||||
model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
|
||||
|
||||
|
||||
@@ -310,6 +310,7 @@ export default function SimpleJob({
|
||||
{ value: 'sigmoid', label: 'Sigmoid' },
|
||||
{ value: 'linear', label: 'Linear' },
|
||||
{ value: 'shift', label: 'Shift' },
|
||||
{ value: 'weighted', label: 'Weighted' },
|
||||
]}
|
||||
/>
|
||||
)}
|
||||
@@ -541,13 +542,12 @@ export default function SimpleJob({
|
||||
{ value: 'ddpm', label: 'DDPM' },
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Guidance Scale"
|
||||
value={jobConfig.config.process[0].sample.guidance_scale}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.guidance_scale')}
|
||||
placeholder="eg. 1.0"
|
||||
className="pt-2"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
@@ -579,6 +579,26 @@ export default function SimpleJob({
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
{isVideoModel && (
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Num Frames"
|
||||
value={jobConfig.config.process[0].sample.num_frames}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="FPS"
|
||||
value={jobConfig.config.process[0].sample.fps}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div>
|
||||
@@ -597,40 +617,36 @@ export default function SimpleJob({
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')}
|
||||
/>
|
||||
</div>
|
||||
{isVideoModel && (
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Num Frames"
|
||||
value={jobConfig.config.process[0].sample.num_frames}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="FPS"
|
||||
value={jobConfig.config.process[0].sample.fps}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div>
|
||||
<FormGroup label="Advanced Sampling" className="pt-2">
|
||||
<div>
|
||||
<Checkbox
|
||||
label="Skip First Sample"
|
||||
className="pt-4"
|
||||
checked={jobConfig.config.process[0].train.skip_first_sample || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.skip_first_sample')}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<Checkbox
|
||||
label="Disable Sampling"
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.disable_sampling || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.disable_sampling')}
|
||||
/>
|
||||
</div>
|
||||
</FormGroup>
|
||||
</div>
|
||||
</div>
|
||||
<FormGroup label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`} className="pt-2">
|
||||
{
|
||||
modelArch?.additionalSections?.includes('sample.ctrl_img') && (
|
||||
<div className='text-sm text-gray-100 mb-2 py-2 px-4 bg-yellow-700 rounded-lg'>
|
||||
<p className='font-semibold mb-1'>
|
||||
Control Images
|
||||
</p>
|
||||
To use control images on samples, add --ctrl_img to the prompts below.
|
||||
<br />
|
||||
Example: <code className='bg-yellow-900 p-1'>make this a cartoon --ctrl_img /path/to/image.png</code>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{modelArch?.additionalSections?.includes('sample.ctrl_img') && (
|
||||
<div className="text-sm text-gray-100 mb-2 py-2 px-4 bg-yellow-700 rounded-lg">
|
||||
<p className="font-semibold mb-1">Control Images</p>
|
||||
To use control images on samples, add --ctrl_img to the prompts below.
|
||||
<br />
|
||||
Example: <code className="bg-yellow-900 p-1">make this a cartoon --ctrl_img /path/to/image.png</code>
|
||||
</div>
|
||||
)}
|
||||
{jobConfig.config.process[0].sample.prompts.map((prompt, i) => (
|
||||
<div key={i} className="flex items-center space-x-2">
|
||||
<div className="flex-1">
|
||||
|
||||
@@ -68,6 +68,8 @@ export const defaultJobConfig: JobConfig = {
|
||||
use_ema: false,
|
||||
ema_decay: 0.99,
|
||||
},
|
||||
skip_first_sample: false,
|
||||
disable_sampling: false,
|
||||
dtype: 'bf16',
|
||||
diff_output_preservation: false,
|
||||
diff_output_preservation_multiplier: 1.0,
|
||||
|
||||
@@ -200,6 +200,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.control_path', 'sample.ctrl_img'],
|
||||
},
|
||||
].sort((a, b) => {
|
||||
// Sort by label, case-insensitive
|
||||
|
||||
@@ -110,6 +110,8 @@ export interface TrainConfig {
|
||||
optimizer_params: {
|
||||
weight_decay: number;
|
||||
};
|
||||
skip_first_sample: boolean;
|
||||
disable_sampling: boolean;
|
||||
diff_output_preservation: boolean;
|
||||
diff_output_preservation_multiplier: number;
|
||||
diff_output_preservation_class: string;
|
||||
|
||||
Reference in New Issue
Block a user