添加三个提示词及模型温度
This commit is contained in:
@@ -113,7 +113,8 @@ class RoutingGraph(GraphBase):
|
||||
self.llm = init_chat_model(model=self.config.llm_name,
|
||||
model_provider=self.config.llm_provider,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url)
|
||||
base_url=self.config.base_url,
|
||||
temperature=0)
|
||||
self.memory = MemorySaver() # shared memory between the two branch
|
||||
self.router = self.llm.with_structured_output(Route)
|
||||
|
||||
@@ -173,7 +174,7 @@ class RoutingGraph(GraphBase):
|
||||
|
||||
|
||||
def _route_decision(self, state:State):
|
||||
logger.info(f"decision:{state["decision"]}")
|
||||
logger.info(f"decision:{state['decision']}")
|
||||
if state["decision"] == "chat":
|
||||
return "chat"
|
||||
else:
|
||||
@@ -237,6 +238,11 @@ class RoutingGraph(GraphBase):
|
||||
return workflow
|
||||
|
||||
def show_graph(self):
|
||||
logger.info("creating image")
|
||||
img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png()))
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
route = RoutingConfig().setup()
|
||||
route.show_graph()
|
||||
Reference in New Issue
Block a user