| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462 |
- #1.引入numpy
- import numpy as np
- #2.引入模型
- from models.model_loader import model_loader
- #3.引入scikit-learn的余弦相似度函数
- from sklearn.metrics.pairwise import cosine_similarity
- #4.引入json
- import json
- #5.引入缓存机制
- from cachetools import TTLCache
- #6.引入哈希值
- import hashlib
- #7.引入标题提取
- import utils.matching as TitleMatcher
- #8.引入模板数据
- from services.templateFun import templateData, SectorScheduler
- # 引入样式数据
- from services.style.style1to4data import sectorStyle1to4Data
- from services.style.style5data import sectorStyle5Data
- #9.引入随机数
- import random
- #10.引入请求
- from utils.request import post_request,generate_error
- #11.引入时间
- import time
- #12.引入输入类型
- import utils.inputType as inputType
- # 创建AI处理类
- class AIService:
- """第一步:初始化"""
- def __init__(self):
- # 1.1加载文本处理模型
- self.text_model = model_loader.load_text_model()
- # 1.2使用TTLCache,最大缓存50个模板,1小时过期,缓存模板嵌入向量
- self.template_embeddings_cache = TTLCache(maxsize=50, ttl=3600)
- # 1.3用户输入文本
- self.user_input = ""
- # 1.4组件样式数据
- self.component_style_data = []
- # 1.5推理过程
- self.reasoning = []
- # 1.6 当前样式数据对象
- self.current_style_data = None
- """第二步:调用模型能力"""
- #2.1 文字转换成向量
- def text_to_embedding(self, text):
- # 2.1.1如果text是字符串,则转换成列表
- if isinstance(text, str):
- text = [text]
-
- embeddings = self.text_model.encode(text)
- # 2.2.2返回numpy数组,不要转成list
- return embeddings
- #2.2生成文字摘要
- def userMessage_to_title(self, userMessage):
- return TitleMatcher.title_matcher.generate_title(userMessage)
- #2.3筛选通栏
- def filter_sectors(self, userMessage,matched_template_id):
- return inputType.inputType_matcher.generate_inputType(userMessage,matched_template_id)
-
- """第三步:查找最匹配的皮肤模板"""
- def find_best_matching_template(self, user_text, skin_data):
- # 3.1如果皮肤数据为空,则返回None
- if not skin_data:
- return None
-
- # 3.2提取用户输入中的颜色关键词
- color_keywords = self.extract_color_keywords(user_text)
- print(f"提取到用户颜色关键词: {color_keywords}")
- # 更新推理过程
- self.reasoning.append(f"我提取到用户需求中的颜色关键词是: {color_keywords}")
-
- # 3.3根据颜色关键词过滤模板
- filtered_skin_data = skin_data
- if color_keywords:
- filtered_skin_data = self.filter_templates_by_color(skin_data, color_keywords)
- print(f"根据颜色过滤后模板数量: {len(filtered_skin_data)}")
- # 更新推理过程
- self.reasoning.append(f"根据颜色过滤数据以后有: {len(filtered_skin_data)}个符合该需求的模板")
-
- # 3.4获取模板信息(使用缓存)
- template_infos, template_embeddings = self.get_template_embeddings(filtered_skin_data)
-
- # 3.5如果模板信息为空,则返回None
- if not template_infos:
- return None
-
- # 3.6用户输入的文本转化为向量
- user_embedding = self.text_to_embedding(user_text)
-
- # 3.7如果用户输入的文本的维度是1,则转换成2维
- if len(user_embedding.shape) == 1:
- user_embedding = user_embedding.reshape(1, -1)
- if len(template_embeddings.shape) == 1:
- template_embeddings = template_embeddings.reshape(1, -1)
-
- # 3.8通过余弦函数批量计算相似度
- similarities = cosine_similarity(user_embedding, template_embeddings)[0]
- best_index = np.argmax(similarities)
- best_similarity = similarities[best_index]
-
- return {
- 'template_info': template_infos[best_index],
- 'similarity_score': best_similarity
- }
- # 3.9提取用户输入中的颜色关键词
- def extract_color_keywords(self, text):
- color_map = {
- "红": "红色",
- "绿": "绿色",
- "蓝": "蓝色",
- "黄": "黄色",
- "橙": "橙色",
- "紫": "紫色",
- "黑": "黑色",
- "白": "白色",
- "灰": "灰色"
- }
-
- found_colors = []
- for char, color in color_map.items():
- if char in text:
- found_colors.append(color)
-
- return found_colors
-
- # 3.10根据颜色关键词过滤模板
- def filter_templates_by_color(self, skin_data, color_keywords):
- filtered = []
- for template in skin_data:
- try:
- keywords = json.loads(template.get('template_keyword', '[]'))
- keyword_text = ' '.join(keywords)
-
- # 检查模板关键词是否包含任何颜色关键词
- for color in color_keywords:
- if color in keyword_text:
- filtered.append(template)
- break
- except:
- continue
-
- return filtered
- """第四步:缓存"""
- # 4.1读取或者新建缓存
- def get_template_embeddings(self, skin_data):
-
- if not skin_data:
- return None, []
-
- # 4.1.2准备模板关键词文本
- template_texts = []
- template_infos = []
-
- for template in skin_data:
- try:
- keywords = json.loads(template.get('template_keyword', '[]'))
- keyword_text = ' '.join(keywords)
- template_texts.append(keyword_text)
- template_infos.append(template)
- except json.JSONDecodeError:
- continue
-
- if not template_texts:
- return None, []
-
- # 4.1.3检查缓存
- cache_key = self._get_template_cache_key(skin_data)
- if cache_key in self.template_embeddings_cache:
- print("📦 从缓存加载模板嵌入向量")
- template_embeddings = self.template_embeddings_cache[cache_key]
- return template_infos, template_embeddings
-
- print("🔄 计算新的模板嵌入向量")
- template_embeddings = self.text_model.encode(template_texts)
-
- # 4.1.4存储到缓存(cachetools自动处理过期)
- self.template_embeddings_cache[cache_key] = template_embeddings
- return template_infos, template_embeddings
- # 4.2 生成模板数据的缓存键
- def _get_template_cache_key(self, skin_data):
- template_keys = []
- for template in skin_data:
- template_keys.append(f"{template.get('id')}_{template.get('template_keyword')}")
- combined_key = '|||'.join(sorted(template_keys))
- return hashlib.md5(combined_key.encode('utf-8')).hexdigest()
- # 4.3 清空模板缓存
- def clear_cache(self):
- self.template_embeddings_cache.clear()
- print("🗑️ 模板嵌入向量缓存已清空")
-
- # 4.4 获取模板缓存
- def get_cache_info(self):
- return {
- 'template_embeddings_cache_size': len(self.template_embeddings_cache),
- 'cached_templates_count': len(self.template_embeddings_cache),
- 'cache_maxsize': self.template_embeddings_cache.maxsize,
- 'cache_ttl': self.template_embeddings_cache.ttl
- }
- """第五步:生成通栏数据"""
- # 5.1 随机生成通栏数据
- def get_sectors(self,website_id,matched_template_id,token):
- print("开始生成通栏数据!")
- print(matched_template_id)
- # 根据matched_template_id选择样式数据
- if int(matched_template_id) == 5:
- print("正在使用第5套皮肤的数据!")
- self.current_style_data = sectorStyle5Data
- else:
- print("正在使用前4套皮肤的数据!")
- self.current_style_data = sectorStyle1to4Data
- sectors_config = self.current_style_data.sectors_config
- scheduler = SectorScheduler(sectors_config)
- result = scheduler.generate_sequence(10)
-
- # 处理强制位置(倒数位置)
- forced_map = {}
- for sector in sectors_config:
- if 'force_pos' in sector:
- # force_pos: 1(倒数第一), 2(倒数第二), 3(倒数第三)
- forced_map[sector['force_pos']] = sector['name']
-
- # 替换结果中的对应位置
- if forced_map and result:
- for pos, name in forced_map.items():
- idx = -1 * pos
- if abs(idx) <= len(result):
- print(f"强制调整位置: 将 {name} 放入索引 {idx}")
- result[idx] = name
- # print("结果:", result)
- sectorNames = []
- # 第一步:获取到筛选结果
- for sector_name in result:
- for sector in sectors_config:
- if sector['name'] == sector_name:
- sectorNames.append(sector['name'])
- break
- # 第二步:使用筛选结果把sectors_webSiteData重组为模板数据
- for name in sectorNames:
- # 从字典中直接获取对应的sector数据
- if name in self.current_style_data.sectors_data:
- matched_sector = self.current_style_data.sectors_data[name]
- # 添加到templateData.sectors_webSiteData.index
- # templateData.sectors_webSiteData["template"]["index"].append(matched_sector)
- # print(matched_sector)
- # print(templateData.sectors_webSiteData["template"]["index"])
- templateData.sectors_webSiteData["template"]["index"].append(matched_sector)
- #print("模板数据:" + json.dumps(templateData.sectors_webSiteData["template"]["index"], ensure_ascii=False))
- #获得通栏名称
- ai_service.test_normal_case_cnames()
- # 第三步:选择组件样式
- ai_service.updata_component_style(website_id,matched_template_id,token)
- # 第四步:依照用户需求修改样式
- ai_service.updata_component_style_to_user(website_id,matched_template_id,token)
- # 第五步:生成画布数据
- # 循环templateData.sectors_webSiteData["template"]["index"],生成画布数据
- # templateData.canvas_data["template"]["index"].append()
- all_y = 0 # 初始化累计高度
- for index,sectorData in enumerate(templateData.sectors_webSiteData["template"]["index"]):
- # 获取当前通栏的高度
- sector_height = sectorData["sectorCanvasHeight"]
- # 生成画布数据并添加到canvas_data
- canvas_item = ai_service.get_canvans_data(sector_height, all_y, sectorData,index)
- templateData.canvas_data["template"]["index"].append(canvas_item)
- # 更新累计高度
- all_y += sector_height
- # 5.2 根据通栏数据生成画布数据
- def get_canvans_data(self,sector_height,all_y,sectorData,sort):
- data = {
- "i": int(time.time() * 1000) * 1000 + random.randint(0, 9999), # 毫秒级时间戳后加随机4位数
- "x": 0,#通栏的x坐标恒定为0
- "y": all_y,#通栏的y坐标为当前所有通栏的高度之和
- "w": 12,#通栏的宽度恒定为12
- "h": sector_height,#通栏的高度为sectorData["sectorCanvasHeight"]
- "type": sectorData["sectorName"],
- "content":sectorData,
- "dataSort": sort,
- "moved": False
- }
- return data
- # 5.3 生成随机的样式数据
- def updata_component_style(self,website_id,matched_template_id,token):
- getComponentStyle = post_request("public/getAllSectorComponentStyle", {
- "website_id":website_id,
- "template_id": matched_template_id,
- },token)
- # 请补全此方法
- # 检查响应是否成功
- if getComponentStyle.get('code') != 200:
- print(f"获取组件样式失败: {getComponentStyle.get('message')}")
- return
-
- # 获取组件样式数据
- style_data = getComponentStyle.get('data', [])
- self.component_style_data = style_data
-
- # 需要跳过的通栏类型
- skip_sectors = ["adSector", "linkSector", "linkCxfwSector"]
- # 更新推理过程
- self.reasoning.append(f"通栏创建完毕")
-
- # 遍历模板中的每个通栏
- for sector in templateData.sectors_webSiteData["template"]["index"]:
- sector_name = sector["sectorName"]
-
- # 跳过不需要修改样式的通栏
- if sector_name in skip_sectors:
- print(f"跳过 {sector_name} 的组件样式修改")
- continue
-
- # 在样式数据中查找匹配的通栏
- matched_style = next((s for s in style_data if s["sector_id"] == sector_name), None)
- if not matched_style:
- print(f"警告: 未找到 {sector_name} 的组件样式配置")
- continue
-
- # 获取该通栏的组件列表配置
- component_list_config = matched_style.get("component_list", [])
-
- # 遍历通栏中的每个组件
- for idx, component in enumerate(sector["componentList"]):
- if idx < len(component_list_config):
- comp_config = component_list_config[idx]
- if comp_config and isinstance(comp_config, list) and len(comp_config) > 0:
- comp_style_data = comp_config[0].get("component_style_data", [])
-
- # 如果有可用的样式数据,随机选择一个样式
- if comp_style_data:
- style_count = len(comp_style_data)
- random_style = random.randint(1, style_count)
- component["component_style"] = random_style
- print(f"初始化 {sector_name} 的第{idx+1}个组件样式为: {random_style}")
- # 更新推理过程
- #self.reasoning.append(f"初始化 {sector_name} 的第{idx+1}个组件样式为: {random_style}号样式")
- else:
- print(f"警告: {sector_name} 的第{idx+1}个组件没有可用的样式数据")
- else:
- print(f"警告: {sector_name} 的第{idx+1}个组件配置无效")
- else:
- print(f"警告: {sector_name} 的组件数量({len(sector['componentList'])})超过配置数量({len(component_list_config)})")
-
- # 5.4 根据用户需求修改样式数据
- def updata_component_style_to_user(self, website_id, matched_template_id, token):
- # 需要跳过的通栏类型
- skip_sectors = ["adSector", "linkSector", "linkCxfwSector"]
-
- # 1. 将用户输入转换为向量
- user_embedding = self.text_to_embedding(self.user_input)
- if len(user_embedding.shape) == 1:
- user_embedding = user_embedding.reshape(1, -1)
- # 更新推理过程
- self.reasoning.append(f"最后让我挑选一下符合用户需求的组件样式")
-
- # 2. 遍历每个通栏
- for sector in templateData.sectors_webSiteData["template"]["index"]:
- sector_name = sector["sectorName"]
-
- # 跳过不需要修改样式的通栏
- if sector_name in skip_sectors:
- print(f"跳过 {sector_name} 的组件样式优化")
- # 更新推理过程
- #self.reasoning.append(f"跳过 {sector_name} 的组件样式优化")
- continue
-
- # 3. 在样式数据中查找匹配的通栏
- matched_style = next((s for s in self.component_style_data if s["sector_id"] == sector_name), None)
- if not matched_style:
- print(f"警告: 未找到 {sector_name} 的组件样式配置")
- continue
-
- # 4. 遍历通栏中的每个组件
- for idx, component in enumerate(sector["componentList"]):
- comp_config_list = matched_style.get("component_list", [])
- if idx >= len(comp_config_list) or not comp_config_list[idx]:
- continue
-
- comp_config = comp_config_list[idx][0] # 取第一个配置
- comp_style_data = comp_config.get("component_style_data", [])
-
- if not comp_style_data:
- continue
-
- # 5. 提取所有样式的描述文本
- style_descriptions = [style["img_name"] for style in comp_style_data]
-
- # 6. 将样式描述转换为向量
- style_embeddings = self.text_to_embedding(style_descriptions)
-
- # 7. 计算用户输入与每个样式的相似度
- similarities = cosine_similarity(user_embedding, style_embeddings)[0]
- best_index = np.argmax(similarities)
- best_style = comp_style_data[best_index]
-
- # 8. 更新组件样式
- component["component_style"] = best_style["img_id"]
- print(f"优化 {sector_name} 的第{idx+1}个组件样式为: {best_style['img_name']} (ID: {best_style['img_id']})")
- # 更新推理过程
- self.reasoning.append(f"选择 {sector_name} 的第{idx+1}个组件样式为: {best_style['img_name']} (ID: {best_style['img_id']})")
-
-
- """其他功能 模拟推理过程"""
- # 获得选择通栏的中文名称
- def test_normal_case_cnames(self):
- #获得生成的模板数据
- result = templateData.sectors_webSiteData["template"]["index"]
- #更新推理过程
- self.reasoning.append(f"接着我要创建一套符合用户需求的通栏组合")
- #遍历result获得所有通栏的Cname
- cnnames = []
-
- # 遍历每个通栏,使用enumerate获取索引
- for i, sector in enumerate(result):
- sector_name = sector["sectorName"]
- # 在sectors_config中查找对应的中文名称
- if self.current_style_data:
- for config in self.current_style_data.sectors_config:
- if config["name"] == sector_name:
- cnnames.append(config["CNname"])
- # 使用通栏在结果列表中的索引+1作为序号
- self.reasoning.append(f"第{i+1}个通栏选择: {config['CNname']}")
- break # 找到后跳出内层循环
-
- # 清理数据
- def clear_data(self):
- #清理模板数据
- templateData.canvas_data["template"]["index"] = []
- templateData.sectors_webSiteData["template"]["index"] = []
- #清理推理数据
- self.reasoning = []
- self.current_style_data = None
- #初始化权重
- sectorStyle1to4Data.reset_sectors_config()
- sectorStyle5Data.reset_sectors_config()
- #添加推理过程
- def add_reasoning(self, reasoning):
- self.reasoning.append(reasoning)
- ai_service = AIService()
|