pdf_code/zzb_data_word/zhipu_agent.py

126 lines
4.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from zhipuai import ZhipuAI
import json
client = ZhipuAI(api_key="5b80671206182be69a928f10ddf419f3.nOG2l5xSW542rkAM")
messages = []
tools = [
{
"type": "function",
"function": {
"name": "get_flight_number",
"description": "根据始发地、目的地和日期,查询对应日期的航班号",
"parameters": {
"type": "object",
"properties": {
"departure": {
"description": "出发地",
"type": "string"
},
"destination": {
"description": "目的地",
"type": "string"
},
"date": {
"description": "日期",
"type": "string",
}
},
"required": [ "departure", "destination", "date" ]
},
}
},
{
"type": "function",
"function": {
"name": "get_ticket_price",
"description": "查询某航班在某日的票价",
"parameters": {
"type": "object",
"properties": {
"flight_number": {
"description": "航班号",
"type": "string"
},
"date": {
"description": "日期",
"type": "string",
}
},
"required": [ "flight_number", "date"]
},
}
},
]
def get_flight_number(date:str , departure:str , destination:str):
flight_number = {
"北京":{
"上海" : "1234",
"广州" : "8321",
},
"上海":{
"北京" : "1233",
"广州" : "8123",
}
}
return { "flight_number":flight_number[departure][destination] }
def get_ticket_price(date:str , flight_number:str):
return {"ticket_price": "1000"}
def parse_function_call(model_response,messages):
# 处理函数调用结果,根据模型返回参数,调用对应的函数。
# 调用函数返回结果后构造tool message再次调用模型将函数结果输入模型
# 模型会将函数调用结果以自然语言格式返回给用户。
if model_response.choices[0].message.tool_calls:
tool_call = model_response.choices[0].message.tool_calls[0]
args = tool_call.function.arguments
function_result = {}
if tool_call.function.name == "get_flight_number":
function_result = get_flight_number(**json.loads(args))
if tool_call.function.name == "get_ticket_price":
function_result = get_ticket_price(**json.loads(args))
messages.append({
"role": "tool",
"content": f"{json.dumps(function_result)}",
"tool_call_id":tool_call.id
})
response = client.chat.completions.create(
model="glm-4", # 填写需要调用的模型名称
messages=messages,
tools=tools,
)
print(response.choices[0].message)
messages.append(response.choices[0].message.model_dump())
if __name__ == "__main__":
#查询北京到广州的航班
# 清空对话
messages = []
messages.append({"role": "system", "content": "如果用户的问题中没有提供完整的参数,请要求用户提供必要信息"})
messages.append({"role": "user", "content": "帮我查询北京到上海的航班"})
response = client.chat.completions.create(
model="glm-4", # 填写需要调用的模型名称
messages=messages,
tools=tools,
)
print(response)
# print(response.choices[0].message)
# messages.append(response.choices[0].message.model_dump())
# parse_function_call(response,messages)
# # 查询航班价格
# messages.append({"role": "user", "content": "这趟航班的价格是多少?"})
# response = client.chat.completions.create(
# model="glm-4", # 填写需要调用的模型名称
# messages=messages,
# tools=tools,
# )
# print(response.choices[0].message)
# messages.append(response.choices[0].message.model_dump())
# parse_function_call(response,messages)