#1.引入numpy import numpy as np #2.引入模型 from models.model_loader import model_loader #3.引入scikit-learn的余弦相似度函数 from sklearn.metrics.pairwise import cosine_similarity #4.引入json import json #5.引入缓存机制 from cachetools import TTLCache #6.引入哈希值 import hashlib #7.引入标题提取 import utils.matching as TitleMatcher #8.引入模板数据 from services.templateFun import templateData, SectorScheduler # 引入样式数据 from services.style.style1to4data import sectorStyle1to4Data from services.style.style5data import sectorStyle5Data #9.引入随机数 import random #10.引入请求 from utils.request import post_request,generate_error #11.引入时间 import time #12.引入输入类型 import utils.inputType as inputType # 创建AI处理类 class AIService: """第一步:初始化""" def __init__(self): # 1.1加载文本处理模型 self.text_model = model_loader.load_text_model() # 1.2使用TTLCache,最大缓存50个模板,1小时过期,缓存模板嵌入向量 self.template_embeddings_cache = TTLCache(maxsize=50, ttl=3600) # 1.3用户输入文本 self.user_input = "" # 1.4组件样式数据 self.component_style_data = [] # 1.5推理过程 self.reasoning = [] # 1.6 当前样式数据对象 self.current_style_data = None """第二步:调用模型能力""" #2.1 文字转换成向量 def text_to_embedding(self, text): # 2.1.1如果text是字符串,则转换成列表 if isinstance(text, str): text = [text] embeddings = self.text_model.encode(text) # 2.2.2返回numpy数组,不要转成list return embeddings #2.2生成文字摘要 def userMessage_to_title(self, userMessage): return TitleMatcher.title_matcher.generate_title(userMessage) #2.3筛选通栏 def filter_sectors(self, userMessage,matched_template_id): return inputType.inputType_matcher.generate_inputType(userMessage,matched_template_id) """第三步:查找最匹配的皮肤模板""" def find_best_matching_template(self, user_text, skin_data): # 3.1如果皮肤数据为空,则返回None if not skin_data: return None # 3.2提取用户输入中的颜色关键词 color_keywords = self.extract_color_keywords(user_text) print(f"提取到用户颜色关键词: {color_keywords}") # 更新推理过程 self.reasoning.append(f"我提取到用户需求中的颜色关键词是: {color_keywords}") # 3.3根据颜色关键词过滤模板 filtered_skin_data = skin_data if color_keywords: filtered_skin_data = self.filter_templates_by_color(skin_data, color_keywords) print(f"根据颜色过滤后模板数量: {len(filtered_skin_data)}") # 更新推理过程 self.reasoning.append(f"根据颜色过滤数据以后有: {len(filtered_skin_data)}个符合该需求的模板") # 3.4获取模板信息(使用缓存) template_infos, template_embeddings = self.get_template_embeddings(filtered_skin_data) # 3.5如果模板信息为空,则返回None if not template_infos: return None # 3.6用户输入的文本转化为向量 user_embedding = self.text_to_embedding(user_text) # 3.7如果用户输入的文本的维度是1,则转换成2维 if len(user_embedding.shape) == 1: user_embedding = user_embedding.reshape(1, -1) if len(template_embeddings.shape) == 1: template_embeddings = template_embeddings.reshape(1, -1) # 3.8通过余弦函数批量计算相似度 similarities = cosine_similarity(user_embedding, template_embeddings)[0] best_index = np.argmax(similarities) best_similarity = similarities[best_index] return { 'template_info': template_infos[best_index], 'similarity_score': best_similarity } # 3.9提取用户输入中的颜色关键词 def extract_color_keywords(self, text): color_map = { "红": "红色", "绿": "绿色", "蓝": "蓝色", "黄": "黄色", "橙": "橙色", "紫": "紫色", "黑": "黑色", "白": "白色", "灰": "灰色" } found_colors = [] for char, color in color_map.items(): if char in text: found_colors.append(color) return found_colors # 3.10根据颜色关键词过滤模板 def filter_templates_by_color(self, skin_data, color_keywords): filtered = [] for template in skin_data: try: keywords = json.loads(template.get('template_keyword', '[]')) keyword_text = ' '.join(keywords) # 检查模板关键词是否包含任何颜色关键词 for color in color_keywords: if color in keyword_text: filtered.append(template) break except: continue return filtered """第四步:缓存""" # 4.1读取或者新建缓存 def get_template_embeddings(self, skin_data): if not skin_data: return None, [] # 4.1.2准备模板关键词文本 template_texts = [] template_infos = [] for template in skin_data: try: keywords = json.loads(template.get('template_keyword', '[]')) keyword_text = ' '.join(keywords) template_texts.append(keyword_text) template_infos.append(template) except json.JSONDecodeError: continue if not template_texts: return None, [] # 4.1.3检查缓存 cache_key = self._get_template_cache_key(skin_data) if cache_key in self.template_embeddings_cache: print("📦 从缓存加载模板嵌入向量") template_embeddings = self.template_embeddings_cache[cache_key] return template_infos, template_embeddings print("🔄 计算新的模板嵌入向量") template_embeddings = self.text_model.encode(template_texts) # 4.1.4存储到缓存(cachetools自动处理过期) self.template_embeddings_cache[cache_key] = template_embeddings return template_infos, template_embeddings # 4.2 生成模板数据的缓存键 def _get_template_cache_key(self, skin_data): template_keys = [] for template in skin_data: template_keys.append(f"{template.get('id')}_{template.get('template_keyword')}") combined_key = '|||'.join(sorted(template_keys)) return hashlib.md5(combined_key.encode('utf-8')).hexdigest() # 4.3 清空模板缓存 def clear_cache(self): self.template_embeddings_cache.clear() print("🗑️ 模板嵌入向量缓存已清空") # 4.4 获取模板缓存 def get_cache_info(self): return { 'template_embeddings_cache_size': len(self.template_embeddings_cache), 'cached_templates_count': len(self.template_embeddings_cache), 'cache_maxsize': self.template_embeddings_cache.maxsize, 'cache_ttl': self.template_embeddings_cache.ttl } """第五步:生成通栏数据""" # 5.1 随机生成通栏数据 def get_sectors(self,website_id,matched_template_id,token): print("开始生成通栏数据!") print(matched_template_id) # 根据matched_template_id选择样式数据 if int(matched_template_id) == 5: print("正在使用第5套皮肤的数据!") self.current_style_data = sectorStyle5Data else: print("正在使用前4套皮肤的数据!") self.current_style_data = sectorStyle1to4Data sectors_config = self.current_style_data.sectors_config scheduler = SectorScheduler(sectors_config) result = scheduler.generate_sequence(10) # 处理强制位置(倒数位置) forced_map = {} for sector in sectors_config: if 'force_pos' in sector: # force_pos: 1(倒数第一), 2(倒数第二), 3(倒数第三) forced_map[sector['force_pos']] = sector['name'] # 替换结果中的对应位置 if forced_map and result: for pos, name in forced_map.items(): idx = -1 * pos if abs(idx) <= len(result): print(f"强制调整位置: 将 {name} 放入索引 {idx}") result[idx] = name # print("结果:", result) sectorNames = [] # 第一步:获取到筛选结果 for sector_name in result: for sector in sectors_config: if sector['name'] == sector_name: sectorNames.append(sector['name']) break # 第二步:使用筛选结果把sectors_webSiteData重组为模板数据 for name in sectorNames: # 从字典中直接获取对应的sector数据 if name in self.current_style_data.sectors_data: matched_sector = self.current_style_data.sectors_data[name] # 添加到templateData.sectors_webSiteData.index # templateData.sectors_webSiteData["template"]["index"].append(matched_sector) # print(matched_sector) # print(templateData.sectors_webSiteData["template"]["index"]) templateData.sectors_webSiteData["template"]["index"].append(matched_sector) #print("模板数据:" + json.dumps(templateData.sectors_webSiteData["template"]["index"], ensure_ascii=False)) #获得通栏名称 ai_service.test_normal_case_cnames() # 第三步:选择组件样式 ai_service.updata_component_style(website_id,matched_template_id,token) # 第四步:依照用户需求修改样式 ai_service.updata_component_style_to_user(website_id,matched_template_id,token) # 第五步:生成画布数据 # 循环templateData.sectors_webSiteData["template"]["index"],生成画布数据 # templateData.canvas_data["template"]["index"].append() all_y = 0 # 初始化累计高度 for index,sectorData in enumerate(templateData.sectors_webSiteData["template"]["index"]): # 获取当前通栏的高度 sector_height = sectorData["sectorCanvasHeight"] # 生成画布数据并添加到canvas_data canvas_item = ai_service.get_canvans_data(sector_height, all_y, sectorData,index) templateData.canvas_data["template"]["index"].append(canvas_item) # 更新累计高度 all_y += sector_height # 5.2 根据通栏数据生成画布数据 def get_canvans_data(self,sector_height,all_y,sectorData,sort): data = { "i": int(time.time() * 1000) * 1000 + random.randint(0, 9999), # 毫秒级时间戳后加随机4位数 "x": 0,#通栏的x坐标恒定为0 "y": all_y,#通栏的y坐标为当前所有通栏的高度之和 "w": 12,#通栏的宽度恒定为12 "h": sector_height,#通栏的高度为sectorData["sectorCanvasHeight"] "type": sectorData["sectorName"], "content":sectorData, "dataSort": sort, "moved": False } return data # 5.3 生成随机的样式数据 def updata_component_style(self,website_id,matched_template_id,token): getComponentStyle = post_request("public/getAllSectorComponentStyle", { "website_id":website_id, "template_id": matched_template_id, },token) # 请补全此方法 # 检查响应是否成功 if getComponentStyle.get('code') != 200: print(f"获取组件样式失败: {getComponentStyle.get('message')}") return # 获取组件样式数据 style_data = getComponentStyle.get('data', []) self.component_style_data = style_data # 需要跳过的通栏类型 skip_sectors = ["adSector", "linkSector", "linkCxfwSector"] # 更新推理过程 self.reasoning.append(f"通栏创建完毕") # 遍历模板中的每个通栏 for sector in templateData.sectors_webSiteData["template"]["index"]: sector_name = sector["sectorName"] # 跳过不需要修改样式的通栏 if sector_name in skip_sectors: print(f"跳过 {sector_name} 的组件样式修改") continue # 在样式数据中查找匹配的通栏 matched_style = next((s for s in style_data if s["sector_id"] == sector_name), None) if not matched_style: print(f"警告: 未找到 {sector_name} 的组件样式配置") continue # 获取该通栏的组件列表配置 component_list_config = matched_style.get("component_list", []) # 遍历通栏中的每个组件 for idx, component in enumerate(sector["componentList"]): if idx < len(component_list_config): comp_config = component_list_config[idx] if comp_config and isinstance(comp_config, list) and len(comp_config) > 0: comp_style_data = comp_config[0].get("component_style_data", []) # 如果有可用的样式数据,随机选择一个样式 if comp_style_data: style_count = len(comp_style_data) random_style = random.randint(1, style_count) component["component_style"] = random_style print(f"初始化 {sector_name} 的第{idx+1}个组件样式为: {random_style}") # 更新推理过程 #self.reasoning.append(f"初始化 {sector_name} 的第{idx+1}个组件样式为: {random_style}号样式") else: print(f"警告: {sector_name} 的第{idx+1}个组件没有可用的样式数据") else: print(f"警告: {sector_name} 的第{idx+1}个组件配置无效") else: print(f"警告: {sector_name} 的组件数量({len(sector['componentList'])})超过配置数量({len(component_list_config)})") # 5.4 根据用户需求修改样式数据 def updata_component_style_to_user(self, website_id, matched_template_id, token): # 需要跳过的通栏类型 skip_sectors = ["adSector", "linkSector", "linkCxfwSector"] # 1. 将用户输入转换为向量 user_embedding = self.text_to_embedding(self.user_input) if len(user_embedding.shape) == 1: user_embedding = user_embedding.reshape(1, -1) # 更新推理过程 self.reasoning.append(f"最后让我挑选一下符合用户需求的组件样式") # 2. 遍历每个通栏 for sector in templateData.sectors_webSiteData["template"]["index"]: sector_name = sector["sectorName"] # 跳过不需要修改样式的通栏 if sector_name in skip_sectors: print(f"跳过 {sector_name} 的组件样式优化") # 更新推理过程 #self.reasoning.append(f"跳过 {sector_name} 的组件样式优化") continue # 3. 在样式数据中查找匹配的通栏 matched_style = next((s for s in self.component_style_data if s["sector_id"] == sector_name), None) if not matched_style: print(f"警告: 未找到 {sector_name} 的组件样式配置") continue # 4. 遍历通栏中的每个组件 for idx, component in enumerate(sector["componentList"]): comp_config_list = matched_style.get("component_list", []) if idx >= len(comp_config_list) or not comp_config_list[idx]: continue comp_config = comp_config_list[idx][0] # 取第一个配置 comp_style_data = comp_config.get("component_style_data", []) if not comp_style_data: continue # 5. 提取所有样式的描述文本 style_descriptions = [style["img_name"] for style in comp_style_data] # 6. 将样式描述转换为向量 style_embeddings = self.text_to_embedding(style_descriptions) # 7. 计算用户输入与每个样式的相似度 similarities = cosine_similarity(user_embedding, style_embeddings)[0] best_index = np.argmax(similarities) best_style = comp_style_data[best_index] # 8. 更新组件样式 component["component_style"] = best_style["img_id"] print(f"优化 {sector_name} 的第{idx+1}个组件样式为: {best_style['img_name']} (ID: {best_style['img_id']})") # 更新推理过程 self.reasoning.append(f"选择 {sector_name} 的第{idx+1}个组件样式为: {best_style['img_name']} (ID: {best_style['img_id']})") """其他功能 模拟推理过程""" # 获得选择通栏的中文名称 def test_normal_case_cnames(self): #获得生成的模板数据 result = templateData.sectors_webSiteData["template"]["index"] #更新推理过程 self.reasoning.append(f"接着我要创建一套符合用户需求的通栏组合") #遍历result获得所有通栏的Cname cnnames = [] # 遍历每个通栏,使用enumerate获取索引 for i, sector in enumerate(result): sector_name = sector["sectorName"] # 在sectors_config中查找对应的中文名称 if self.current_style_data: for config in self.current_style_data.sectors_config: if config["name"] == sector_name: cnnames.append(config["CNname"]) # 使用通栏在结果列表中的索引+1作为序号 self.reasoning.append(f"第{i+1}个通栏选择: {config['CNname']}") break # 找到后跳出内层循环 # 清理数据 def clear_data(self): #清理模板数据 templateData.canvas_data["template"]["index"] = [] templateData.sectors_webSiteData["template"]["index"] = [] #清理推理数据 self.reasoning = [] self.current_style_data = None #初始化权重 sectorStyle1to4Data.reset_sectors_config() sectorStyle5Data.reset_sectors_config() #添加推理过程 def add_reasoning(self, reasoning): self.reasoning.append(reasoning) ai_service = AIService()