126 lines
4.3 KiB
Python
126 lines
4.3 KiB
Python
from zhipuai import ZhipuAI
|
||
import json
|
||
|
||
client = ZhipuAI(api_key="5b80671206182be69a928f10ddf419f3.nOG2l5xSW542rkAM")
|
||
messages = []
|
||
tools = [
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "get_flight_number",
|
||
"description": "根据始发地、目的地和日期,查询对应日期的航班号",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"departure": {
|
||
"description": "出发地",
|
||
"type": "string"
|
||
},
|
||
"destination": {
|
||
"description": "目的地",
|
||
"type": "string"
|
||
},
|
||
"date": {
|
||
"description": "日期",
|
||
"type": "string",
|
||
}
|
||
},
|
||
"required": [ "departure", "destination", "date" ]
|
||
},
|
||
}
|
||
},
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "get_ticket_price",
|
||
"description": "查询某航班在某日的票价",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"flight_number": {
|
||
"description": "航班号",
|
||
"type": "string"
|
||
},
|
||
"date": {
|
||
"description": "日期",
|
||
"type": "string",
|
||
}
|
||
},
|
||
"required": [ "flight_number", "date"]
|
||
},
|
||
}
|
||
},
|
||
]
|
||
|
||
def get_flight_number(date:str , departure:str , destination:str):
|
||
flight_number = {
|
||
"北京":{
|
||
"上海" : "1234",
|
||
"广州" : "8321",
|
||
},
|
||
"上海":{
|
||
"北京" : "1233",
|
||
"广州" : "8123",
|
||
}
|
||
}
|
||
return { "flight_number":flight_number[departure][destination] }
|
||
|
||
def get_ticket_price(date:str , flight_number:str):
|
||
return {"ticket_price": "1000"}
|
||
|
||
def parse_function_call(model_response,messages):
|
||
# 处理函数调用结果,根据模型返回参数,调用对应的函数。
|
||
# 调用函数返回结果后构造tool message,再次调用模型,将函数结果输入模型
|
||
# 模型会将函数调用结果以自然语言格式返回给用户。
|
||
if model_response.choices[0].message.tool_calls:
|
||
tool_call = model_response.choices[0].message.tool_calls[0]
|
||
args = tool_call.function.arguments
|
||
function_result = {}
|
||
if tool_call.function.name == "get_flight_number":
|
||
function_result = get_flight_number(**json.loads(args))
|
||
if tool_call.function.name == "get_ticket_price":
|
||
function_result = get_ticket_price(**json.loads(args))
|
||
messages.append({
|
||
"role": "tool",
|
||
"content": f"{json.dumps(function_result)}",
|
||
"tool_call_id":tool_call.id
|
||
})
|
||
response = client.chat.completions.create(
|
||
model="glm-4", # 填写需要调用的模型名称
|
||
messages=messages,
|
||
tools=tools,
|
||
)
|
||
print(response.choices[0].message)
|
||
messages.append(response.choices[0].message.model_dump())
|
||
|
||
if __name__ == "__main__":
|
||
#查询北京到广州的航班
|
||
# 清空对话
|
||
messages = []
|
||
|
||
messages.append({"role": "system", "content": "如果用户的问题中没有提供完整的参数,请要求用户提供必要信息"})
|
||
messages.append({"role": "user", "content": "帮我查询北京到上海的航班"})
|
||
|
||
response = client.chat.completions.create(
|
||
model="glm-4", # 填写需要调用的模型名称
|
||
messages=messages,
|
||
tools=tools,
|
||
)
|
||
print(response)
|
||
# print(response.choices[0].message)
|
||
# messages.append(response.choices[0].message.model_dump())
|
||
|
||
# parse_function_call(response,messages)
|
||
|
||
# # 查询航班价格
|
||
# messages.append({"role": "user", "content": "这趟航班的价格是多少?"})
|
||
# response = client.chat.completions.create(
|
||
# model="glm-4", # 填写需要调用的模型名称
|
||
# messages=messages,
|
||
# tools=tools,
|
||
# )
|
||
# print(response.choices[0].message)
|
||
# messages.append(response.choices[0].message.model_dump())
|
||
|
||
# parse_function_call(response,messages)
|