完善向量运算逻辑(np.array替代list), 增加redis的更新脚本
This commit is contained in:
parent
e41d0c8dbc
commit
76d92944a1
|
@ -399,13 +399,10 @@ def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,re
|
|||
distance = record[2]
|
||||
ori_measure_id = record[3]
|
||||
measure_id = record[4]
|
||||
measure_vector = redis_service.read_from_redis(redis_client,ori_measure_id)
|
||||
|
||||
measure_vector = redis_service.read_from_redis(redis_client, ori_measure_id, redis_key="measure_config_new")
|
||||
|
||||
measure_list = ast.literal_eval(measure_vector)
|
||||
|
||||
filter_str = 'file_id == "'+file_id+'"'
|
||||
res = vector_storage.search_similar_vectors(measure_vector, top_k=3, similarity_threshold=distance)
|
||||
res = vector_storage.search_similar_vectors(measure_list, top_k=3, similarity_threshold=distance)
|
||||
# 返回格式为 [(similarity, {metadata}), ...]
|
||||
for i in range(len(res)):
|
||||
vector_distance = float(res[i][0])
|
||||
|
|
|
@ -5,12 +5,12 @@ def read_from_file_and_write_to_redis(redis_client,ori_measure_id,measure_vector
|
|||
redis_client.hset('measure_config',ori_measure_id, measure_vector)
|
||||
|
||||
# 从 Redis 中读取数据
|
||||
def read_from_redis(redis_client,ori_measure_id):
|
||||
def read_from_redis(redis_client, ori_measure_id, redis_key="measure_config"):
|
||||
# 获取所有键
|
||||
return redis_client.hget('measure_config',ori_measure_id).decode()
|
||||
return redis_client.hget(redis_key, ori_measure_id).decode()
|
||||
|
||||
if __name__ == "__main__":
|
||||
redis_client = redis.Redis(host='192.168.0.175', port=6379, password='Xgf_redis', db=6)
|
||||
|
||||
value = read_from_redis(redis_client,"bb3cf43f3dba147373c706c6567b5a")
|
||||
value = read_from_redis(redis_client,"7c2dc91c1df9ab657b24e070431b401d", "measure_config_new")
|
||||
print(value)
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
import redis
|
||||
import utils
|
||||
import mysql.connector
|
||||
from config import *
|
||||
|
||||
def insert_measure_vector(conn,cursor):
|
||||
|
||||
redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=6)
|
||||
# redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=6)
|
||||
# 执行SQL语句,更新数据
|
||||
select_query = '''
|
||||
SELECT ori_measure_id,ori_measure_name,measure_name FROM measure_config_half_year where year='2025'
|
||||
'''
|
||||
# select_query = '''
|
||||
# SELECT ori_measure_id,ori_measure_name,measure_name FROM measure_config where year='2024'
|
||||
# '''
|
||||
cursor.execute(select_query)
|
||||
records = cursor.fetchall()
|
||||
print(records[:8])
|
||||
#return
|
||||
index = 1
|
||||
for record in records:
|
||||
if redis_client.hexists('measure_config_new', record[0]):
|
||||
measure_vector = redis_client.hget('measure_config_new', record[0])
|
||||
else:
|
||||
print('新增指标',record[1])
|
||||
vector = utils.embed_with_str_local(record[1])
|
||||
measure_vector = str(vector)
|
||||
#print(f'新增指标{index} 对应归一化指标为{record[2]}',record[1])
|
||||
#index += 1
|
||||
#vector = utils.embed_with_str_local(record[1])
|
||||
#measure_vector = str(vector)
|
||||
redis_client.hset('measure_config_new', record[0], measure_vector)
|
||||
redis_client.close()
|
||||
conn.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
conn = mysql.connector.connect(
|
||||
host=MYSQL_HOST,
|
||||
user=MYSQL_USER,
|
||||
password=MYSQL_PASSWORD,
|
||||
database="financial_report_prod"
|
||||
)
|
||||
cursor = conn.cursor()
|
||||
|
||||
insert_measure_vector(conn,cursor)
|
|
@ -108,7 +108,7 @@ class VectorStorage:
|
|||
def search_similar_vectors(self, query_vector: np.ndarray,
|
||||
top_k: int = 10, similarity_threshold: float = 0.0) -> List[Tuple[float, Dict[str, Any]]]:
|
||||
"""搜索相似向量"""
|
||||
query_vector = query_vector.astype(np.float32)
|
||||
query_vector = np.asarray(query_vector).astype(np.float32)
|
||||
current_count = self.data_count[0]
|
||||
if current_count == 0:
|
||||
return []
|
||||
|
|
Loading…
Reference in New Issue