dual proof of concept

This commit is contained in:
2026-01-06 00:26:07 +08:00
parent 62f56827b5
commit c0849220ec

View File

@@ -0,0 +1,109 @@
from dataclasses import dataclass, field
from typing import Type, TypedDict, Literal, Dict, List
import tyro
from pydantic import BaseModel, Field
from loguru import logger
from langchain.chat_models import init_chat_model
from lang_agent.config import LLMKeyConfig
from lang_agent.base import GraphBase
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.graphs.graph_states import State
from langchain.agents import create_agent
from langchain.messages import SystemMessage, HumanMessage
from langchain.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, START, END
SYS_PROMPT = "you are a helpful helper who will have a fun conversation with the user"
TOOL_SYS_PROMPT = "base on the user's speech, identify their emotions and change the light color to its appropriate colors. If it sounds neutral, do nothing"
@dataclass
class DualConfig(LLMKeyConfig):
_target: Type = field(default_factory=lambda:Dual)
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
from langchain.tools import tool
@tool
def turn_lights(col:Literal["red", "green", "yellow", "blue"]):
"""
Turn on the color of the lights
"""
print(f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
class Dual(GraphBase):
def __init__(self, config:DualConfig):
self.config = config
self._build_modules()
self.workflow = self._build_graph()
self.streamable_tags = ["dual_chat_llm"]
def _build_modules(self):
self.chat_llm = init_chat_model(model=self.config.llm_name,
model_provider=self.config.llm_provider,
api_key=self.config.api_key,
base_url=self.config.base_url,
temperature=0,
tags=["dual_chat_llm"])
self.tool_llm = init_chat_model(model='qwen-flash',
model_provider='openai',
api_key=self.config.api_key,
base_url=self.config.base_url,
temperature=0,
tags=["dual_tool_llm"])
self.memory = MemorySaver()
self.tool_manager: ToolManager = self.config.tool_manager_config.setup()
self.chat_agent = create_agent(self.chat_llm, [], checkpointer=self.memory)
# self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_langchain_tools())
self.tool_agent = create_agent(self.tool_llm, [turn_lights])
self.streamable_tags = [["dual_chat_llm"]]
def _chat_call(self, state:State):
return self._agent_call_template(SYS_PROMPT, self.chat_agent, state)
def _tool_call(self, state:State):
self._agent_call_template(TOOL_SYS_PROMPT, self.tool_agent, state)
return {}
def _join(self, state:State):
return {}
def _build_graph(self):
builder = StateGraph(State)
builder.add_node("chat_call", self._chat_call)
builder.add_node("tool_call", self._tool_call)
builder.add_node("join", self._join)
builder.add_edge(START, "chat_call")
builder.add_edge(START, "tool_call")
builder.add_edge("chat_call", "join")
builder.add_edge("tool_call", "join")
builder.add_edge("join", END)
return builder.compile()
if __name__ == "__main__":
dual:Dual = DualConfig().setup()
nargs = {"messages": [SystemMessage("you are a helpful bot named jarvis"),
HumanMessage("I feel very very sad")]
}, {"configurable": {"thread_id": "3"}}
out = dual.invoke(*nargs)
print(out)