Files
lang-agent/garbage_demo.py
2025-11-28 18:37:10 +08:00

116 lines
3.3 KiB
Python

import time, threading
from typing import TypedDict, Annotated
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langchain.chat_models import init_chat_model
import dotenv, os
from langchain_core.messages.base import BaseMessageChunk
from langchain_core.messages import BaseMessage, AIMessage
from langchain.agents import create_agent
dotenv.load_dotenv()
# --- Shared state ---
class State(TypedDict):
chat_active: bool
tool_done: Annotated[bool, lambda a, b: a or b]
tool_result: str
chat_output: Annotated[str, lambda a, b: a + b]
RUN_STATE = {"tool_done": False}
CHAT_MODEL = init_chat_model(
model="qwen-flash",
model_provider="openai",
api_key=os.environ.get("ALI_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
temperature=0,
)
# CHAT_MODEL = create_agent(CHAT_MODEL, [], checkpointer=MemorySaver())
# --- Nodes ---
def tool_node(state: State):
"""Simulate background work."""
def work():
for i in range(3):
print(f"\033\n[33m[Tool] Working... {i+1}/3\033[0m")
time.sleep(2)
print("[Tool] Done.")
RUN_STATE["tool_done"] = True
work()
return {"tool_done": False, "tool_result": "WEEEEEE, TOOL WORKS BABAY!"}
def chat_node(state: State):
"""Chat repeatedly until tool is finished."""
output = ""
i = 1
while not RUN_STATE["tool_done"]:
msg = {"role": "user", "content": f"Tool not done yet (step {i})"}
resp = CHAT_MODEL.invoke([{"role": "system", "content": "Keep the user updated."}, msg])
text = getattr(resp, "content", str(resp))
output += f"[Chat] {text}\n"
i += 1
time.sleep(1)
output += "\n[Chat] Tool done detected.\n"
return {"chat_output": output}
def handoff_node(state: State):
final = f"Tool result: {state.get('tool_result', 'unknown')}"
print("[Handoff]", final)
return {"chat_output": state.get("chat_output", "") + final}
# --- Graph ---
builder = StateGraph(State)
builder.add_node("tool_node", tool_node)
builder.add_node("chat_node", chat_node)
builder.add_node("handoff_node", handoff_node)
builder.add_edge(START, "tool_node")
builder.add_edge(START, "chat_node")
builder.add_edge("chat_node", "handoff_node")
builder.add_edge("tool_node", "handoff_node")
builder.add_edge("handoff_node", END)
graph = builder.compile(checkpointer=MemorySaver())
# from PIL import Image
# import matplotlib.pyplot as plt
# from io import BytesIO
# img = Image.open(BytesIO(graph.get_graph().draw_mermaid_png()))
# plt.imshow(img)
# plt.show()
# --- Streaming Run ---
print("\n=== STREAM MODE: messages ===\n")
state0 = (
{"chat_active": True, "tool_done": False, "tool_result": "", "chat_output": ""},
{"configurable": {"thread_id": 1}},
)
# for event in graph.stream(*state0, stream_mode="messages"):
# print("[STREAM UPDATE]:", event)
for chunk in graph.stream(*state0, stream_mode="updates"):
# node = metadata.get("langgraph_node")
# if node not in ("model"):
# continue # skip router or other intermediate nodes
# Print only the final message content
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
print(chunk.content, end="", flush=True)
print("\n=== Run Finished ===")