Differential Revision: D90237984 fbshipit-source-id: 526fd760f303bf31be4f743bdcd77760496de0de
208 lines
7.5 KiB
Python
208 lines
7.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
|
|
# pyre-unsafe
|
|
|
|
import base64
|
|
import os
|
|
from typing import Any, Optional
|
|
|
|
from openai import OpenAI
|
|
|
|
|
|
def get_image_base64_and_mime(image_path):
|
|
"""Convert image file to base64 string and get MIME type"""
|
|
try:
|
|
# Get MIME type based on file extension
|
|
ext = os.path.splitext(image_path)[1].lower()
|
|
mime_types = {
|
|
".jpg": "image/jpeg",
|
|
".jpeg": "image/jpeg",
|
|
".png": "image/png",
|
|
".gif": "image/gif",
|
|
".webp": "image/webp",
|
|
".bmp": "image/bmp",
|
|
}
|
|
mime_type = mime_types.get(ext, "image/jpeg") # Default to JPEG
|
|
|
|
# Convert image to base64
|
|
with open(image_path, "rb") as image_file:
|
|
base64_data = base64.b64encode(image_file.read()).decode("utf-8")
|
|
return base64_data, mime_type
|
|
except Exception as e:
|
|
print(f"Error converting image to base64: {e}")
|
|
return None, None
|
|
|
|
|
|
def send_generate_request(
|
|
messages,
|
|
server_url=None,
|
|
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
|
api_key=None,
|
|
max_tokens=4096,
|
|
):
|
|
"""
|
|
Sends a request to the OpenAI-compatible API endpoint using the OpenAI client library.
|
|
|
|
Args:
|
|
server_url (str): The base URL of the server, e.g. "http://127.0.0.1:8000"
|
|
messages (list): A list of message dicts, each containing role and content.
|
|
model (str): The model to use for generation (default: "llama-4")
|
|
max_tokens (int): Maximum number of tokens to generate (default: 4096)
|
|
|
|
Returns:
|
|
str: The generated response text from the server.
|
|
"""
|
|
# Process messages to convert image paths to base64
|
|
processed_messages = []
|
|
for message in messages:
|
|
processed_message = message.copy()
|
|
if message["role"] == "user" and "content" in message:
|
|
processed_content = []
|
|
for c in message["content"]:
|
|
if isinstance(c, dict) and c.get("type") == "image":
|
|
# Convert image path to base64 format
|
|
image_path = c["image"]
|
|
|
|
print("image_path", image_path)
|
|
new_image_path = image_path.replace(
|
|
"?", "%3F"
|
|
) # Escape ? in the path
|
|
|
|
# Read the image file and convert to base64
|
|
try:
|
|
base64_image, mime_type = get_image_base64_and_mime(
|
|
new_image_path
|
|
)
|
|
if base64_image is None:
|
|
print(
|
|
f"Warning: Could not convert image to base64: {new_image_path}"
|
|
)
|
|
continue
|
|
|
|
# Create the proper image_url structure with base64 data
|
|
processed_content.append(
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:{mime_type};base64,{base64_image}",
|
|
"detail": "high",
|
|
},
|
|
}
|
|
)
|
|
|
|
except FileNotFoundError:
|
|
print(f"Warning: Image file not found: {new_image_path}")
|
|
continue
|
|
except Exception as e:
|
|
print(f"Warning: Error processing image {new_image_path}: {e}")
|
|
continue
|
|
else:
|
|
processed_content.append(c)
|
|
|
|
processed_message["content"] = processed_content
|
|
processed_messages.append(processed_message)
|
|
|
|
# Create OpenAI client with custom base URL
|
|
client = OpenAI(api_key=api_key, base_url=server_url)
|
|
|
|
try:
|
|
print(f"🔍 Calling model {model}...")
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=processed_messages,
|
|
max_completion_tokens=max_tokens,
|
|
n=1,
|
|
)
|
|
# print(f"Received response: {response.choices[0].message}")
|
|
|
|
# Extract the response content
|
|
if response.choices and len(response.choices) > 0:
|
|
return response.choices[0].message.content
|
|
else:
|
|
print(f"Unexpected response format: {response}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
print(f"Request failed: {e}")
|
|
return None
|
|
|
|
|
|
def send_direct_request(
|
|
llm: Any,
|
|
messages: list[dict[str, Any]],
|
|
sampling_params: Any,
|
|
) -> Optional[str]:
|
|
"""
|
|
Run inference on a vLLM model instance directly without using a server.
|
|
|
|
Args:
|
|
llm: Initialized vLLM LLM instance (passed from external initialization)
|
|
messages: List of message dicts with role and content (OpenAI format)
|
|
sampling_params: vLLM SamplingParams instance (initialized externally)
|
|
|
|
Returns:
|
|
str: Generated response text, or None if inference fails
|
|
"""
|
|
try:
|
|
# Process messages to handle images (convert to base64 if needed)
|
|
processed_messages = []
|
|
for message in messages:
|
|
processed_message = message.copy()
|
|
if message["role"] == "user" and "content" in message:
|
|
processed_content = []
|
|
for c in message["content"]:
|
|
if isinstance(c, dict) and c.get("type") == "image":
|
|
# Convert image path to base64 format
|
|
image_path = c["image"]
|
|
new_image_path = image_path.replace("?", "%3F")
|
|
|
|
try:
|
|
base64_image, mime_type = get_image_base64_and_mime(
|
|
new_image_path
|
|
)
|
|
if base64_image is None:
|
|
print(
|
|
f"Warning: Could not convert image: {new_image_path}"
|
|
)
|
|
continue
|
|
|
|
# vLLM expects image_url format
|
|
processed_content.append(
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:{mime_type};base64,{base64_image}"
|
|
},
|
|
}
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
f"Warning: Error processing image {new_image_path}: {e}"
|
|
)
|
|
continue
|
|
else:
|
|
processed_content.append(c)
|
|
|
|
processed_message["content"] = processed_content
|
|
processed_messages.append(processed_message)
|
|
|
|
print("🔍 Running direct inference with vLLM...")
|
|
|
|
# Run inference using vLLM's chat interface
|
|
outputs = llm.chat(
|
|
messages=processed_messages,
|
|
sampling_params=sampling_params,
|
|
)
|
|
|
|
# Extract the generated text from the first output
|
|
if outputs and len(outputs) > 0:
|
|
generated_text = outputs[0].outputs[0].text
|
|
return generated_text
|
|
else:
|
|
print(f"Unexpected output format: {outputs}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
print(f"Direct inference failed: {e}")
|
|
return None
|