diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 4fc263e..447e9f4 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -238,7 +238,7 @@ class ExllamaV2Container(BaseModelContainer): base_seq_len = hf_model.hf_config.max_position_embeddings # Set the target seq len if present - target_seq_len = unwrap(kwargs.get("max_seq_len"), 4096) + target_seq_len = unwrap(kwargs.get("max_seq_len"), base_seq_len) # Set the rope scale self.config.scale_pos_emb = unwrap( @@ -289,16 +289,7 @@ class ExllamaV2Container(BaseModelContainer): # Set k/v cache size # cache_size is only relevant when paged mode is enabled if self.paged: - cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len) - - if cache_size < self.config.max_seq_len: - logger.warning( - f"The given cache_size ({cache_size}) is smaller than the " - "desired context length.\n" - "Overriding cache_size to max_seq_len. " - ) - - cache_size = self.config.max_seq_len + cache_size = unwrap(kwargs.get("cache_size"), 4096) # Enforce a multiple of 256 for cache size # Overestimate to ensure that the cache isn't below max_seq_len @@ -317,6 +308,13 @@ class ExllamaV2Container(BaseModelContainer): cache_size = rounded_cache_size + if self.config.max_seq_len > cache_size: + logger.warning( + f"The given max_seq_len ({self.config.max_seq_len}) is larger than the " + f"cache size and will be limited to {cache_size} tokens." + ) + self.config.max_seq_len = cache_size + # Warn user if cache size may be inadequate for CFG if cache_size < 2 * self.config.max_seq_len: logger.warning( diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 4427df4..7e8402f 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -105,12 +105,7 @@ class ExllamaV3Container(BaseModelContainer): self = cls() # Make sure ExllamaV3 is up to date - check_package_version("exllamav3", "0.0.4") - - logger.warning( - "ExllamaV3 is currently in an alpha state. " - "Please note that all config options may not work." - ) + check_package_version("exllamav3", "0.0.7") self.model_dir = model_directory self.hf_model = hf_model @@ -131,9 +126,6 @@ class ExllamaV3Container(BaseModelContainer): self.vision_model = None self.use_vision = False - # Fallback to 4096 since exl3 can't fetch from HF's config.json - self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096) - # Prepare the draft model config if necessary draft_args = unwrap(kwargs.get("draft_model"), {}) draft_model_name = draft_args.get("draft_model_name") @@ -231,11 +223,15 @@ class ExllamaV3Container(BaseModelContainer): raise RuntimeError(gpu_unsupported_message) # Cache - user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len) + user_cache_size = unwrap(kwargs.get("cache_size"), 4096) self.cache_size = self.adjust_cache_size(user_cache_size) self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16") self.cache = self.create_cache(self.cache_mode, self.model) + # Limit max_seq_len to prevent sequences larger than the cache + max_seq_len = unwrap(kwargs.get("max_seq_len"), hf_model.hf_config.max_position_embeddings) + self.max_seq_len = self.adjust_max_seq_len(max_seq_len) + # Draft cache if self.use_draft_model: # Set draft cache mode @@ -274,21 +270,11 @@ class ExllamaV3Container(BaseModelContainer): return self def adjust_cache_size(self, cache_size): - if cache_size < self.max_seq_len: - logger.warning( - f"The given cache_size ({cache_size}) is smaller than the " - "desired context length.\n" - "Overriding cache_size to max_seq_len. " - ) - - cache_size = self.max_seq_len - # Enforce a multiple of 256 for cache size # Overestimate to ensure that the cache isn't below max_seq_len cache_remainder = cache_size % 256 if cache_remainder != 0: rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1)) - logger.warning( f"The given cache size ({cache_size}) is " "not a multiple of 256.\n" @@ -298,22 +284,22 @@ class ExllamaV3Container(BaseModelContainer): cache_size = rounded_cache_size - # Warn user if cache size may be inadequate for CFG - if cache_size < 2 * self.max_seq_len: - logger.warning( - f"The given cache_size ({cache_size}) is less than 2 * max_seq_len " - "and may be too small for requests using CFG. \n" - "Ignore this warning if you do not plan on using CFG." - ) - return cache_size - def adjust_chunk_size(self, user_chunk_size: int): - chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1] - chunk_remainder = chunk_size % 256 - if chunk_remainder != 0: - rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1)) + def adjust_max_seq_len(self, max_seq_len): + if max_seq_len > self.cache_size: + logger.warning( + f"The given max_seq_len ({max_seq_len}) is larger than the cache size " + f"and will be limited to {self.cache_size} tokens." + ) + max_seq_len = self.cache_size + return max_seq_len + + def adjust_chunk_size(self, user_chunk_size: int): + chunk_size = max(256, user_chunk_size) + rounded_chunk_size = (chunk_size + 255) // 256 * 256 + if chunk_size != rounded_chunk_size: logger.warning( f"The given chunk size ({chunk_size}) is " "not a multiple of 256.\n" @@ -950,24 +936,18 @@ class ExllamaV3Container(BaseModelContainer): context_len = input_ids[0].size(dim=-1) # Automatically set max_tokens to fill up the context - # This should be an OK default, but may be changed in the future max_tokens = unwrap( - params.max_tokens, - self.max_seq_len - context_len, + params.max_tokens if params.max_tokens > 0 else None, + self.max_seq_len - context_len - 1, ) if max_tokens < 1: logger.warning("max_tokens must be a positive integer, setting to 1.") max_tokens = 1 - # Determine if the negative context or the context length is bigger - context_to_check = context_len - # Check total length of prompt against max context length - if context_to_check > self.max_seq_len: - preamble = "Prompt" - + if context_len > self.max_seq_len: raise ValueError( - f"{preamble} length {context_to_check} is greater than " + f"Prompt length {context_len} is greater than " f"max_seq_len {self.max_seq_len}" ) diff --git a/common/model.py b/common/model.py index 16138a3..1eea5eb 100644 --- a/common/model.py +++ b/common/model.py @@ -157,10 +157,8 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): # Override the max sequence length based on user max_seq_len = kwargs.get("max_seq_len") - if max_seq_len == -1: + if max_seq_len == -1 or max_seq_len is None: kwargs["max_seq_len"] = hf_model.hf_config.max_position_embeddings - elif max_seq_len is None: - kwargs["max_seq_len"] = 4096 # Create a new container and check if the right dependencies are installed backend = unwrap(kwargs.get("backend"), detect_backend(hf_model)) diff --git a/config_sample.yml b/config_sample.yml index a294cfc..7781083 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -78,8 +78,7 @@ model: # Options: exllamav2, exllamav3 backend: - # Max sequence length (default: 4096). - # Set to -1 to fetch from the model's config.json + # Max sequence length (default: fetch from the model's config.json). max_seq_len: # Load model with tensor parallelism. @@ -124,9 +123,8 @@ model: # For exllamav3, specify the pair k_bits,v_bits where k_bits and v_bits are integers from 2-8 (i.e. 8,8). cache_mode: FP16 - # Size of the prompt cache to allocate (default: max_seq_len). - # Must be a multiple of 256 and can't be less than max_seq_len. - # For CFG, set this to 2 * max_seq_len. + # Size of the key/value cache to allocate, in tokens (default: 4096). + # Must be a multiple of 256. cache_size: # Chunk size for prompt ingestion (default: 2048). diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 202b74b..c9fbf1b 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -85,7 +85,7 @@ class ModelLoadRequest(BaseModel): examples=[4096], ) cache_size: Optional[int] = Field( - description=("Number in tokens, must be greater than or equal to max_seq_len"), + description="Number in tokens, must be multiple of 256", default=None, examples=[4096], )