downloadMT5Model.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. # 注意:请把该文件放到文件根目录运行
  2. # 下载或更新本地的模型
  3. import os
  4. def download_mT5_model_only():
  5. """只下载 mT5 摘要模型"""
  6. model_name = "csebuetnlp/mT5_multilingual_XLSum"
  7. save_path = "./models/bin/mT5_multilingual_XLSum"
  8. print(f"下载摘要模型: {model_name}")
  9. print(f"保存到: {save_path}")
  10. # 确保目录存在
  11. os.makedirs(save_path, exist_ok=True)
  12. # 下载模型
  13. from transformers import MT5ForConditionalGeneration, MT5Tokenizer
  14. print("下载 mT5 模型中...")
  15. model = MT5ForConditionalGeneration.from_pretrained(model_name)
  16. print("下载 mT5 分词器中...")
  17. tokenizer = MT5Tokenizer.from_pretrained(model_name)
  18. print("保存到本地...")
  19. model.save_pretrained(save_path)
  20. tokenizer.save_pretrained(save_path)
  21. print(f"✅ mT5 模型已保存到: {save_path}")
  22. return model, tokenizer
  23. # 在 __main__ 中可以选择调用哪个
  24. if __name__ == "__main__":
  25. # 只下载摘要模型
  26. download_mT5_model_only()
  27. # 或者下载所有模型
  28. # download_all_models()