Update Pixtral experiment

This commit is contained in:
turboderp
2024-11-10 11:17:21 +01:00
parent 193a6b2b36
commit 7c876ef091
2 changed files with 35 additions and 56 deletions

View File

@@ -2,22 +2,14 @@ import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PixtralImageProcessor,
PixtralVisionModel,
)
from PIL import Image
import requests
import safetensors
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Tokenizer,
ExLlamaV2MultimodalProjector
ExLlamaV2MultimodalProjector,
ExLlamaV2VisionTower
)
from exllamav2.generator import (
@@ -26,55 +18,42 @@ from exllamav2.generator import (
ExLlamaV2MMEmbedding
)
# Unquantized model used for this experiment:
from PIL import Image
import requests
# Get an input image
url = "https://pbs.twimg.com/media/BAeuBsnCIAAUITV.jpg:large"
image = Image.open(requests.get(url, stream = True).raw)
# Unquantized model used for experiment:
#
# https://huggingface.co/mistral-community/pixtral-12b/
model_directory = "/mnt/str/models/pixtral-12b"
config = ExLlamaV2Config(model_directory)
# PixtralVisionModel expects vision tower keys to be prefixed with "vision_encoder", but the checkpoint prefixes
# them with "vision_tower". Patch the model implementation to allow the model to load with from_pretrained.
PixtralVisionModel.base_model_prefix = "vision_tower"
config.max_seq_len = 32768 # default is 1M
# Load multimodal projector
multimodal_projector = ExLlamaV2MultimodalProjector(config)
multimodal_projector.load()
with torch.inference_mode():
# Load vision tower and preprocessor
# Initialize preprocessor, vision model and multimodal projector
vision_tower = ExLlamaV2VisionTower(config)
vision_tower.load(progress = True)
image_processor = PixtralImageProcessor.from_pretrained(model_directory, device_map = "cuda:0")
vision_model = PixtralVisionModel.from_pretrained(
model_directory,
device_map = "cuda:0",
hidden_act = "silu"
)
# Preprocess
# multimodal_projector = ExLlamaV2MultimodalProjector()
# safetensors.torch.load_model(
# multimodal_projector,
# os.path.join(model_directory, "model-00001-of-00006.safetensors"),
# strict = False,
# )
# multimodal_projector.half().to("cuda:0")
image_tensor = vision_tower.preprocess(image)
image_tensor = image_tensor.cuda()
image_size = tuple(image_tensor.shape[1:])
# Get an input image and process it
# Produce embeddings
# url = "https://i.imgur.com/JMDz9pC.jpeg"
# image = Image.open(requests.get(url, stream = True).raw)
image_path = "car2.jpg"
image = Image.open(image_path)
inputs = image_processor(image, return_tensors = "pt")
pixel_values = [inputs["pixel_values"][0][0].to("cuda:0", torch.half)]
image_features = vision_model(pixel_values)
image_features = multimodal_projector.forward(image_features.hidden_states[0].half())
image_features = image_features[0]
image_size = inputs["image_sizes"][0][0]
embeddings = vision_tower.process(image_tensor)
embeddings = multimodal_projector.forward(embeddings)[0]
# Load EXL2 model
@@ -94,12 +73,12 @@ id_end = tokenizer.single_id("[IMG_END]")
img_break = model.modules[0].forward(torch.tensor([id_break], dtype = torch.long)).to("cuda:0")
img_end = model.modules[0].forward(torch.tensor([id_end], dtype = torch.long)).to("cuda:0")
dim = image_features.shape[-1]
image_features = image_features.view((features_y, features_x, dim))
dim = embeddings.shape[-1]
embeddings = embeddings.view((features_y, features_x, dim))
break_col = img_break.expand(features_y, -1, -1)
image_features = torch.cat((image_features, break_col), dim = 1)
image_features = image_features.view((features_y * (features_x + 1)), dim)
image_features = torch.cat((image_features, img_end), dim = 0)
embeddings = torch.cat((embeddings, break_col), dim = 1)
embeddings = embeddings.view((features_y * (features_x + 1)), dim)
embeddings = torch.cat((embeddings, img_end), dim = 0)
# Create generator
@@ -111,25 +90,25 @@ generator = ExLlamaV2DynamicGenerator(
# Create an MMEmbedding for the image features and a prompt containing the placeholder string
image_tokens = ExLlamaV2MMEmbedding(
image_tokens_a = ExLlamaV2MMEmbedding(
model = model,
embeddings = image_features,
text_alias = "{{EMBED_HERE}}"
embeddings = embeddings,
text_alias = "{{EMBED_A}}"
)
prompt = "[INST] {{EMBED_HERE}}\nDescribe the image. [/INST]"
prompt = "[INST]{{EMBED_A}}\nDescribe the image.[/INST]"
# Pass embeddings to generator
output = generator.generate(
prompt = prompt,
max_new_tokens = 200,
max_new_tokens = 500,
add_bos = True,
encode_special_tokens = True,
decode_special_tokens = True,
stop_conditions = [tokenizer.eos_token_id],
# gen_settings = ExLlamaV2Sampler.Settings.greedy(),
embeddings = [image_tokens],
gen_settings = ExLlamaV2Sampler.Settings.greedy(),
embeddings = [image_tokens_a],
)
print(output)
print(output)