126 lines
4.3 KiB
Python
126 lines
4.3 KiB
Python
|
from zhipuai import ZhipuAI
|
|||
|
import json
|
|||
|
|
|||
|
client = ZhipuAI(api_key="5b80671206182be69a928f10ddf419f3.nOG2l5xSW542rkAM")
|
|||
|
messages = []
|
|||
|
tools = [
|
|||
|
{
|
|||
|
"type": "function",
|
|||
|
"function": {
|
|||
|
"name": "get_flight_number",
|
|||
|
"description": "根据始发地、目的地和日期,查询对应日期的航班号",
|
|||
|
"parameters": {
|
|||
|
"type": "object",
|
|||
|
"properties": {
|
|||
|
"departure": {
|
|||
|
"description": "出发地",
|
|||
|
"type": "string"
|
|||
|
},
|
|||
|
"destination": {
|
|||
|
"description": "目的地",
|
|||
|
"type": "string"
|
|||
|
},
|
|||
|
"date": {
|
|||
|
"description": "日期",
|
|||
|
"type": "string",
|
|||
|
}
|
|||
|
},
|
|||
|
"required": [ "departure", "destination", "date" ]
|
|||
|
},
|
|||
|
}
|
|||
|
},
|
|||
|
{
|
|||
|
"type": "function",
|
|||
|
"function": {
|
|||
|
"name": "get_ticket_price",
|
|||
|
"description": "查询某航班在某日的票价",
|
|||
|
"parameters": {
|
|||
|
"type": "object",
|
|||
|
"properties": {
|
|||
|
"flight_number": {
|
|||
|
"description": "航班号",
|
|||
|
"type": "string"
|
|||
|
},
|
|||
|
"date": {
|
|||
|
"description": "日期",
|
|||
|
"type": "string",
|
|||
|
}
|
|||
|
},
|
|||
|
"required": [ "flight_number", "date"]
|
|||
|
},
|
|||
|
}
|
|||
|
},
|
|||
|
]
|
|||
|
|
|||
|
def get_flight_number(date:str , departure:str , destination:str):
|
|||
|
flight_number = {
|
|||
|
"北京":{
|
|||
|
"上海" : "1234",
|
|||
|
"广州" : "8321",
|
|||
|
},
|
|||
|
"上海":{
|
|||
|
"北京" : "1233",
|
|||
|
"广州" : "8123",
|
|||
|
}
|
|||
|
}
|
|||
|
return { "flight_number":flight_number[departure][destination] }
|
|||
|
|
|||
|
def get_ticket_price(date:str , flight_number:str):
|
|||
|
return {"ticket_price": "1000"}
|
|||
|
|
|||
|
def parse_function_call(model_response,messages):
|
|||
|
# 处理函数调用结果,根据模型返回参数,调用对应的函数。
|
|||
|
# 调用函数返回结果后构造tool message,再次调用模型,将函数结果输入模型
|
|||
|
# 模型会将函数调用结果以自然语言格式返回给用户。
|
|||
|
if model_response.choices[0].message.tool_calls:
|
|||
|
tool_call = model_response.choices[0].message.tool_calls[0]
|
|||
|
args = tool_call.function.arguments
|
|||
|
function_result = {}
|
|||
|
if tool_call.function.name == "get_flight_number":
|
|||
|
function_result = get_flight_number(**json.loads(args))
|
|||
|
if tool_call.function.name == "get_ticket_price":
|
|||
|
function_result = get_ticket_price(**json.loads(args))
|
|||
|
messages.append({
|
|||
|
"role": "tool",
|
|||
|
"content": f"{json.dumps(function_result)}",
|
|||
|
"tool_call_id":tool_call.id
|
|||
|
})
|
|||
|
response = client.chat.completions.create(
|
|||
|
model="glm-4", # 填写需要调用的模型名称
|
|||
|
messages=messages,
|
|||
|
tools=tools,
|
|||
|
)
|
|||
|
print(response.choices[0].message)
|
|||
|
messages.append(response.choices[0].message.model_dump())
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
#查询北京到广州的航班
|
|||
|
# 清空对话
|
|||
|
messages = []
|
|||
|
|
|||
|
messages.append({"role": "system", "content": "如果用户的问题中没有提供完整的参数,请要求用户提供必要信息"})
|
|||
|
messages.append({"role": "user", "content": "帮我查询北京到上海的航班"})
|
|||
|
|
|||
|
response = client.chat.completions.create(
|
|||
|
model="glm-4", # 填写需要调用的模型名称
|
|||
|
messages=messages,
|
|||
|
tools=tools,
|
|||
|
)
|
|||
|
print(response)
|
|||
|
# print(response.choices[0].message)
|
|||
|
# messages.append(response.choices[0].message.model_dump())
|
|||
|
|
|||
|
# parse_function_call(response,messages)
|
|||
|
|
|||
|
# # 查询航班价格
|
|||
|
# messages.append({"role": "user", "content": "这趟航班的价格是多少?"})
|
|||
|
# response = client.chat.completions.create(
|
|||
|
# model="glm-4", # 填写需要调用的模型名称
|
|||
|
# messages=messages,
|
|||
|
# tools=tools,
|
|||
|
# )
|
|||
|
# print(response.choices[0].message)
|
|||
|
# messages.append(response.choices[0].message.model_dump())
|
|||
|
|
|||
|
# parse_function_call(response,messages)
|