|
|
@@ -0,0 +1,462 @@
|
|
|
+#1.引入numpy
|
|
|
+import numpy as np
|
|
|
+#2.引入模型
|
|
|
+from models.model_loader import model_loader
|
|
|
+#3.引入scikit-learn的余弦相似度函数
|
|
|
+from sklearn.metrics.pairwise import cosine_similarity
|
|
|
+#4.引入json
|
|
|
+import json
|
|
|
+#5.引入缓存机制
|
|
|
+from cachetools import TTLCache
|
|
|
+#6.引入哈希值
|
|
|
+import hashlib
|
|
|
+#7.引入标题提取
|
|
|
+import utils.matching as TitleMatcher
|
|
|
+#8.引入模板数据
|
|
|
+from services.templateFun import templateData, SectorScheduler
|
|
|
+# 引入样式数据
|
|
|
+from services.style.style1to4data import sectorStyle1to4Data
|
|
|
+from services.style.style5data import sectorStyle5Data
|
|
|
+#9.引入随机数
|
|
|
+import random
|
|
|
+#10.引入请求
|
|
|
+from utils.request import post_request,generate_error
|
|
|
+#11.引入时间
|
|
|
+import time
|
|
|
+#12.引入输入类型
|
|
|
+import utils.inputType as inputType
|
|
|
+
|
|
|
+# 创建AI处理类
|
|
|
+class AIService:
|
|
|
+ """第一步:初始化"""
|
|
|
+ def __init__(self):
|
|
|
+ # 1.1加载文本处理模型
|
|
|
+ self.text_model = model_loader.load_text_model()
|
|
|
+ # 1.2使用TTLCache,最大缓存50个模板,1小时过期,缓存模板嵌入向量
|
|
|
+ self.template_embeddings_cache = TTLCache(maxsize=50, ttl=3600)
|
|
|
+ # 1.3用户输入文本
|
|
|
+ self.user_input = ""
|
|
|
+ # 1.4组件样式数据
|
|
|
+ self.component_style_data = []
|
|
|
+ # 1.5推理过程
|
|
|
+ self.reasoning = []
|
|
|
+ # 1.6 当前样式数据对象
|
|
|
+ self.current_style_data = None
|
|
|
+
|
|
|
+ """第二步:调用模型能力"""
|
|
|
+ #2.1 文字转换成向量
|
|
|
+ def text_to_embedding(self, text):
|
|
|
+ # 2.1.1如果text是字符串,则转换成列表
|
|
|
+ if isinstance(text, str):
|
|
|
+ text = [text]
|
|
|
+
|
|
|
+ embeddings = self.text_model.encode(text)
|
|
|
+ # 2.2.2返回numpy数组,不要转成list
|
|
|
+ return embeddings
|
|
|
+
|
|
|
+ #2.2生成文字摘要
|
|
|
+ def userMessage_to_title(self, userMessage):
|
|
|
+ return TitleMatcher.title_matcher.generate_title(userMessage)
|
|
|
+
|
|
|
+ #2.3筛选通栏
|
|
|
+ def filter_sectors(self, userMessage,matched_template_id):
|
|
|
+ return inputType.inputType_matcher.generate_inputType(userMessage,matched_template_id)
|
|
|
+
|
|
|
+ """第三步:查找最匹配的皮肤模板"""
|
|
|
+ def find_best_matching_template(self, user_text, skin_data):
|
|
|
+ # 3.1如果皮肤数据为空,则返回None
|
|
|
+ if not skin_data:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 3.2提取用户输入中的颜色关键词
|
|
|
+ color_keywords = self.extract_color_keywords(user_text)
|
|
|
+ print(f"提取到用户颜色关键词: {color_keywords}")
|
|
|
+ # 更新推理过程
|
|
|
+ self.reasoning.append(f"我提取到用户需求中的颜色关键词是: {color_keywords}")
|
|
|
+
|
|
|
+ # 3.3根据颜色关键词过滤模板
|
|
|
+ filtered_skin_data = skin_data
|
|
|
+ if color_keywords:
|
|
|
+ filtered_skin_data = self.filter_templates_by_color(skin_data, color_keywords)
|
|
|
+ print(f"根据颜色过滤后模板数量: {len(filtered_skin_data)}")
|
|
|
+ # 更新推理过程
|
|
|
+ self.reasoning.append(f"根据颜色过滤数据以后有: {len(filtered_skin_data)}个符合该需求的模板")
|
|
|
+
|
|
|
+ # 3.4获取模板信息(使用缓存)
|
|
|
+ template_infos, template_embeddings = self.get_template_embeddings(filtered_skin_data)
|
|
|
+
|
|
|
+ # 3.5如果模板信息为空,则返回None
|
|
|
+ if not template_infos:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 3.6用户输入的文本转化为向量
|
|
|
+ user_embedding = self.text_to_embedding(user_text)
|
|
|
+
|
|
|
+ # 3.7如果用户输入的文本的维度是1,则转换成2维
|
|
|
+ if len(user_embedding.shape) == 1:
|
|
|
+ user_embedding = user_embedding.reshape(1, -1)
|
|
|
+ if len(template_embeddings.shape) == 1:
|
|
|
+ template_embeddings = template_embeddings.reshape(1, -1)
|
|
|
+
|
|
|
+ # 3.8通过余弦函数批量计算相似度
|
|
|
+ similarities = cosine_similarity(user_embedding, template_embeddings)[0]
|
|
|
+ best_index = np.argmax(similarities)
|
|
|
+ best_similarity = similarities[best_index]
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'template_info': template_infos[best_index],
|
|
|
+ 'similarity_score': best_similarity
|
|
|
+ }
|
|
|
+
|
|
|
+ # 3.9提取用户输入中的颜色关键词
|
|
|
+ def extract_color_keywords(self, text):
|
|
|
+ color_map = {
|
|
|
+ "红": "红色",
|
|
|
+ "绿": "绿色",
|
|
|
+ "蓝": "蓝色",
|
|
|
+ "黄": "黄色",
|
|
|
+ "橙": "橙色",
|
|
|
+ "紫": "紫色",
|
|
|
+ "黑": "黑色",
|
|
|
+ "白": "白色",
|
|
|
+ "灰": "灰色"
|
|
|
+ }
|
|
|
+
|
|
|
+ found_colors = []
|
|
|
+ for char, color in color_map.items():
|
|
|
+ if char in text:
|
|
|
+ found_colors.append(color)
|
|
|
+
|
|
|
+ return found_colors
|
|
|
+
|
|
|
+ # 3.10根据颜色关键词过滤模板
|
|
|
+ def filter_templates_by_color(self, skin_data, color_keywords):
|
|
|
+ filtered = []
|
|
|
+ for template in skin_data:
|
|
|
+ try:
|
|
|
+ keywords = json.loads(template.get('template_keyword', '[]'))
|
|
|
+ keyword_text = ' '.join(keywords)
|
|
|
+
|
|
|
+ # 检查模板关键词是否包含任何颜色关键词
|
|
|
+ for color in color_keywords:
|
|
|
+ if color in keyword_text:
|
|
|
+ filtered.append(template)
|
|
|
+ break
|
|
|
+ except:
|
|
|
+ continue
|
|
|
+
|
|
|
+ return filtered
|
|
|
+
|
|
|
+ """第四步:缓存"""
|
|
|
+ # 4.1读取或者新建缓存
|
|
|
+ def get_template_embeddings(self, skin_data):
|
|
|
+
|
|
|
+ if not skin_data:
|
|
|
+ return None, []
|
|
|
+
|
|
|
+ # 4.1.2准备模板关键词文本
|
|
|
+ template_texts = []
|
|
|
+ template_infos = []
|
|
|
+
|
|
|
+ for template in skin_data:
|
|
|
+ try:
|
|
|
+ keywords = json.loads(template.get('template_keyword', '[]'))
|
|
|
+ keyword_text = ' '.join(keywords)
|
|
|
+ template_texts.append(keyword_text)
|
|
|
+ template_infos.append(template)
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ continue
|
|
|
+
|
|
|
+ if not template_texts:
|
|
|
+ return None, []
|
|
|
+
|
|
|
+ # 4.1.3检查缓存
|
|
|
+ cache_key = self._get_template_cache_key(skin_data)
|
|
|
+ if cache_key in self.template_embeddings_cache:
|
|
|
+ print("📦 从缓存加载模板嵌入向量")
|
|
|
+ template_embeddings = self.template_embeddings_cache[cache_key]
|
|
|
+ return template_infos, template_embeddings
|
|
|
+
|
|
|
+ print("🔄 计算新的模板嵌入向量")
|
|
|
+ template_embeddings = self.text_model.encode(template_texts)
|
|
|
+
|
|
|
+ # 4.1.4存储到缓存(cachetools自动处理过期)
|
|
|
+ self.template_embeddings_cache[cache_key] = template_embeddings
|
|
|
+ return template_infos, template_embeddings
|
|
|
+
|
|
|
+ # 4.2 生成模板数据的缓存键
|
|
|
+ def _get_template_cache_key(self, skin_data):
|
|
|
+ template_keys = []
|
|
|
+ for template in skin_data:
|
|
|
+ template_keys.append(f"{template.get('id')}_{template.get('template_keyword')}")
|
|
|
+ combined_key = '|||'.join(sorted(template_keys))
|
|
|
+ return hashlib.md5(combined_key.encode('utf-8')).hexdigest()
|
|
|
+
|
|
|
+ # 4.3 清空模板缓存
|
|
|
+ def clear_cache(self):
|
|
|
+ self.template_embeddings_cache.clear()
|
|
|
+ print("🗑️ 模板嵌入向量缓存已清空")
|
|
|
+
|
|
|
+ # 4.4 获取模板缓存
|
|
|
+ def get_cache_info(self):
|
|
|
+ return {
|
|
|
+ 'template_embeddings_cache_size': len(self.template_embeddings_cache),
|
|
|
+ 'cached_templates_count': len(self.template_embeddings_cache),
|
|
|
+ 'cache_maxsize': self.template_embeddings_cache.maxsize,
|
|
|
+ 'cache_ttl': self.template_embeddings_cache.ttl
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ """第五步:生成通栏数据"""
|
|
|
+ # 5.1 随机生成通栏数据
|
|
|
+ def get_sectors(self,website_id,matched_template_id,token):
|
|
|
+ print("开始生成通栏数据!")
|
|
|
+ print(matched_template_id)
|
|
|
+ # 根据matched_template_id选择样式数据
|
|
|
+ if int(matched_template_id) == 5:
|
|
|
+ print("正在使用第5套皮肤的数据!")
|
|
|
+ self.current_style_data = sectorStyle5Data
|
|
|
+ else:
|
|
|
+ print("正在使用前4套皮肤的数据!")
|
|
|
+ self.current_style_data = sectorStyle1to4Data
|
|
|
+
|
|
|
+ sectors_config = self.current_style_data.sectors_config
|
|
|
+ scheduler = SectorScheduler(sectors_config)
|
|
|
+ result = scheduler.generate_sequence(10)
|
|
|
+
|
|
|
+ # 处理强制位置(倒数位置)
|
|
|
+ forced_map = {}
|
|
|
+ for sector in sectors_config:
|
|
|
+ if 'force_pos' in sector:
|
|
|
+ # force_pos: 1(倒数第一), 2(倒数第二), 3(倒数第三)
|
|
|
+ forced_map[sector['force_pos']] = sector['name']
|
|
|
+
|
|
|
+ # 替换结果中的对应位置
|
|
|
+ if forced_map and result:
|
|
|
+ for pos, name in forced_map.items():
|
|
|
+ idx = -1 * pos
|
|
|
+ if abs(idx) <= len(result):
|
|
|
+ print(f"强制调整位置: 将 {name} 放入索引 {idx}")
|
|
|
+ result[idx] = name
|
|
|
+
|
|
|
+ # print("结果:", result)
|
|
|
+ sectorNames = []
|
|
|
+
|
|
|
+ # 第一步:获取到筛选结果
|
|
|
+ for sector_name in result:
|
|
|
+ for sector in sectors_config:
|
|
|
+ if sector['name'] == sector_name:
|
|
|
+ sectorNames.append(sector['name'])
|
|
|
+ break
|
|
|
+
|
|
|
+ # 第二步:使用筛选结果把sectors_webSiteData重组为模板数据
|
|
|
+ for name in sectorNames:
|
|
|
+ # 从字典中直接获取对应的sector数据
|
|
|
+ if name in self.current_style_data.sectors_data:
|
|
|
+ matched_sector = self.current_style_data.sectors_data[name]
|
|
|
+ # 添加到templateData.sectors_webSiteData.index
|
|
|
+ # templateData.sectors_webSiteData["template"]["index"].append(matched_sector)
|
|
|
+ # print(matched_sector)
|
|
|
+ # print(templateData.sectors_webSiteData["template"]["index"])
|
|
|
+ templateData.sectors_webSiteData["template"]["index"].append(matched_sector)
|
|
|
+
|
|
|
+ #print("模板数据:" + json.dumps(templateData.sectors_webSiteData["template"]["index"], ensure_ascii=False))
|
|
|
+ #获得通栏名称
|
|
|
+ ai_service.test_normal_case_cnames()
|
|
|
+
|
|
|
+ # 第三步:选择组件样式
|
|
|
+ ai_service.updata_component_style(website_id,matched_template_id,token)
|
|
|
+
|
|
|
+ # 第四步:依照用户需求修改样式
|
|
|
+ ai_service.updata_component_style_to_user(website_id,matched_template_id,token)
|
|
|
+
|
|
|
+ # 第五步:生成画布数据
|
|
|
+ # 循环templateData.sectors_webSiteData["template"]["index"],生成画布数据
|
|
|
+ # templateData.canvas_data["template"]["index"].append()
|
|
|
+ all_y = 0 # 初始化累计高度
|
|
|
+ for index,sectorData in enumerate(templateData.sectors_webSiteData["template"]["index"]):
|
|
|
+ # 获取当前通栏的高度
|
|
|
+ sector_height = sectorData["sectorCanvasHeight"]
|
|
|
+ # 生成画布数据并添加到canvas_data
|
|
|
+ canvas_item = ai_service.get_canvans_data(sector_height, all_y, sectorData,index)
|
|
|
+ templateData.canvas_data["template"]["index"].append(canvas_item)
|
|
|
+ # 更新累计高度
|
|
|
+ all_y += sector_height
|
|
|
+
|
|
|
+ # 5.2 根据通栏数据生成画布数据
|
|
|
+ def get_canvans_data(self,sector_height,all_y,sectorData,sort):
|
|
|
+ data = {
|
|
|
+ "i": int(time.time() * 1000) * 1000 + random.randint(0, 9999), # 毫秒级时间戳后加随机4位数
|
|
|
+ "x": 0,#通栏的x坐标恒定为0
|
|
|
+ "y": all_y,#通栏的y坐标为当前所有通栏的高度之和
|
|
|
+ "w": 12,#通栏的宽度恒定为12
|
|
|
+ "h": sector_height,#通栏的高度为sectorData["sectorCanvasHeight"]
|
|
|
+ "type": sectorData["sectorName"],
|
|
|
+ "content":sectorData,
|
|
|
+ "dataSort": sort,
|
|
|
+ "moved": False
|
|
|
+ }
|
|
|
+
|
|
|
+ return data
|
|
|
+
|
|
|
+ # 5.3 生成随机的样式数据
|
|
|
+ def updata_component_style(self,website_id,matched_template_id,token):
|
|
|
+ getComponentStyle = post_request("public/getAllSectorComponentStyle", {
|
|
|
+ "website_id":website_id,
|
|
|
+ "template_id": matched_template_id,
|
|
|
+ },token)
|
|
|
+ # 请补全此方法
|
|
|
+ # 检查响应是否成功
|
|
|
+ if getComponentStyle.get('code') != 200:
|
|
|
+ print(f"获取组件样式失败: {getComponentStyle.get('message')}")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 获取组件样式数据
|
|
|
+ style_data = getComponentStyle.get('data', [])
|
|
|
+ self.component_style_data = style_data
|
|
|
+
|
|
|
+ # 需要跳过的通栏类型
|
|
|
+ skip_sectors = ["adSector", "linkSector", "linkCxfwSector"]
|
|
|
+
|
|
|
+ # 更新推理过程
|
|
|
+ self.reasoning.append(f"通栏创建完毕")
|
|
|
+
|
|
|
+ # 遍历模板中的每个通栏
|
|
|
+ for sector in templateData.sectors_webSiteData["template"]["index"]:
|
|
|
+ sector_name = sector["sectorName"]
|
|
|
+
|
|
|
+ # 跳过不需要修改样式的通栏
|
|
|
+ if sector_name in skip_sectors:
|
|
|
+ print(f"跳过 {sector_name} 的组件样式修改")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 在样式数据中查找匹配的通栏
|
|
|
+ matched_style = next((s for s in style_data if s["sector_id"] == sector_name), None)
|
|
|
+ if not matched_style:
|
|
|
+ print(f"警告: 未找到 {sector_name} 的组件样式配置")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 获取该通栏的组件列表配置
|
|
|
+ component_list_config = matched_style.get("component_list", [])
|
|
|
+
|
|
|
+ # 遍历通栏中的每个组件
|
|
|
+ for idx, component in enumerate(sector["componentList"]):
|
|
|
+ if idx < len(component_list_config):
|
|
|
+ comp_config = component_list_config[idx]
|
|
|
+ if comp_config and isinstance(comp_config, list) and len(comp_config) > 0:
|
|
|
+ comp_style_data = comp_config[0].get("component_style_data", [])
|
|
|
+
|
|
|
+ # 如果有可用的样式数据,随机选择一个样式
|
|
|
+ if comp_style_data:
|
|
|
+ style_count = len(comp_style_data)
|
|
|
+ random_style = random.randint(1, style_count)
|
|
|
+ component["component_style"] = random_style
|
|
|
+ print(f"初始化 {sector_name} 的第{idx+1}个组件样式为: {random_style}")
|
|
|
+ # 更新推理过程
|
|
|
+ #self.reasoning.append(f"初始化 {sector_name} 的第{idx+1}个组件样式为: {random_style}号样式")
|
|
|
+ else:
|
|
|
+ print(f"警告: {sector_name} 的第{idx+1}个组件没有可用的样式数据")
|
|
|
+ else:
|
|
|
+ print(f"警告: {sector_name} 的第{idx+1}个组件配置无效")
|
|
|
+ else:
|
|
|
+ print(f"警告: {sector_name} 的组件数量({len(sector['componentList'])})超过配置数量({len(component_list_config)})")
|
|
|
+
|
|
|
+ # 5.4 根据用户需求修改样式数据
|
|
|
+ def updata_component_style_to_user(self, website_id, matched_template_id, token):
|
|
|
+ # 需要跳过的通栏类型
|
|
|
+ skip_sectors = ["adSector", "linkSector", "linkCxfwSector"]
|
|
|
+
|
|
|
+ # 1. 将用户输入转换为向量
|
|
|
+ user_embedding = self.text_to_embedding(self.user_input)
|
|
|
+ if len(user_embedding.shape) == 1:
|
|
|
+ user_embedding = user_embedding.reshape(1, -1)
|
|
|
+
|
|
|
+ # 更新推理过程
|
|
|
+ self.reasoning.append(f"最后让我挑选一下符合用户需求的组件样式")
|
|
|
+
|
|
|
+ # 2. 遍历每个通栏
|
|
|
+ for sector in templateData.sectors_webSiteData["template"]["index"]:
|
|
|
+ sector_name = sector["sectorName"]
|
|
|
+
|
|
|
+ # 跳过不需要修改样式的通栏
|
|
|
+ if sector_name in skip_sectors:
|
|
|
+ print(f"跳过 {sector_name} 的组件样式优化")
|
|
|
+ # 更新推理过程
|
|
|
+ #self.reasoning.append(f"跳过 {sector_name} 的组件样式优化")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 3. 在样式数据中查找匹配的通栏
|
|
|
+ matched_style = next((s for s in self.component_style_data if s["sector_id"] == sector_name), None)
|
|
|
+ if not matched_style:
|
|
|
+ print(f"警告: 未找到 {sector_name} 的组件样式配置")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 4. 遍历通栏中的每个组件
|
|
|
+ for idx, component in enumerate(sector["componentList"]):
|
|
|
+ comp_config_list = matched_style.get("component_list", [])
|
|
|
+ if idx >= len(comp_config_list) or not comp_config_list[idx]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ comp_config = comp_config_list[idx][0] # 取第一个配置
|
|
|
+ comp_style_data = comp_config.get("component_style_data", [])
|
|
|
+
|
|
|
+ if not comp_style_data:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 5. 提取所有样式的描述文本
|
|
|
+ style_descriptions = [style["img_name"] for style in comp_style_data]
|
|
|
+
|
|
|
+ # 6. 将样式描述转换为向量
|
|
|
+ style_embeddings = self.text_to_embedding(style_descriptions)
|
|
|
+
|
|
|
+ # 7. 计算用户输入与每个样式的相似度
|
|
|
+ similarities = cosine_similarity(user_embedding, style_embeddings)[0]
|
|
|
+ best_index = np.argmax(similarities)
|
|
|
+ best_style = comp_style_data[best_index]
|
|
|
+
|
|
|
+ # 8. 更新组件样式
|
|
|
+ component["component_style"] = best_style["img_id"]
|
|
|
+ print(f"优化 {sector_name} 的第{idx+1}个组件样式为: {best_style['img_name']} (ID: {best_style['img_id']})")
|
|
|
+ # 更新推理过程
|
|
|
+ self.reasoning.append(f"选择 {sector_name} 的第{idx+1}个组件样式为: {best_style['img_name']} (ID: {best_style['img_id']})")
|
|
|
+
|
|
|
+
|
|
|
+ """其他功能 模拟推理过程"""
|
|
|
+ # 获得选择通栏的中文名称
|
|
|
+ def test_normal_case_cnames(self):
|
|
|
+ #获得生成的模板数据
|
|
|
+ result = templateData.sectors_webSiteData["template"]["index"]
|
|
|
+ #更新推理过程
|
|
|
+ self.reasoning.append(f"接着我要创建一套符合用户需求的通栏组合")
|
|
|
+ #遍历result获得所有通栏的Cname
|
|
|
+ cnnames = []
|
|
|
+
|
|
|
+ # 遍历每个通栏,使用enumerate获取索引
|
|
|
+ for i, sector in enumerate(result):
|
|
|
+ sector_name = sector["sectorName"]
|
|
|
+ # 在sectors_config中查找对应的中文名称
|
|
|
+ if self.current_style_data:
|
|
|
+ for config in self.current_style_data.sectors_config:
|
|
|
+ if config["name"] == sector_name:
|
|
|
+ cnnames.append(config["CNname"])
|
|
|
+ # 使用通栏在结果列表中的索引+1作为序号
|
|
|
+ self.reasoning.append(f"第{i+1}个通栏选择: {config['CNname']}")
|
|
|
+ break # 找到后跳出内层循环
|
|
|
+
|
|
|
+ # 清理数据
|
|
|
+ def clear_data(self):
|
|
|
+ #清理模板数据
|
|
|
+ templateData.canvas_data["template"]["index"] = []
|
|
|
+ templateData.sectors_webSiteData["template"]["index"] = []
|
|
|
+ #清理推理数据
|
|
|
+ self.reasoning = []
|
|
|
+ self.current_style_data = None
|
|
|
+ #初始化权重
|
|
|
+ sectorStyle1to4Data.reset_sectors_config()
|
|
|
+ sectorStyle5Data.reset_sectors_config()
|
|
|
+
|
|
|
+ #添加推理过程
|
|
|
+ def add_reasoning(self, reasoning):
|
|
|
+ self.reasoning.append(reasoning)
|
|
|
+
|
|
|
+ai_service = AIService()
|