mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 06:18:59 +00:00
[fix](kt-sft): fix peft adaptations for RL tasks (#1674)
This commit is contained in:
@@ -33,7 +33,7 @@ from peft.utils import ModulesToSaveWrapper, _get_submodules
|
||||
from peft.tuners.tuners_utils import check_target_module_exists
|
||||
from peft.config import PeftConfig
|
||||
|
||||
from ktransformers.sft.peft_utils.lora_layer import dispatch_default, LoraLayer
|
||||
from ktransformers.sft.peft_utils.lora_layer import dispatch_default, LoraLayer, BaseTunerLayer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -372,6 +372,28 @@ class LoraModel(nn.Module, ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def _set_adapter_layers(self, enabled: bool = True) -> None:
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.enable_adapters(enabled)
|
||||
|
||||
def disable_adapter_layers(self) -> None:
|
||||
"""
|
||||
Disable all adapters in-place.
|
||||
|
||||
When disabling all adapters, the model output corresponds to the output of the base model.
|
||||
"""
|
||||
# TODO: deprecate in favor of enable_adapters
|
||||
self._set_adapter_layers(enabled=False)
|
||||
|
||||
def enable_adapter_layers(self) -> None:
|
||||
"""
|
||||
Enable all adapters in-place
|
||||
"""
|
||||
# TODO: deprecate in favor of enable_adapters
|
||||
self._set_adapter_layers(enabled=True)
|
||||
|
||||
|
||||
# def set_adapter(self, adapter_names: str | list[str]) -> None:
|
||||
# """Set the active adapter(s).
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ from transformers.utils import PushToHubMixin
|
||||
from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING
|
||||
|
||||
from peft.config import PeftConfig
|
||||
# from .tuners.tuners_utils import BaseTuner, BaseTunerLayer
|
||||
from .lora_layer import BaseTunerLayer
|
||||
from peft.utils import (
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
|
||||
@@ -1695,11 +1695,6 @@ def get_layer_status(model: torch.nn.Module) -> list[TunerLayerStatus]:
|
||||
"""
|
||||
if isinstance(model, PeftModel):
|
||||
base_model = model.base_model
|
||||
if not isinstance(base_model, BaseTuner):
|
||||
raise TypeError(
|
||||
"get_layer_status() got an invalid PeftModel instance; prefix tuning and adaption prompt are not "
|
||||
"supported."
|
||||
)
|
||||
else:
|
||||
base_model = model
|
||||
|
||||
@@ -1827,11 +1822,6 @@ def get_model_status(model: torch.nn.Module) -> TunerModelStatus:
|
||||
|
||||
"""
|
||||
if isinstance(model, PeftModel):
|
||||
if not isinstance(model.base_model, BaseTuner):
|
||||
raise TypeError(
|
||||
"get_model_status() got an invalid PeftModel instance; prefix tuning and adaption prompt are not "
|
||||
"supported."
|
||||
)
|
||||
base_model_type = model.get_base_model().__class__.__name__
|
||||
trainable_params, total_params = model.get_nb_trainable_parameters()
|
||||
base_model = model.base_model
|
||||
|
||||
Reference in New Issue
Block a user