utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # -*- coding:utf-8 -*-
  2. import math
  3. import os
  4. import pickle
  5. import re
  6. import numpy as np
  7. # 停用词存放文件夹
  8. STOP_WORD_DIR = "./conf/stopwords"
  9. # 临时文件路径
  10. TEMP_PATH = "../tmp"
  11. # 停用词模型
  12. STOP_WORD_CACHE = "stop_word.pkl"
  13. # 正则表达式中需要额外处理的特殊符号
  14. RE_SPECIAL_SYMBOL = [".", "?", "^", "$", "*", "+", "\\", "[", "]", "|", "{", "}", "(", ")"]
  15. def save_obj(path, obj):
  16. """
  17. 保存对象至本地
  18. """
  19. with open(path, "wb") as f:
  20. pickle.dump(obj, f)
  21. def load_obj(path):
  22. """
  23. 加载对象
  24. """
  25. with open(path, "rb") as f:
  26. return pickle.load(f)
  27. def load_stop_word():
  28. """
  29. 加载停用词
  30. """
  31. # 判断临时文件路径是否存在,不存在则重新创建
  32. if not os.path.exists(TEMP_PATH):
  33. os.makedirs(TEMP_PATH)
  34. # 判断是否存在缓存
  35. stop_word_cache_path = os.path.join(TEMP_PATH, STOP_WORD_CACHE)
  36. if os.path.exists(stop_word_cache_path) and os.path.isfile(stop_word_cache_path):
  37. return load_obj(stop_word_cache_path)
  38. # 停用词容器
  39. stop_word = set()
  40. # 构建停用词列表
  41. stop_word_files = os.listdir(STOP_WORD_DIR)
  42. for file in stop_word_files:
  43. stop_word_file = os.path.join(STOP_WORD_DIR, file)
  44. with open(stop_word_file, encoding="UTF-8") as f:
  45. for item in f:
  46. # 移除换行符
  47. stop_word.add(item.replace("\n", "").replace("\r", ""))
  48. # 改成dict提升检索速度
  49. stop_word_dict = {}
  50. for item in stop_word:
  51. stop_word_dict[item] = None
  52. # 保存本地作为缓存
  53. save_obj(stop_word_cache_path, stop_word_dict)
  54. return stop_word_dict
  55. def avg_split_task(total: int, split_internal: int, start=0):
  56. """
  57. 平分任务,包含开始位置,不包含结束位置,开始位置是从0开始
  58. :param start: 开始位置
  59. :param total: 任务总数量
  60. :param split_internal: 每份数量
  61. :return: (开始位置,结束位置)
  62. """
  63. # 分割的任务份数
  64. split_num = math.ceil(total / split_internal)
  65. # 平分
  66. tasks = []
  67. for i in range(split_num):
  68. # 计算平分点在列表中的位置
  69. start_pos = i * split_internal
  70. end_pos = i * split_internal + split_internal
  71. if i == 0:
  72. start_pos = start
  73. # 如果超过列表大小需要额外处理
  74. if end_pos >= total:
  75. end_pos = -1
  76. tasks.append([start_pos, end_pos])
  77. return tasks
  78. def cal_cos_sim(a_word: str, a_stem: list, b_word: str, b_stem: list):
  79. """
  80. 计算余弦相似性
  81. :param a_word: A词
  82. :param a_stem: A词根列表
  83. :param b_word: B词
  84. :param b_stem: B词根列表
  85. :return: 余弦值
  86. """
  87. # 合并词根
  88. union_stem = list(set(a_stem).union(set(b_stem)))
  89. # 生成词向量
  90. a_vec, b_vec = [], []
  91. for word in union_stem:
  92. if word in RE_SPECIAL_SYMBOL:
  93. word = "\\" + word
  94. if word == "c++":
  95. word = "c\\+\\+"
  96. a_vec.append(len(re.findall(word, a_word)))
  97. b_vec.append(len(re.findall(word, b_word)))
  98. # 计算余弦相关性
  99. vec1 = np.array(a_vec)
  100. vec2 = np.array(b_vec)
  101. val = (np.linalg.norm(vec1) * np.linalg.norm(vec2))
  102. if val == 0:
  103. return 0
  104. return vec1.dot(vec2) / val
  105. def remove_line_break(line: str):
  106. """
  107. 移除换行符
  108. :param line: 待处理文本
  109. :return: 替换后的结果
  110. """
  111. if line:
  112. return line.replace("\r", "").replace("\n", "")
  113. return line