verifyInput.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import re
  2. from services.ai_service import ai_service
  3. from sklearn.metrics.pairwise import cosine_similarity
  4. class HybridRequestValidator:
  5. def __init__(self):
  6. # 规则1:检查是否包含模板设计相关词汇
  7. self.template_patterns = [
  8. r'.*模板.*',
  9. r'.*设计.*',
  10. r'.*推荐.*(样式|风格|主题)',
  11. r'.*(颜色|色彩|配色).*',
  12. r'.*(去掉|移除|删除|隐藏).*',
  13. r'.*(组件|元素|部件).*',
  14. r'.*(采用|使用|设置为).*'
  15. ]
  16. # 规则2:明显的非模板请求模式
  17. self.reject_patterns = [
  18. r'^告诉我.*',
  19. r'.*天气.*',
  20. r'.*时间.*',
  21. r'.*播放.*',
  22. r'.*新闻.*',
  23. r'^查询.*',
  24. r'^搜索.*'
  25. ]
  26. # 语义检查的参考语句
  27. self.reference_phrases = [
  28. "模板设计需求,包含样式和组件要求",
  29. "视觉设计请求,涉及颜色和布局调整",
  30. "UI模板定制,需要修改特定元素"
  31. ]
  32. self.ref_embeddings = ai_service.text_to_embedding(self.reference_phrases)
  33. def validate(self, user_input):
  34. # 1. 快速拒绝明显无关的请求
  35. for pattern in self.reject_patterns:
  36. if re.search(pattern, user_input, re.IGNORECASE):
  37. return False
  38. # 2. 模式匹配检查
  39. match_count = 0
  40. for pattern in self.template_patterns:
  41. if re.search(pattern, user_input, re.IGNORECASE):
  42. match_count += 1
  43. # 修改:只要包含关键词即可,不再强制要求多个关键词或语义验证
  44. # >=1 只要包含1个关键词就通过验证 >=2 必须包含2个关键词 以此类推
  45. if match_count >= 1:
  46. return True
  47. # 3. 语义验证(仅作为补充,用于没有任何关键词匹配但语义相关的情况)
  48. try:
  49. user_embedding = ai_service.text_to_embedding(user_input)
  50. # 确保维度匹配
  51. if len(user_embedding.shape) == 1:
  52. user_embedding = user_embedding.reshape(1, -1)
  53. similarity = cosine_similarity(
  54. user_embedding,
  55. self.ref_embeddings
  56. ).max()
  57. if similarity > 0.4: #稍微提高阈值,因为没有关键词支撑
  58. return True
  59. except Exception as e:
  60. print(f"语义验证出错: {e}")
  61. pass
  62. return False
  63. messageValidator = HybridRequestValidator()