pdf_code/zzb_data_word/zhipu_agent.py

126 lines
4.3 KiB
Python
Raw Normal View History

2024-12-30 17:51:12 +08:00
from zhipuai import ZhipuAI
import json
client = ZhipuAI(api_key="5b80671206182be69a928f10ddf419f3.nOG2l5xSW542rkAM")
messages = []
tools = [
{
"type": "function",
"function": {
"name": "get_flight_number",
"description": "根据始发地、目的地和日期,查询对应日期的航班号",
"parameters": {
"type": "object",
"properties": {
"departure": {
"description": "出发地",
"type": "string"
},
"destination": {
"description": "目的地",
"type": "string"
},
"date": {
"description": "日期",
"type": "string",
}
},
"required": [ "departure", "destination", "date" ]
},
}
},
{
"type": "function",
"function": {
"name": "get_ticket_price",
"description": "查询某航班在某日的票价",
"parameters": {
"type": "object",
"properties": {
"flight_number": {
"description": "航班号",
"type": "string"
},
"date": {
"description": "日期",
"type": "string",
}
},
"required": [ "flight_number", "date"]
},
}
},
]
def get_flight_number(date:str , departure:str , destination:str):
flight_number = {
"北京":{
"上海" : "1234",
"广州" : "8321",
},
"上海":{
"北京" : "1233",
"广州" : "8123",
}
}
return { "flight_number":flight_number[departure][destination] }
def get_ticket_price(date:str , flight_number:str):
return {"ticket_price": "1000"}
def parse_function_call(model_response,messages):
# 处理函数调用结果,根据模型返回参数,调用对应的函数。
# 调用函数返回结果后构造tool message再次调用模型将函数结果输入模型
# 模型会将函数调用结果以自然语言格式返回给用户。
if model_response.choices[0].message.tool_calls:
tool_call = model_response.choices[0].message.tool_calls[0]
args = tool_call.function.arguments
function_result = {}
if tool_call.function.name == "get_flight_number":
function_result = get_flight_number(**json.loads(args))
if tool_call.function.name == "get_ticket_price":
function_result = get_ticket_price(**json.loads(args))
messages.append({
"role": "tool",
"content": f"{json.dumps(function_result)}",
"tool_call_id":tool_call.id
})
response = client.chat.completions.create(
model="glm-4", # 填写需要调用的模型名称
messages=messages,
tools=tools,
)
print(response.choices[0].message)
messages.append(response.choices[0].message.model_dump())
if __name__ == "__main__":
#查询北京到广州的航班
# 清空对话
messages = []
messages.append({"role": "system", "content": "如果用户的问题中没有提供完整的参数,请要求用户提供必要信息"})
messages.append({"role": "user", "content": "帮我查询北京到上海的航班"})
response = client.chat.completions.create(
model="glm-4", # 填写需要调用的模型名称
messages=messages,
tools=tools,
)
print(response)
# print(response.choices[0].message)
# messages.append(response.choices[0].message.model_dump())
# parse_function_call(response,messages)
# # 查询航班价格
# messages.append({"role": "user", "content": "这趟航班的价格是多少?"})
# response = client.chat.completions.create(
# model="glm-4", # 填写需要调用的模型名称
# messages=messages,
# tools=tools,
# )
# print(response.choices[0].message)
# messages.append(response.choices[0].message.model_dump())
# parse_function_call(response,messages)