temp demo [TO REMOVE]
This commit is contained in:
115
garbage_demo.py
Normal file
115
garbage_demo.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import time, threading
|
||||
from typing import TypedDict, Annotated
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langchain.chat_models import init_chat_model
|
||||
import dotenv, os
|
||||
|
||||
from langchain_core.messages.base import BaseMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain.agents import create_agent
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
# --- Shared state ---
|
||||
class State(TypedDict):
|
||||
chat_active: bool
|
||||
tool_done: Annotated[bool, lambda a, b: a or b]
|
||||
tool_result: str
|
||||
chat_output: Annotated[str, lambda a, b: a + b]
|
||||
|
||||
|
||||
RUN_STATE = {"tool_done": False}
|
||||
|
||||
CHAT_MODEL = init_chat_model(
|
||||
model="qwen-flash",
|
||||
model_provider="openai",
|
||||
api_key=os.environ.get("ALI_API_KEY"),
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
|
||||
# CHAT_MODEL = create_agent(CHAT_MODEL, [], checkpointer=MemorySaver())
|
||||
|
||||
# --- Nodes ---
|
||||
def tool_node(state: State):
|
||||
"""Simulate background work."""
|
||||
def work():
|
||||
for i in range(3):
|
||||
print(f"\033\n[33m[Tool] Working... {i+1}/3\033[0m")
|
||||
time.sleep(2)
|
||||
print("[Tool] Done.")
|
||||
RUN_STATE["tool_done"] = True
|
||||
|
||||
work()
|
||||
return {"tool_done": False, "tool_result": "WEEEEEE, TOOL WORKS BABAY!"}
|
||||
|
||||
|
||||
def chat_node(state: State):
|
||||
"""Chat repeatedly until tool is finished."""
|
||||
|
||||
output = ""
|
||||
i = 1
|
||||
while not RUN_STATE["tool_done"]:
|
||||
msg = {"role": "user", "content": f"Tool not done yet (step {i})"}
|
||||
resp = CHAT_MODEL.invoke([{"role": "system", "content": "Keep the user updated."}, msg])
|
||||
text = getattr(resp, "content", str(resp))
|
||||
output += f"[Chat] {text}\n"
|
||||
i += 1
|
||||
time.sleep(1)
|
||||
|
||||
output += "\n[Chat] Tool done detected.\n"
|
||||
return {"chat_output": output}
|
||||
|
||||
|
||||
def handoff_node(state: State):
|
||||
final = f"Tool result: {state.get('tool_result', 'unknown')}"
|
||||
print("[Handoff]", final)
|
||||
return {"chat_output": state.get("chat_output", "") + final}
|
||||
|
||||
|
||||
# --- Graph ---
|
||||
builder = StateGraph(State)
|
||||
builder.add_node("tool_node", tool_node)
|
||||
builder.add_node("chat_node", chat_node)
|
||||
builder.add_node("handoff_node", handoff_node)
|
||||
|
||||
builder.add_edge(START, "tool_node")
|
||||
builder.add_edge(START, "chat_node")
|
||||
builder.add_edge("chat_node", "handoff_node")
|
||||
builder.add_edge("tool_node", "handoff_node")
|
||||
builder.add_edge("handoff_node", END)
|
||||
|
||||
graph = builder.compile(checkpointer=MemorySaver())
|
||||
|
||||
# from PIL import Image
|
||||
# import matplotlib.pyplot as plt
|
||||
# from io import BytesIO
|
||||
|
||||
# img = Image.open(BytesIO(graph.get_graph().draw_mermaid_png()))
|
||||
# plt.imshow(img)
|
||||
# plt.show()
|
||||
|
||||
# --- Streaming Run ---
|
||||
print("\n=== STREAM MODE: messages ===\n")
|
||||
state0 = (
|
||||
{"chat_active": True, "tool_done": False, "tool_result": "", "chat_output": ""},
|
||||
{"configurable": {"thread_id": 1}},
|
||||
)
|
||||
|
||||
# for event in graph.stream(*state0, stream_mode="messages"):
|
||||
# print("[STREAM UPDATE]:", event)
|
||||
|
||||
|
||||
for chunk, metadata in graph.stream(*state0, stream_mode="messages"):
|
||||
# node = metadata.get("langgraph_node")
|
||||
# if node not in ("model"):
|
||||
# continue # skip router or other intermediate nodes
|
||||
|
||||
# Print only the final message content
|
||||
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||
print(chunk.content, end="", flush=True)
|
||||
|
||||
|
||||
print("\n=== Run Finished ===")
|
||||
Reference in New Issue
Block a user