Files
lang-agent/lang_agent/graphs/vision_routing.py

383 lines
14 KiB
Python

"""
Vision-enabled routing graph that:
1. First passes input to qwen-plus to decide if a photo should be taken
2. If photo taken -> passes to qwen-vl-max for image description
3. If no photo -> passes back to qwen-plus for conversation response
"""
from dataclasses import dataclass, field
from typing import Type, TypedDict, List, Dict, Any, Tuple, Optional
import tyro
import base64
import json
from loguru import logger
from lang_agent.config import LLMNodeConfig
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.base import GraphBase, ToolNodeBase
from lang_agent.components.client_tool_manager import ClientToolManagerConfig
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from langchain.agents import create_agent
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode
# ==================== SYSTEM PROMPTS ====================
CAMERA_DECISION_PROMPT = """You are an intelligent assistant that decides whether to take a photo using the camera.
Your ONLY job in this step is to determine if the user's request requires taking a photo.
You should use the self_camera_take_photo tool when:
- User explicitly asks to take a photo, picture, or image
- User asks "what do you see", "what's in front of you", "look at this", "describe what you see"
- User asks about their surroundings, environment, or current scene
- User wants you to identify, recognize, or analyze something visually
- User asks questions that require visual information to answer
You should NOT use the tool when:
- User is asking general knowledge questions
- User wants to have a normal conversation
- User is asking about text, code, or non-visual topics
- The request can be answered without visual information
If you decide to take a photo, call the self_camera_take_photo tool. Otherwise, respond that no photo is needed."""
VISION_DESCRIPTION_PROMPT = """You are a highly accurate visual analysis assistant powered by qwen-vl-max.
Your task is to provide detailed, accurate descriptions of images. Focus on:
1. ACCURACY: Describe only what you can clearly see. Do not make assumptions or hallucinate details.
2. DETAIL: Include relevant details about:
- Objects and their positions
- People (if any) and their actions/expressions
- Colors, textures, and lighting
- Text visible in the image
- Environment and context
3. STRUCTURE: Organize your description logically (foreground to background, or left to right)
4. RELEVANCE: If the user asked a specific question, prioritize information relevant to that question
Be precise and factual. If something is unclear or ambiguous, say so rather than guessing."""
CONVERSATION_PROMPT = """You are a friendly, helpful conversational assistant.
Your role is to:
1. Engage naturally with the user
2. Provide helpful, accurate, and thoughtful responses
3. Be concise but thorough
4. Maintain a warm and approachable tone
5. Ask clarifying questions when needed
Focus on the quality of the conversation. Be engaging, informative, and helpful."""
# ==================== STATE DEFINITION ====================
class VisionRoutingState(TypedDict):
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
Dict[str, Dict[str, str | int]]]
messages: List[SystemMessage | HumanMessage | AIMessage]
image_base64: str | None # Captured image data
has_image: bool # Flag indicating if image was captured
# ==================== CONFIG ====================
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass
class VisionRoutingConfig(LLMNodeConfig):
_target: Type = field(default_factory=lambda: VisionRoutingGraph)
tool_llm_name: str = "qwen-flash"
"""LLM for tool decisions and conversation"""
vision_llm_name: str = "qwen-vl-max"
"""LLM for vision/image analysis"""
tool_manager_config: ToolManagerConfig = field(default_factory=ClientToolManagerConfig)
# ==================== GRAPH IMPLEMENTATION ====================
class VisionRoutingGraph(GraphBase):
def __init__(self, config: VisionRoutingConfig):
self.config = config
self._build_modules()
self.workflow = self._build_graph()
self.streamable_tags: List[List[str]] = [["vision_llm"], ["conversation_llm"]]
self.textreleaser_delay_keys = (None, None)
def _build_modules(self):
# qwen-plus for tool decisions and conversation
self.tool_llm = init_chat_model(
model=self.config.tool_llm_name,
model_provider=self.config.llm_provider,
api_key=self.config.api_key,
base_url=self.config.base_url,
temperature=0,
tags=["tool_decision_llm"]
)
# qwen-plus for conversation (2nd pass)
self.conversation_llm = init_chat_model(
model='qwen-plus',
model_provider=self.config.llm_provider,
api_key=self.config.api_key,
base_url=self.config.base_url,
temperature=0.7,
tags=["conversation_llm"]
)
# qwen-vl-max for vision (no tools)
self.vision_llm = init_chat_model(
model=self.config.vision_llm_name,
model_provider=self.config.llm_provider,
api_key=self.config.api_key,
base_url=self.config.base_url,
temperature=0,
tags=["vision_llm"],
# reasoning={
# "effort": "medium", # Can be "low", "medium", or "high"
# "summary": "auto", # Can be "auto", "concise", or "detailed"
# }
)
self.memory = MemorySaver()
# Get tools and bind to tool_llm
tool_manager: ToolManager = self.config.tool_manager_config.setup()
self.tools = tool_manager.get_tools()
# Filter to only get camera tool
self.camera_tools = [t for t in self.tools if t.name == "self_camera_take_photo"]
# Bind tools to qwen-plus only
self.tool_llm_with_tools = self.tool_llm.bind_tools(self.camera_tools)
# Create tool node for executing tools
self.tool_node = ToolNode(self.camera_tools)
# Build prompt store: DB (if pipeline_id) > hardcoded defaults
self.prompt_store = build_prompt_store(
pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
hardcoded={
"camera_decision_prompt": CAMERA_DECISION_PROMPT,
"vision_description_prompt": VISION_DESCRIPTION_PROMPT,
"conversation_prompt": CONVERSATION_PROMPT,
},
)
def _get_human_msg(self, state: VisionRoutingState) -> HumanMessage:
"""Get user message from current invocation"""
msgs = state["inp"][0]["messages"]
for msg in reversed(msgs):
if isinstance(msg, HumanMessage):
return msg
raise ValueError("No HumanMessage found in input")
def _camera_decision_call(self, state: VisionRoutingState):
"""First pass: qwen-plus decides if photo should be taken"""
human_msg = self._get_human_msg(state)
messages = [
SystemMessage(content=self.prompt_store.get("camera_decision_prompt")),
human_msg
]
response = self.tool_llm_with_tools.invoke(messages)
return {
"messages": [response],
"has_image": False,
"image_base64": None
}
def _execute_tool(self, state: VisionRoutingState):
"""Execute the camera tool if called"""
last_msg = state["messages"][-1]
if not hasattr(last_msg, "tool_calls") or not last_msg.tool_calls:
return {"has_image": False}
# Execute tool calls
tool_messages = []
image_data = None
for tool_call in last_msg.tool_calls:
if tool_call["name"] == "self_camera_take_photo":
# Find and execute the camera tool
camera_tool = next((t for t in self.camera_tools if t.name == "self_camera_take_photo"), None)
if camera_tool:
result = camera_tool.invoke(tool_call)
# Parse result to extract image
if isinstance(result, ToolMessage):
content = result.content
else:
content = result
try:
result_data = json.loads(content) if isinstance(content, str) else content
if isinstance(result_data, dict) and "image_base64" in result_data:
image_data = result_data["image_base64"]
except (json.JSONDecodeError, TypeError):
pass
tool_messages.append(
ToolMessage(content=content, tool_call_id=tool_call["id"])
)
return {
"messages": state["messages"] + tool_messages,
"has_image": image_data is not None,
"image_base64": image_data
}
def _check_image_taken(self, state: VisionRoutingState) -> str:
"""Conditional: check if image was captured"""
last_msg = state["messages"][-1]
# Check if there are tool calls
if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
return "execute_tool"
# Check if we have an image after tool execution
if state.get("has_image"):
return "vision"
return "conversation"
def _post_tool_check(self, state: VisionRoutingState) -> str:
"""Check after tool execution"""
if state.get("has_image"):
return "vision"
return "conversation"
def _vision_call(self, state: VisionRoutingState):
"""Pass image to qwen-vl-max for description"""
human_msg = self._get_human_msg(state)
image_base64 = state.get("image_base64")
if not image_base64:
logger.warning("No image data available for vision call")
return self._conversation_call(state)
# Format message with image for vision model
vision_message = HumanMessage(
content=[
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
}
},
{
"type": "text",
"text": f"User's request: {human_msg.content}\n\nPlease describe what you see and respond to the user's request."
}
]
)
messages = [
SystemMessage(content=self.prompt_store.get("vision_description_prompt")),
vision_message
]
response = self.vision_llm.invoke(messages)
return {"messages": state["messages"] + [response]}
def _conversation_call(self, state: VisionRoutingState):
"""2nd pass to qwen-plus for conversation quality"""
human_msg = self._get_human_msg(state)
messages = [
SystemMessage(content=self.prompt_store.get("conversation_prompt")),
human_msg
]
response = self.conversation_llm.invoke(messages)
return {"messages": state["messages"] + [response]}
def _build_graph(self):
builder = StateGraph(VisionRoutingState)
# Add nodes
builder.add_node("camera_decision", self._camera_decision_call)
builder.add_node("execute_tool", self._execute_tool)
builder.add_node("vision_call", self._vision_call)
builder.add_node("conversation_call", self._conversation_call)
# Add edges
builder.add_edge(START, "camera_decision")
# After camera decision, check if tool should be executed
builder.add_conditional_edges(
"camera_decision",
self._check_image_taken,
{
"execute_tool": "execute_tool",
"vision": "vision_call",
"conversation": "conversation_call"
}
)
# After tool execution, route based on whether image was captured
builder.add_conditional_edges(
"execute_tool",
self._post_tool_check,
{
"vision": "vision_call",
"conversation": "conversation_call"
}
)
# Both vision and conversation go to END
builder.add_edge("vision_call", END)
builder.add_edge("conversation_call", END)
return builder.compile()
# ==================== MAIN ====================
if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
config = VisionRoutingConfig()
graph = VisionRoutingGraph(config)
# Test with a conversation request
print("\n=== Test 1: Conversation (no photo needed) ===")
nargs = {
"messages": [
SystemMessage("You are a helpful assistant"),
HumanMessage("Hello, how are you today?")
]
}, {"configurable": {"thread_id": "1"}}
result = graph.invoke(*nargs)
print(f"Result: {result}")
# Test with a photo request
# print("\n=== Test 2: Photo request ===")
# nargs = {
# "messages": [
# SystemMessage("You are a helpful assistant"),
# HumanMessage("Take a photo and tell me what you see")
# ]
# }, {"configurable": {"thread_id": "2"}}
# result = graph.invoke(*nargs)
# print(f"\033[32mResult: {result}\033[0m")
# print(f"Result: {result}")