utils.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # -*- coding:utf-8 -*-
  2. import json
  3. import math
  4. import os
  5. import pickle
  6. # 停用词存放文件夹
  7. STOP_WORD_DIR = "./resources/stopwords"
  8. # 临时文件路径
  9. TEMP_PATH = "../tmp"
  10. # 停用词模型
  11. STOP_WORD_CACHE = "stop_word.pkl"
  12. # 正则表达式中需要额外处理的特殊符号
  13. RE_SPECIAL_SYMBOL = [".", "?", "^", "$", "*", "+", "\\", "[", "]", "|", "{", "}", "(", ")"]
  14. def save_obj(path, obj):
  15. """
  16. 保存对象至本地
  17. """
  18. with open(path, "wb") as f:
  19. pickle.dump(obj, f)
  20. def load_obj(path):
  21. """
  22. 加载对象
  23. """
  24. with open(path, "rb") as f:
  25. return pickle.load(f)
  26. def load_stop_word():
  27. """
  28. 加载停用词
  29. """
  30. # 判断临时文件路径是否存在,不存在则重新创建
  31. if not os.path.exists(TEMP_PATH):
  32. os.makedirs(TEMP_PATH)
  33. # 判断是否存在缓存
  34. stop_word_cache_path = os.path.join(TEMP_PATH, STOP_WORD_CACHE)
  35. if os.path.exists(stop_word_cache_path) and os.path.isfile(stop_word_cache_path):
  36. return load_obj(stop_word_cache_path)
  37. # 停用词容器
  38. stop_word = set()
  39. # 构建停用词列表
  40. stop_word_files = os.listdir(STOP_WORD_DIR)
  41. for file in stop_word_files:
  42. stop_word_file = os.path.join(STOP_WORD_DIR, file)
  43. with open(stop_word_file, encoding="UTF-8") as f:
  44. for item in f:
  45. # 移除换行符
  46. stop_word.add(item.replace("\n", "").replace("\r", ""))
  47. # 改成dict提升检索速度
  48. stop_word_dict = {}
  49. for item in stop_word:
  50. stop_word_dict[item] = None
  51. # 保存本地作为缓存
  52. save_obj(stop_word_cache_path, stop_word_dict)
  53. return stop_word_dict
  54. def avg_split_task(total: int, split_internal: int, start=0):
  55. """
  56. 平分任务,包含开始位置,不包含结束位置,开始位置是从0开始
  57. :param start: 开始位置
  58. :param total: 任务总数量
  59. :param split_internal: 每份数量
  60. :return: (开始位置,结束位置)
  61. """
  62. # 分割的任务份数
  63. split_num = math.ceil(total / split_internal)
  64. # 平分
  65. tasks = []
  66. for i in range(split_num):
  67. # 计算平分点在列表中的位置
  68. start_pos = i * split_internal
  69. end_pos = i * split_internal + split_internal
  70. if i == 0:
  71. start_pos = start
  72. # 如果超过列表大小需要额外处理
  73. if end_pos >= total:
  74. end_pos = -1
  75. tasks.append([start_pos, end_pos])
  76. return tasks
  77. def remove_line_break(line: str):
  78. """
  79. 移除换行符
  80. :param line: 待处理文本
  81. :return: 替换后的结果
  82. """
  83. if line:
  84. return line.replace("\r", "").replace("\n", "")
  85. return line
  86. def saveJson(save_path: str, save_obj: dict):
  87. """
  88. 保存为json文件
  89. :param save_path: 保存的路径
  90. :param save_obj: 保存的内容对象
  91. :return:
  92. """
  93. with open(save_path, 'w', encoding='utf-8') as f:
  94. f.write(json.dumps(save_obj))
  95. def load_json(path: str):
  96. """
  97. 加载json文件
  98. :param path:
  99. :return:
  100. """
  101. if os.path.exists(path) and os.path.isfile(path):
  102. with open(path, 'r', encoding='utf-8') as f:
  103. return json.loads(f.read())
  104. return dict()