[fix](kt-sft): fix peft adaptations for RL tasks (#1674)

This commit is contained in:
mrhaoxx
2025-12-09 14:28:51 +08:00
committed by GitHub
parent 503295fc88
commit f992de55da
2 changed files with 24 additions and 12 deletions

View File

@@ -33,7 +33,7 @@ from peft.utils import ModulesToSaveWrapper, _get_submodules
from peft.tuners.tuners_utils import check_target_module_exists
from peft.config import PeftConfig
from ktransformers.sft.peft_utils.lora_layer import dispatch_default, LoraLayer
from ktransformers.sft.peft_utils.lora_layer import dispatch_default, LoraLayer, BaseTunerLayer
logger = logging.getLogger(__name__)
@@ -372,6 +372,28 @@ class LoraModel(nn.Module, ABC):
"""
pass
def _set_adapter_layers(self, enabled: bool = True) -> None:
for module in self.model.modules():
if isinstance(module, BaseTunerLayer):
module.enable_adapters(enabled)
def disable_adapter_layers(self) -> None:
"""
Disable all adapters in-place.
When disabling all adapters, the model output corresponds to the output of the base model.
"""
# TODO: deprecate in favor of enable_adapters
self._set_adapter_layers(enabled=False)
def enable_adapter_layers(self) -> None:
"""
Enable all adapters in-place
"""
# TODO: deprecate in favor of enable_adapters
self._set_adapter_layers(enabled=True)
# def set_adapter(self, adapter_names: str | list[str]) -> None:
# """Set the active adapter(s).

View File

@@ -42,7 +42,7 @@ from transformers.utils import PushToHubMixin
from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING
from peft.config import PeftConfig
# from .tuners.tuners_utils import BaseTuner, BaseTunerLayer
from .lora_layer import BaseTunerLayer
from peft.utils import (
SAFETENSORS_WEIGHTS_NAME,
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
@@ -1695,11 +1695,6 @@ def get_layer_status(model: torch.nn.Module) -> list[TunerLayerStatus]:
"""
if isinstance(model, PeftModel):
base_model = model.base_model
if not isinstance(base_model, BaseTuner):
raise TypeError(
"get_layer_status() got an invalid PeftModel instance; prefix tuning and adaption prompt are not "
"supported."
)
else:
base_model = model
@@ -1827,11 +1822,6 @@ def get_model_status(model: torch.nn.Module) -> TunerModelStatus:
"""
if isinstance(model, PeftModel):
if not isinstance(model.base_model, BaseTuner):
raise TypeError(
"get_model_status() got an invalid PeftModel instance; prefix tuning and adaption prompt are not "
"supported."
)
base_model_type = model.get_base_model().__class__.__name__
trainable_params, total_params = model.get_nb_trainable_parameters()
base_model = model.base_model