pass in device_id
This commit is contained in:
@@ -153,8 +153,15 @@ class Pipeline:
|
||||
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
|
||||
u = DEFAULT_PROMPT
|
||||
|
||||
device_id = "0"
|
||||
spl_ls = thread_id.split("_")
|
||||
assert len(spl_ls) <= 2, "something wrong!"
|
||||
if len(spl_ls) == 2:
|
||||
thread_id, device_id = spl_ls
|
||||
|
||||
inp = {"messages":[SystemMessage(u),
|
||||
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id}}
|
||||
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
||||
"device_id":device_id}}
|
||||
|
||||
out = self.invoke(*inp, as_stream=as_stream, as_raw=as_raw)
|
||||
|
||||
@@ -197,8 +204,16 @@ class Pipeline:
|
||||
# NOTE: this prompt will be overwritten by 'configs/route_sys_prompts/chat_prompt.txt' for route graph
|
||||
u = DEFAULT_PROMPT
|
||||
|
||||
device_id = "0"
|
||||
spl_ls = thread_id.split("_")
|
||||
assert len(spl_ls) <= 2, "something wrong!"
|
||||
if len(spl_ls) == 2:
|
||||
thread_id, device_id = spl_ls
|
||||
print(f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m")
|
||||
|
||||
inp_data = {"messages":[SystemMessage(u),
|
||||
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id}}
|
||||
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
||||
"device_id":device_id}}
|
||||
|
||||
if as_stream:
|
||||
# Return async generator for streaming
|
||||
|
||||
Reference in New Issue
Block a user