unified constants
This commit is contained in:
@@ -48,6 +48,7 @@ You should NOT use the tool when:
|
||||
|
||||
If you decide to take a photo, call the self_camera_take_photo tool. Otherwise, respond that no photo is needed."""
|
||||
|
||||
|
||||
VISION_DESCRIPTION_PROMPT = """You are a highly accurate visual analysis assistant powered by qwen-vl-max.
|
||||
|
||||
Your task is to provide detailed, accurate descriptions of images. Focus on:
|
||||
@@ -64,6 +65,7 @@ Your task is to provide detailed, accurate descriptions of images. Focus on:
|
||||
|
||||
Be precise and factual. If something is unclear or ambiguous, say so rather than guessing."""
|
||||
|
||||
|
||||
CONVERSATION_PROMPT = """You are a friendly, helpful conversational assistant.
|
||||
|
||||
Your role is to:
|
||||
@@ -78,9 +80,11 @@ Focus on the quality of the conversation. Be engaging, informative, and helpful.
|
||||
|
||||
# ==================== STATE DEFINITION ====================
|
||||
|
||||
|
||||
class VisionRoutingState(TypedDict):
|
||||
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
|
||||
Dict[str, Dict[str, str | int]]]
|
||||
inp: Tuple[
|
||||
Dict[str, List[SystemMessage | HumanMessage]], Dict[str, Dict[str, str | int]]
|
||||
]
|
||||
messages: List[SystemMessage | HumanMessage | AIMessage]
|
||||
image_base64: str | None # Captured image data
|
||||
has_image: bool # Flag indicating if image was captured
|
||||
@@ -88,6 +92,7 @@ class VisionRoutingState(TypedDict):
|
||||
|
||||
# ==================== CONFIG ====================
|
||||
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
class VisionRoutingConfig(LLMNodeConfig):
|
||||
@@ -99,11 +104,14 @@ class VisionRoutingConfig(LLMNodeConfig):
|
||||
vision_llm_name: str = "qwen-vl-max"
|
||||
"""LLM for vision/image analysis"""
|
||||
|
||||
tool_manager_config: ToolManagerConfig = field(default_factory=ClientToolManagerConfig)
|
||||
tool_manager_config: ToolManagerConfig = field(
|
||||
default_factory=ClientToolManagerConfig
|
||||
)
|
||||
|
||||
|
||||
# ==================== GRAPH IMPLEMENTATION ====================
|
||||
|
||||
|
||||
class VisionRoutingGraph(GraphBase):
|
||||
def __init__(self, config: VisionRoutingConfig):
|
||||
self.config = config
|
||||
@@ -120,19 +128,19 @@ class VisionRoutingGraph(GraphBase):
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
temperature=0,
|
||||
tags=["tool_decision_llm"]
|
||||
tags=["tool_decision_llm"],
|
||||
)
|
||||
|
||||
|
||||
# qwen-plus for conversation (2nd pass)
|
||||
self.conversation_llm = init_chat_model(
|
||||
model='qwen-plus',
|
||||
model="qwen-plus",
|
||||
model_provider=self.config.llm_provider,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
temperature=0.7,
|
||||
tags=["conversation_llm"]
|
||||
tags=["conversation_llm"],
|
||||
)
|
||||
|
||||
|
||||
# qwen-vl-max for vision (no tools)
|
||||
self.vision_llm = init_chat_model(
|
||||
model=self.config.vision_llm_name,
|
||||
@@ -152,13 +160,15 @@ class VisionRoutingGraph(GraphBase):
|
||||
# Get tools and bind to tool_llm
|
||||
tool_manager: ToolManager = self.config.tool_manager_config.setup()
|
||||
self.tools = tool_manager.get_tools()
|
||||
|
||||
|
||||
# Filter to only get camera tool
|
||||
self.camera_tools = [t for t in self.tools if t.name == "self_camera_take_photo"]
|
||||
|
||||
self.camera_tools = [
|
||||
t for t in self.tools if t.name == "self_camera_take_photo"
|
||||
]
|
||||
|
||||
# Bind tools to qwen-plus only
|
||||
self.tool_llm_with_tools = self.tool_llm.bind_tools(self.camera_tools)
|
||||
|
||||
|
||||
# Create tool node for executing tools
|
||||
self.tool_node = ToolNode(self.camera_tools)
|
||||
|
||||
@@ -184,73 +194,81 @@ class VisionRoutingGraph(GraphBase):
|
||||
def _camera_decision_call(self, state: VisionRoutingState):
|
||||
"""First pass: qwen-plus decides if photo should be taken"""
|
||||
human_msg = self._get_human_msg(state)
|
||||
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=self.prompt_store.get("camera_decision_prompt")),
|
||||
human_msg
|
||||
human_msg,
|
||||
]
|
||||
|
||||
|
||||
response = self.tool_llm_with_tools.invoke(messages)
|
||||
|
||||
return {
|
||||
"messages": [response],
|
||||
"has_image": False,
|
||||
"image_base64": None
|
||||
}
|
||||
|
||||
return {"messages": [response], "has_image": False, "image_base64": None}
|
||||
|
||||
def _execute_tool(self, state: VisionRoutingState):
|
||||
"""Execute the camera tool if called"""
|
||||
last_msg = state["messages"][-1]
|
||||
|
||||
|
||||
if not hasattr(last_msg, "tool_calls") or not last_msg.tool_calls:
|
||||
return {"has_image": False}
|
||||
|
||||
|
||||
# Execute tool calls
|
||||
tool_messages = []
|
||||
image_data = None
|
||||
|
||||
|
||||
for tool_call in last_msg.tool_calls:
|
||||
if tool_call["name"] == "self_camera_take_photo":
|
||||
# Find and execute the camera tool
|
||||
camera_tool = next((t for t in self.camera_tools if t.name == "self_camera_take_photo"), None)
|
||||
camera_tool = next(
|
||||
(
|
||||
t
|
||||
for t in self.camera_tools
|
||||
if t.name == "self_camera_take_photo"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if camera_tool:
|
||||
result = camera_tool.invoke(tool_call)
|
||||
|
||||
|
||||
# Parse result to extract image
|
||||
if isinstance(result, ToolMessage):
|
||||
content = result.content
|
||||
else:
|
||||
content = result
|
||||
|
||||
|
||||
try:
|
||||
result_data = json.loads(content) if isinstance(content, str) else content
|
||||
if isinstance(result_data, dict) and "image_base64" in result_data:
|
||||
result_data = (
|
||||
json.loads(content) if isinstance(content, str) else content
|
||||
)
|
||||
if (
|
||||
isinstance(result_data, dict)
|
||||
and "image_base64" in result_data
|
||||
):
|
||||
image_data = result_data["image_base64"]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
tool_messages.append(
|
||||
ToolMessage(content=content, tool_call_id=tool_call["id"])
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"messages": state["messages"] + tool_messages,
|
||||
"has_image": image_data is not None,
|
||||
"image_base64": image_data
|
||||
"image_base64": image_data,
|
||||
}
|
||||
|
||||
def _check_image_taken(self, state: VisionRoutingState) -> str:
|
||||
"""Conditional: check if image was captured"""
|
||||
last_msg = state["messages"][-1]
|
||||
|
||||
|
||||
# Check if there are tool calls
|
||||
if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
|
||||
return "execute_tool"
|
||||
|
||||
|
||||
# Check if we have an image after tool execution
|
||||
if state.get("has_image"):
|
||||
return "vision"
|
||||
|
||||
|
||||
return "conversation"
|
||||
|
||||
def _post_tool_check(self, state: VisionRoutingState) -> str:
|
||||
@@ -263,47 +281,45 @@ class VisionRoutingGraph(GraphBase):
|
||||
"""Pass image to qwen-vl-max for description"""
|
||||
human_msg = self._get_human_msg(state)
|
||||
image_base64 = state.get("image_base64")
|
||||
|
||||
|
||||
if not image_base64:
|
||||
logger.warning("No image data available for vision call")
|
||||
return self._conversation_call(state)
|
||||
|
||||
|
||||
# Format message with image for vision model
|
||||
vision_message = HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
}
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"User's request: {human_msg.content}\n\nPlease describe what you see and respond to the user's request."
|
||||
}
|
||||
"text": f"User's request: {human_msg.content}\n\nPlease describe what you see and respond to the user's request.",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=self.prompt_store.get("vision_description_prompt")),
|
||||
vision_message
|
||||
vision_message,
|
||||
]
|
||||
|
||||
|
||||
response = self.vision_llm.invoke(messages)
|
||||
|
||||
|
||||
return {"messages": state["messages"] + [response]}
|
||||
|
||||
def _conversation_call(self, state: VisionRoutingState):
|
||||
"""2nd pass to qwen-plus for conversation quality"""
|
||||
human_msg = self._get_human_msg(state)
|
||||
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=self.prompt_store.get("conversation_prompt")),
|
||||
human_msg
|
||||
human_msg,
|
||||
]
|
||||
|
||||
|
||||
response = self.conversation_llm.invoke(messages)
|
||||
|
||||
|
||||
return {"messages": state["messages"] + [response]}
|
||||
|
||||
def _build_graph(self):
|
||||
@@ -317,7 +333,7 @@ class VisionRoutingGraph(GraphBase):
|
||||
|
||||
# Add edges
|
||||
builder.add_edge(START, "camera_decision")
|
||||
|
||||
|
||||
# After camera decision, check if tool should be executed
|
||||
builder.add_conditional_edges(
|
||||
"camera_decision",
|
||||
@@ -325,20 +341,17 @@ class VisionRoutingGraph(GraphBase):
|
||||
{
|
||||
"execute_tool": "execute_tool",
|
||||
"vision": "vision_call",
|
||||
"conversation": "conversation_call"
|
||||
}
|
||||
"conversation": "conversation_call",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# After tool execution, route based on whether image was captured
|
||||
builder.add_conditional_edges(
|
||||
"execute_tool",
|
||||
self._post_tool_check,
|
||||
{
|
||||
"vision": "vision_call",
|
||||
"conversation": "conversation_call"
|
||||
}
|
||||
{"vision": "vision_call", "conversation": "conversation_call"},
|
||||
)
|
||||
|
||||
|
||||
# Both vision and conversation go to END
|
||||
builder.add_edge("vision_call", END)
|
||||
builder.add_edge("conversation_call", END)
|
||||
@@ -350,23 +363,27 @@ class VisionRoutingGraph(GraphBase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
config = VisionRoutingConfig()
|
||||
graph = VisionRoutingGraph(config)
|
||||
|
||||
|
||||
# Test with a conversation request
|
||||
print("\n=== Test 1: Conversation (no photo needed) ===")
|
||||
nargs = {
|
||||
"messages": [
|
||||
SystemMessage("You are a helpful assistant"),
|
||||
HumanMessage("Hello, how are you today?")
|
||||
]
|
||||
}, {"configurable": {"thread_id": "1"}}
|
||||
|
||||
nargs = (
|
||||
{
|
||||
"messages": [
|
||||
SystemMessage("You are a helpful assistant"),
|
||||
HumanMessage("Hello, how are you today?"),
|
||||
]
|
||||
},
|
||||
{"configurable": {"thread_id": "1"}},
|
||||
)
|
||||
|
||||
result = graph.invoke(*nargs)
|
||||
print(f"Result: {result}")
|
||||
|
||||
|
||||
# Test with a photo request
|
||||
# print("\n=== Test 2: Photo request ===")
|
||||
# nargs = {
|
||||
@@ -375,8 +392,8 @@ if __name__ == "__main__":
|
||||
# HumanMessage("Take a photo and tell me what you see")
|
||||
# ]
|
||||
# }, {"configurable": {"thread_id": "2"}}
|
||||
|
||||
|
||||
# result = graph.invoke(*nargs)
|
||||
# print(f"\033[32mResult: {result}\033[0m")
|
||||
|
||||
|
||||
# print(f"Result: {result}")
|
||||
|
||||
Reference in New Issue
Block a user