ai_service.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. #1.引入numpy
  2. import numpy as np
  3. #2.引入模型
  4. from models.model_loader import model_loader
  5. #3.引入scikit-learn的余弦相似度函数
  6. from sklearn.metrics.pairwise import cosine_similarity
  7. #4.引入json
  8. import json
  9. #5.引入缓存机制
  10. from cachetools import TTLCache
  11. #6.引入哈希值
  12. import hashlib
  13. #7.引入标题提取
  14. import utils.matching as TitleMatcher
  15. #8.引入模板数据
  16. from services.templateFun import templateData, SectorScheduler
  17. # 引入样式数据
  18. from services.style.style1to4data import sectorStyle1to4Data
  19. from services.style.style5data import sectorStyle5Data
  20. #9.引入随机数
  21. import random
  22. #10.引入请求
  23. from utils.request import post_request,generate_error
  24. #11.引入时间
  25. import time
  26. #12.引入输入类型
  27. import utils.inputType as inputType
  28. # 创建AI处理类
  29. class AIService:
  30. """第一步:初始化"""
  31. def __init__(self):
  32. # 1.1加载文本处理模型
  33. self.text_model = model_loader.load_text_model()
  34. # 1.2使用TTLCache,最大缓存50个模板,1小时过期,缓存模板嵌入向量
  35. self.template_embeddings_cache = TTLCache(maxsize=50, ttl=3600)
  36. # 1.3用户输入文本
  37. self.user_input = ""
  38. # 1.4组件样式数据
  39. self.component_style_data = []
  40. # 1.5推理过程
  41. self.reasoning = []
  42. # 1.6 当前样式数据对象
  43. self.current_style_data = None
  44. """第二步:调用模型能力"""
  45. #2.1 文字转换成向量
  46. def text_to_embedding(self, text):
  47. # 2.1.1如果text是字符串,则转换成列表
  48. if isinstance(text, str):
  49. text = [text]
  50. embeddings = self.text_model.encode(text)
  51. # 2.2.2返回numpy数组,不要转成list
  52. return embeddings
  53. #2.2生成文字摘要
  54. def userMessage_to_title(self, userMessage):
  55. return TitleMatcher.title_matcher.generate_title(userMessage)
  56. #2.3筛选通栏
  57. def filter_sectors(self, userMessage,matched_template_id):
  58. return inputType.inputType_matcher.generate_inputType(userMessage,matched_template_id)
  59. """第三步:查找最匹配的皮肤模板"""
  60. def find_best_matching_template(self, user_text, skin_data):
  61. # 3.1如果皮肤数据为空,则返回None
  62. if not skin_data:
  63. return None
  64. # 3.2提取用户输入中的颜色关键词
  65. color_keywords = self.extract_color_keywords(user_text)
  66. print(f"提取到用户颜色关键词: {color_keywords}")
  67. # 更新推理过程
  68. self.reasoning.append(f"我提取到用户需求中的颜色关键词是: {color_keywords}")
  69. # 3.3根据颜色关键词过滤模板
  70. filtered_skin_data = skin_data
  71. if color_keywords:
  72. filtered_skin_data = self.filter_templates_by_color(skin_data, color_keywords)
  73. print(f"根据颜色过滤后模板数量: {len(filtered_skin_data)}")
  74. # 更新推理过程
  75. self.reasoning.append(f"根据颜色过滤数据以后有: {len(filtered_skin_data)}个符合该需求的模板")
  76. # 3.4获取模板信息(使用缓存)
  77. template_infos, template_embeddings = self.get_template_embeddings(filtered_skin_data)
  78. # 3.5如果模板信息为空,则返回None
  79. if not template_infos:
  80. return None
  81. # 3.6用户输入的文本转化为向量
  82. user_embedding = self.text_to_embedding(user_text)
  83. # 3.7如果用户输入的文本的维度是1,则转换成2维
  84. if len(user_embedding.shape) == 1:
  85. user_embedding = user_embedding.reshape(1, -1)
  86. if len(template_embeddings.shape) == 1:
  87. template_embeddings = template_embeddings.reshape(1, -1)
  88. # 3.8通过余弦函数批量计算相似度
  89. similarities = cosine_similarity(user_embedding, template_embeddings)[0]
  90. best_index = np.argmax(similarities)
  91. best_similarity = similarities[best_index]
  92. return {
  93. 'template_info': template_infos[best_index],
  94. 'similarity_score': best_similarity
  95. }
  96. # 3.9提取用户输入中的颜色关键词
  97. def extract_color_keywords(self, text):
  98. color_map = {
  99. "红": "红色",
  100. "绿": "绿色",
  101. "蓝": "蓝色",
  102. "黄": "黄色",
  103. "橙": "橙色",
  104. "紫": "紫色",
  105. "黑": "黑色",
  106. "白": "白色",
  107. "灰": "灰色"
  108. }
  109. found_colors = []
  110. for char, color in color_map.items():
  111. if char in text:
  112. found_colors.append(color)
  113. return found_colors
  114. # 3.10根据颜色关键词过滤模板
  115. def filter_templates_by_color(self, skin_data, color_keywords):
  116. filtered = []
  117. for template in skin_data:
  118. try:
  119. keywords = json.loads(template.get('template_keyword', '[]'))
  120. keyword_text = ' '.join(keywords)
  121. # 检查模板关键词是否包含任何颜色关键词
  122. for color in color_keywords:
  123. if color in keyword_text:
  124. filtered.append(template)
  125. break
  126. except:
  127. continue
  128. return filtered
  129. """第四步:缓存"""
  130. # 4.1读取或者新建缓存
  131. def get_template_embeddings(self, skin_data):
  132. if not skin_data:
  133. return None, []
  134. # 4.1.2准备模板关键词文本
  135. template_texts = []
  136. template_infos = []
  137. for template in skin_data:
  138. try:
  139. keywords = json.loads(template.get('template_keyword', '[]'))
  140. keyword_text = ' '.join(keywords)
  141. template_texts.append(keyword_text)
  142. template_infos.append(template)
  143. except json.JSONDecodeError:
  144. continue
  145. if not template_texts:
  146. return None, []
  147. # 4.1.3检查缓存
  148. cache_key = self._get_template_cache_key(skin_data)
  149. if cache_key in self.template_embeddings_cache:
  150. print("📦 从缓存加载模板嵌入向量")
  151. template_embeddings = self.template_embeddings_cache[cache_key]
  152. return template_infos, template_embeddings
  153. print("🔄 计算新的模板嵌入向量")
  154. template_embeddings = self.text_model.encode(template_texts)
  155. # 4.1.4存储到缓存(cachetools自动处理过期)
  156. self.template_embeddings_cache[cache_key] = template_embeddings
  157. return template_infos, template_embeddings
  158. # 4.2 生成模板数据的缓存键
  159. def _get_template_cache_key(self, skin_data):
  160. template_keys = []
  161. for template in skin_data:
  162. template_keys.append(f"{template.get('id')}_{template.get('template_keyword')}")
  163. combined_key = '|||'.join(sorted(template_keys))
  164. return hashlib.md5(combined_key.encode('utf-8')).hexdigest()
  165. # 4.3 清空模板缓存
  166. def clear_cache(self):
  167. self.template_embeddings_cache.clear()
  168. print("🗑️ 模板嵌入向量缓存已清空")
  169. # 4.4 获取模板缓存
  170. def get_cache_info(self):
  171. return {
  172. 'template_embeddings_cache_size': len(self.template_embeddings_cache),
  173. 'cached_templates_count': len(self.template_embeddings_cache),
  174. 'cache_maxsize': self.template_embeddings_cache.maxsize,
  175. 'cache_ttl': self.template_embeddings_cache.ttl
  176. }
  177. """第五步:生成通栏数据"""
  178. # 5.1 随机生成通栏数据
  179. def get_sectors(self,website_id,matched_template_id,token):
  180. print("开始生成通栏数据!")
  181. print(matched_template_id)
  182. # 根据matched_template_id选择样式数据
  183. if int(matched_template_id) == 5:
  184. print("正在使用第5套皮肤的数据!")
  185. self.current_style_data = sectorStyle5Data
  186. else:
  187. print("正在使用前4套皮肤的数据!")
  188. self.current_style_data = sectorStyle1to4Data
  189. sectors_config = self.current_style_data.sectors_config
  190. scheduler = SectorScheduler(sectors_config)
  191. result = scheduler.generate_sequence(10)
  192. # 处理强制位置(倒数位置)
  193. forced_map = {}
  194. for sector in sectors_config:
  195. if 'force_pos' in sector:
  196. # force_pos: 1(倒数第一), 2(倒数第二), 3(倒数第三)
  197. forced_map[sector['force_pos']] = sector['name']
  198. # 替换结果中的对应位置
  199. if forced_map and result:
  200. for pos, name in forced_map.items():
  201. idx = -1 * pos
  202. if abs(idx) <= len(result):
  203. print(f"强制调整位置: 将 {name} 放入索引 {idx}")
  204. result[idx] = name
  205. # print("结果:", result)
  206. sectorNames = []
  207. # 第一步:获取到筛选结果
  208. for sector_name in result:
  209. for sector in sectors_config:
  210. if sector['name'] == sector_name:
  211. sectorNames.append(sector['name'])
  212. break
  213. # 第二步:使用筛选结果把sectors_webSiteData重组为模板数据
  214. for name in sectorNames:
  215. # 从字典中直接获取对应的sector数据
  216. if name in self.current_style_data.sectors_data:
  217. matched_sector = self.current_style_data.sectors_data[name]
  218. # 添加到templateData.sectors_webSiteData.index
  219. # templateData.sectors_webSiteData["template"]["index"].append(matched_sector)
  220. # print(matched_sector)
  221. # print(templateData.sectors_webSiteData["template"]["index"])
  222. templateData.sectors_webSiteData["template"]["index"].append(matched_sector)
  223. #print("模板数据:" + json.dumps(templateData.sectors_webSiteData["template"]["index"], ensure_ascii=False))
  224. #获得通栏名称
  225. ai_service.test_normal_case_cnames()
  226. # 第三步:选择组件样式
  227. ai_service.updata_component_style(website_id,matched_template_id,token)
  228. # 第四步:依照用户需求修改样式
  229. ai_service.updata_component_style_to_user(website_id,matched_template_id,token)
  230. # 第五步:生成画布数据
  231. # 循环templateData.sectors_webSiteData["template"]["index"],生成画布数据
  232. # templateData.canvas_data["template"]["index"].append()
  233. all_y = 0 # 初始化累计高度
  234. for index,sectorData in enumerate(templateData.sectors_webSiteData["template"]["index"]):
  235. # 获取当前通栏的高度
  236. sector_height = sectorData["sectorCanvasHeight"]
  237. # 生成画布数据并添加到canvas_data
  238. canvas_item = ai_service.get_canvans_data(sector_height, all_y, sectorData,index)
  239. templateData.canvas_data["template"]["index"].append(canvas_item)
  240. # 更新累计高度
  241. all_y += sector_height
  242. # 5.2 根据通栏数据生成画布数据
  243. def get_canvans_data(self,sector_height,all_y,sectorData,sort):
  244. data = {
  245. "i": int(time.time() * 1000) * 1000 + random.randint(0, 9999), # 毫秒级时间戳后加随机4位数
  246. "x": 0,#通栏的x坐标恒定为0
  247. "y": all_y,#通栏的y坐标为当前所有通栏的高度之和
  248. "w": 12,#通栏的宽度恒定为12
  249. "h": sector_height,#通栏的高度为sectorData["sectorCanvasHeight"]
  250. "type": sectorData["sectorName"],
  251. "content":sectorData,
  252. "dataSort": sort,
  253. "moved": False
  254. }
  255. return data
  256. # 5.3 生成随机的样式数据
  257. def updata_component_style(self,website_id,matched_template_id,token):
  258. getComponentStyle = post_request("public/getAllSectorComponentStyle", {
  259. "website_id":website_id,
  260. "template_id": matched_template_id,
  261. },token)
  262. # 请补全此方法
  263. # 检查响应是否成功
  264. if getComponentStyle.get('code') != 200:
  265. print(f"获取组件样式失败: {getComponentStyle.get('message')}")
  266. return
  267. # 获取组件样式数据
  268. style_data = getComponentStyle.get('data', [])
  269. self.component_style_data = style_data
  270. # 需要跳过的通栏类型
  271. skip_sectors = ["adSector", "linkSector", "linkCxfwSector"]
  272. # 更新推理过程
  273. self.reasoning.append(f"通栏创建完毕")
  274. # 遍历模板中的每个通栏
  275. for sector in templateData.sectors_webSiteData["template"]["index"]:
  276. sector_name = sector["sectorName"]
  277. # 跳过不需要修改样式的通栏
  278. if sector_name in skip_sectors:
  279. print(f"跳过 {sector_name} 的组件样式修改")
  280. continue
  281. # 在样式数据中查找匹配的通栏
  282. matched_style = next((s for s in style_data if s["sector_id"] == sector_name), None)
  283. if not matched_style:
  284. print(f"警告: 未找到 {sector_name} 的组件样式配置")
  285. continue
  286. # 获取该通栏的组件列表配置
  287. component_list_config = matched_style.get("component_list", [])
  288. # 遍历通栏中的每个组件
  289. for idx, component in enumerate(sector["componentList"]):
  290. if idx < len(component_list_config):
  291. comp_config = component_list_config[idx]
  292. if comp_config and isinstance(comp_config, list) and len(comp_config) > 0:
  293. comp_style_data = comp_config[0].get("component_style_data", [])
  294. # 如果有可用的样式数据,随机选择一个样式
  295. if comp_style_data:
  296. style_count = len(comp_style_data)
  297. random_style = random.randint(1, style_count)
  298. component["component_style"] = random_style
  299. print(f"初始化 {sector_name} 的第{idx+1}个组件样式为: {random_style}")
  300. # 更新推理过程
  301. #self.reasoning.append(f"初始化 {sector_name} 的第{idx+1}个组件样式为: {random_style}号样式")
  302. else:
  303. print(f"警告: {sector_name} 的第{idx+1}个组件没有可用的样式数据")
  304. else:
  305. print(f"警告: {sector_name} 的第{idx+1}个组件配置无效")
  306. else:
  307. print(f"警告: {sector_name} 的组件数量({len(sector['componentList'])})超过配置数量({len(component_list_config)})")
  308. # 5.4 根据用户需求修改样式数据
  309. def updata_component_style_to_user(self, website_id, matched_template_id, token):
  310. # 需要跳过的通栏类型
  311. skip_sectors = ["adSector", "linkSector", "linkCxfwSector"]
  312. # 1. 将用户输入转换为向量
  313. user_embedding = self.text_to_embedding(self.user_input)
  314. if len(user_embedding.shape) == 1:
  315. user_embedding = user_embedding.reshape(1, -1)
  316. # 更新推理过程
  317. self.reasoning.append(f"最后让我挑选一下符合用户需求的组件样式")
  318. # 2. 遍历每个通栏
  319. for sector in templateData.sectors_webSiteData["template"]["index"]:
  320. sector_name = sector["sectorName"]
  321. # 跳过不需要修改样式的通栏
  322. if sector_name in skip_sectors:
  323. print(f"跳过 {sector_name} 的组件样式优化")
  324. # 更新推理过程
  325. #self.reasoning.append(f"跳过 {sector_name} 的组件样式优化")
  326. continue
  327. # 3. 在样式数据中查找匹配的通栏
  328. matched_style = next((s for s in self.component_style_data if s["sector_id"] == sector_name), None)
  329. if not matched_style:
  330. print(f"警告: 未找到 {sector_name} 的组件样式配置")
  331. continue
  332. # 4. 遍历通栏中的每个组件
  333. for idx, component in enumerate(sector["componentList"]):
  334. comp_config_list = matched_style.get("component_list", [])
  335. if idx >= len(comp_config_list) or not comp_config_list[idx]:
  336. continue
  337. comp_config = comp_config_list[idx][0] # 取第一个配置
  338. comp_style_data = comp_config.get("component_style_data", [])
  339. if not comp_style_data:
  340. continue
  341. # 5. 提取所有样式的描述文本
  342. style_descriptions = [style["img_name"] for style in comp_style_data]
  343. # 6. 将样式描述转换为向量
  344. style_embeddings = self.text_to_embedding(style_descriptions)
  345. # 7. 计算用户输入与每个样式的相似度
  346. similarities = cosine_similarity(user_embedding, style_embeddings)[0]
  347. best_index = np.argmax(similarities)
  348. best_style = comp_style_data[best_index]
  349. # 8. 更新组件样式
  350. component["component_style"] = best_style["img_id"]
  351. print(f"优化 {sector_name} 的第{idx+1}个组件样式为: {best_style['img_name']} (ID: {best_style['img_id']})")
  352. # 更新推理过程
  353. self.reasoning.append(f"选择 {sector_name} 的第{idx+1}个组件样式为: {best_style['img_name']} (ID: {best_style['img_id']})")
  354. """其他功能 模拟推理过程"""
  355. # 获得选择通栏的中文名称
  356. def test_normal_case_cnames(self):
  357. #获得生成的模板数据
  358. result = templateData.sectors_webSiteData["template"]["index"]
  359. #更新推理过程
  360. self.reasoning.append(f"接着我要创建一套符合用户需求的通栏组合")
  361. #遍历result获得所有通栏的Cname
  362. cnnames = []
  363. # 遍历每个通栏,使用enumerate获取索引
  364. for i, sector in enumerate(result):
  365. sector_name = sector["sectorName"]
  366. # 在sectors_config中查找对应的中文名称
  367. if self.current_style_data:
  368. for config in self.current_style_data.sectors_config:
  369. if config["name"] == sector_name:
  370. cnnames.append(config["CNname"])
  371. # 使用通栏在结果列表中的索引+1作为序号
  372. self.reasoning.append(f"第{i+1}个通栏选择: {config['CNname']}")
  373. break # 找到后跳出内层循环
  374. # 清理数据
  375. def clear_data(self):
  376. #清理模板数据
  377. templateData.canvas_data["template"]["index"] = []
  378. templateData.sectors_webSiteData["template"]["index"] = []
  379. #清理推理数据
  380. self.reasoning = []
  381. self.current_style_data = None
  382. #初始化权重
  383. sectorStyle1to4Data.reset_sectors_config()
  384. sectorStyle5Data.reset_sectors_config()
  385. #添加推理过程
  386. def add_reasoning(self, reasoning):
  387. self.reasoning.append(reasoning)
  388. ai_service = AIService()