model_loader.py 900 B

123456789101112131415161718192021222324
  1. import torch
  2. from sentence_transformers import SentenceTransformer
  3. import joblib
  4. import os
  5. class ModelLoader:
  6. def __init__(self):
  7. self.text_model = None
  8. self.classification_model = None
  9. self.summarization_model = None
  10. self.summarization_tokenizer = None
  11. def load_text_model(self, use_local=True, local_model_path="./models/bin/all-MiniLM-L6-v2"):
  12. if self.text_model is None:
  13. if use_local and os.path.exists(local_model_path):
  14. print("🔧 从本地加载sentence-transformers模型...ok")
  15. self.text_model = SentenceTransformer(local_model_path)
  16. else:
  17. print("🔧 从网络加载sentence-transformers模型...ok")
  18. self.text_model = SentenceTransformer('all-MiniLM-L6-v2')
  19. return self.text_model
  20. # 全局模型加载器
  21. model_loader = ModelLoader()