diff --git a/model.py b/model.py index 4e420a2..fb222d0 100644 --- a/model.py +++ b/model.py @@ -204,7 +204,12 @@ class ModelContainer: 'temperature' (float): Sampling temperature (default: 0.8) 'top_k' (int): Sampling top-K (default: 100) 'top_p' (float): Sampling top-P (default: 0.8) + 'min_p' (float): Sampling min-P (default: 0.0) + 'tfs' (float): Tail-free sampling (default: 0.0) 'typical' (float): Sampling typical (default: 0.0) + 'mirostat' (bool): Use Mirostat (default: False) + 'mirostat_tau' (float) Mirostat tau parameter (default: 1.5) + 'mirostat_eta' (float) Mirostat eta parameter (default: 0.1) 'token_repetition_penalty' (float): Token repetition/presence penalty (default: 1.15) 'token_repetition_range' (int): Repetition penalty range (default: whole context) 'token_repetition_decay' (int): Repetition penalty range (default: same as range) @@ -228,7 +233,12 @@ class ModelContainer: gen_settings.temperature = kwargs.get("temperature", 0.8) gen_settings.top_k = kwargs.get("top_k", 100) gen_settings.top_p = kwargs.get("top_p", 0.8) + gen_settings.min_p = kwargs.get("min_p", 0.0) + gen_settings.tfs = kwargs.get("tfs", 0.0) gen_settings.typical = kwargs.get("typical", 0.0) + gen_settings.mirostat = kwargs.get("mirostat", False) + gen_settings.mirostat_tau = kwargs.get("mirostat_tau", 1.5) + gen_settings.mirostat_eta = kwargs.get("mirostat_eta", 0.1) gen_settings.token_repetition_penalty = kwargs.get("token_repetition_penalty", 1.15) gen_settings.token_repetition_range = kwargs.get("token_repetition_range", self.config.max_seq_len) gen_settings.token_repetition_decay = kwargs.get("token_repetition_decay", gen_settings.token_repetition_range)