Эх сурвалжийг харах

feat: 聚合功能,增加使用多线程,优化代码

ChenYL 2 жил өмнө
parent
commit
a1be3a6293
3 өөрчлөгдсөн 196 нэмэгдсэн , 136 устгасан
  1. 3 1
      README.md
  2. 192 134
      src/agg.py
  3. 1 1
      src/money.py

+ 3 - 1
README.md

@@ -1,10 +1,12 @@
 # 开发记录
 
 ## 待办列表
-* 长尾词聚合增加使用多线程
 * 多进程多线程使用tqdm显示进度
 
 ## 开发进度
+* 2024-01-17
+  - 增加多进程初始化
+  - 长尾词聚合增加使用多线程
 * 2024-01-16
   - 增加使用redis,提高性能 
 * 2023-12-15

+ 192 - 134
src/agg.py

@@ -2,7 +2,8 @@
 import math
 import os
 import re
-from concurrent.futures import ProcessPoolExecutor, as_completed
+import threading
+from concurrent.futures import ProcessPoolExecutor, as_completed, ThreadPoolExecutor
 
 import jieba
 import redis
@@ -11,7 +12,7 @@ from tqdm import tqdm
 import utils
 import logging
 
-from src.constant import FILE_LONG_TAIL_MERGE
+from constant import FILE_LONG_TAIL_MERGE
 
 # 文件:长尾词_合并_分词.txt
 FILE_LONG_TAIL_MERGE_SPLIT = "长尾词_合并_分词.txt"
@@ -23,7 +24,7 @@ FILE_LONG_TAIL_MERGE_AGG = "长尾词_合并_聚合.txt"
 FILE_LONG_TAIL_MERGE_REVERSE_INDEX = "长尾词_合并_倒排索引.txt"
 
 # 子文件:长尾词_合并_聚合_%s.txt
-FILE_LONG_TAIL_MERGE_AGG_PID = "长尾词_合并_聚合_%s.txt"
+FILE_LONG_TAIL_MERGE_AGG_PID = "长尾词_合并_聚合_%s_%s.txt"
 
 # 缓存前缀:分词词根
 CACHE_WORD_STEM = "word:stem"
@@ -40,6 +41,18 @@ CACHE_UNUSED_BITMAP = "unused_bitmap"
 # 字符集:UTF-8
 CHARSET_UTF_8 = "UTF-8"
 
+# redis缓存池
+redis_pool: redis.ConnectionPool = None
+
+# redis客户端
+redis_cache: redis.StrictRedis = None
+
+# 线程池
+thread_pool: ThreadPoolExecutor = None
+
+# 线程本地变量
+local_variable = threading.local()
+
 
 def agg_word(file_path: str):
     """
@@ -50,6 +63,7 @@ def agg_word(file_path: str):
 
     # 总长尾词数量
     word_total_num = 0
+    # word_total_num = 100000
 
     # 聚合阈值
     agg_threshold = 0.8
@@ -62,10 +76,16 @@ def agg_word(file_path: str):
     # worker_num = 1
 
     # 正则表达式:聚合文件分文件
-    agg_file_pattern = re.compile(r"长尾词_合并_聚合_\d+.txt", re.I)
+    agg_file_pattern = re.compile(r"长尾词_合并_聚合_\d+_\d+.txt", re.I)
+
+    # 最大线程数
+    max_threads = 3
+
+    # redis最大连接数(和工作线程数保持一致,免得浪费)
+    redis_max_conns = max_threads
 
     # redis缓存
-    redis_cache = redis.StrictRedis(host='127.0.0.1', port=6379)
+    m_redis_cache = redis.StrictRedis(host='127.0.0.1', port=6379)
 
     # 判断文件是否存在
     for file_name in [FILE_LONG_TAIL_MERGE, FILE_LONG_TAIL_MERGE_SPLIT,
@@ -88,63 +108,64 @@ def agg_word(file_path: str):
             if not word:
                 continue
             word_dict[position] = word
-    redis_cache.hset(CACHE_WORD, mapping=word_dict)
+    m_redis_cache.hset(CACHE_WORD, mapping=word_dict)
     # 记录总关键词数
     word_total_num = len(word_dict)
     # 释放内存
     del word_dict
-    #
-    # # 缓存分词
-    # word_split_file = os.path.join(file_path, FILE_LONG_TAIL_MERGE_SPLIT)
-    # word_split_dict = {}
-    # with open(word_split_file, "r", encoding=CHARSET_UTF_8) as f:
-    #     for position, word_split_line in enumerate(f, start=1):
-    #         word_split_line = utils.remove_line_break(word_split_line)
-    #         if not word_split_line:
-    #             continue
-    #         word_split_dict[position] = word_split_line
-    # redis_cache.hset(CACHE_WORD_STEM, mapping=word_split_dict)
-    # # 释放内存
-    # del word_split_dict
-    #
-    # # 缓存倒排索引
-    # word_reverse_index_file = os.path.join(file_path, FILE_LONG_TAIL_MERGE_REVERSE_INDEX)
-    # word_reverse_index_dict = {}
-    # # 分词
-    # key_pattern = re.compile(r"([^,]+),\[", re.I)
-    # # 索引
-    # index_pattern = re.compile(r"\d+", re.I)
-    # with open(word_reverse_index_file, "r", encoding="utf-8") as f:
-    #     for word_split_line in f:
-    #         key_m = key_pattern.match(word_split_line)
-    #         key = key_m.group(1)
-    #         val = index_pattern.findall(word_split_line[word_split_line.index(","):])
-    #         word_reverse_index_dict[key] = ",".join(val)
-    # redis_cache.hset(CACHE_WORD_REVERSE_INDEX, mapping=word_reverse_index_dict)
-    # # 释放内存
-    # del word_reverse_index_dict
-
-    # 构建长尾词使用位图
-    redis_cache.setbit(CACHE_UNUSED_BITMAP, word_total_num + 1, 1)
+
+    # 缓存分词
+    word_split_file = os.path.join(file_path, FILE_LONG_TAIL_MERGE_SPLIT)
+    word_split_dict = {}
+    with open(word_split_file, "r", encoding=CHARSET_UTF_8) as f:
+        for position, word_split_line in enumerate(f, start=1):
+            word_split_line = utils.remove_line_break(word_split_line)
+            if not word_split_line:
+                continue
+            word_split_dict[position] = word_split_line
+    m_redis_cache.hset(CACHE_WORD_STEM, mapping=word_split_dict)
+    # 释放内存
+    del word_split_dict
+
+    # 缓存倒排索引
+    word_reverse_index_file = os.path.join(file_path, FILE_LONG_TAIL_MERGE_REVERSE_INDEX)
+    word_reverse_index_dict = {}
+    # 分词
+    key_pattern = re.compile(r"([^,]+),\[", re.I)
+    # 索引
+    index_pattern = re.compile(r"\d+", re.I)
+    with open(word_reverse_index_file, "r", encoding="utf-8") as f:
+        for word_split_line in f:
+            key_m = key_pattern.match(word_split_line)
+            key = key_m.group(1)
+            val = index_pattern.findall(word_split_line[word_split_line.index(","):])
+            word_reverse_index_dict[key] = ",".join(val)
+    m_redis_cache.hset(CACHE_WORD_REVERSE_INDEX, mapping=word_reverse_index_dict)
+    # 释放内存
+    del word_reverse_index_dict
+
+    # 先清除,然后重新构建长尾词使用位图
+    m_redis_cache.delete(CACHE_UNUSED_BITMAP)
+    m_redis_cache.setbit(CACHE_UNUSED_BITMAP, word_total_num + 1, 1)
 
     # 提交任务 并输出结果
     word_agg_file = os.path.join(file_path, FILE_LONG_TAIL_MERGE_AGG)
 
-    with ProcessPoolExecutor(max_workers=worker_num) as process_pool:
+    with ProcessPoolExecutor(max_workers=worker_num, initializer=init_process,
+                             initargs=(redis_max_conns, max_threads, file_path)) as process_pool:
         # 计算任务边界
         task_list = utils.avg_split_task(word_total_num, task_cal_num, 1)
 
         # 提交任务
         process_futures = []
         for skip_line, pos in enumerate(task_list, start=1):
-            # skip_line =  (skip_line % 4 ) + 1
-            skip_line = 1
-            p_future = process_pool.submit(agg_word_process, file_path, agg_threshold, pos[0], pos[1], word_total_num,
+            skip_line =  (skip_line % worker_num ) + 1
+            p_future = process_pool.submit(agg_word_process, agg_threshold, pos[0], pos[1], word_total_num,
                                            skip_line)
             process_futures.append(p_future)
 
         # 显示任务进度
-        with tqdm(total=len(process_futures), desc='文本聚合进度', unit='份', unit_scale=True, position=0) as pbar:
+        with tqdm(total=len(process_futures), desc='文本聚合进度', unit='份', unit_scale=True) as pbar:
             for p_future in as_completed(process_futures):
                 p_result = p_future.result()
                 # 更新发呆进度
@@ -232,6 +253,7 @@ def prepare_word_split_and_reverse_index(file_path: str):
         # 关闭进程池
         process_pool.shutdown()
 
+
 def word_split_reverse(input_file: str, start_pos: int, end_pos: int):
     """
     分词和建立倒排索引
@@ -288,12 +310,27 @@ def word_split_reverse(input_file: str, start_pos: int, end_pos: int):
 
     return start_pos, word_arr_list, word_reverse_index
 
-def agg_word_process(file_path: str, agg_threshold: float, start_pos: int, end_pos: int, final_pos: int,
+
+def init_process(max_conns: int, max_threads: int, file_path: str):
+    """
+    初始化进程
+    :param max_conns: redis最大连接数量
+    :param max_threads: 线程最大数量
+    :param file_path: 输出文件路径
+    :return:
+    """
+    # redis缓存池 初始化
+    global redis_pool
+    redis_pool = redis.ConnectionPool(host='127.0.0.1', port=6379, max_connections=max_conns)
+
+    global thread_pool
+    thread_pool = ThreadPoolExecutor(max_threads, initializer=init_thread, initargs=(file_path,))
+
+
+def agg_word_process(agg_threshold: float, start_pos: int, end_pos: int, final_pos: int,
                      skip_line: int):
     """
     长尾词聚合处理
-    :param file_path: 文件路径
-    :param word_file: 长尾词文件路径
     :param agg_threshold: 聚合阈值
     :param start_pos: 任务处理开始边界(包含)
     :param end_pos: 任务处理结束边界(不包含)
@@ -302,15 +339,6 @@ def agg_word_process(file_path: str, agg_threshold: float, start_pos: int, end_p
     :return:
     """
 
-    # 生成临时结果文件
-    word_agg_file = os.path.join(file_path, FILE_LONG_TAIL_MERGE_AGG_PID % os.getpid())
-
-    # redis缓存池
-    redis_pool = redis.ConnectionPool(host='127.0.0.1', port=6379, max_connections=1)
-
-    # redis缓存
-    redis_cache = redis.StrictRedis(connection_pool=redis_pool)
-
     # 进度长度
     process_len = 0
     if end_pos == -1:
@@ -318,90 +346,120 @@ def agg_word_process(file_path: str, agg_threshold: float, start_pos: int, end_p
     else:
         process_len = end_pos - start_pos
 
-    with (open(word_agg_file, "a", encoding="UTF-8") as fo,
-          tqdm(total=process_len, desc='子进程-%s:文本聚合进度' % os.getpid(), unit='份', unit_scale=True,
-               position=skip_line) as pbar):
+    with tqdm(total=process_len, desc='子进程-%s:文本聚合进度' % os.getpid(), unit='份', unit_scale=True,
+              position=skip_line) as pbar:
 
-        for main_word_position in range(start_pos, end_pos):
-            try:
-                # 判断主词是否为已使用,是则跳过,否则设置为已使用
-                if redis_cache.getbit(CACHE_UNUSED_BITMAP, main_word_position):
-                    continue
-                else:
-                    redis_cache.setbit(CACHE_UNUSED_BITMAP, main_word_position, 1)
+        thread_futures = [thread_pool.submit(agg_word_thread, main_word_position, agg_threshold) for main_word_position
+                          in
+                          range(start_pos, end_pos)]
 
-                # 获取主词,移除换行符
-                main_word = redis_cache.hget(CACHE_WORD, main_word_position)
-                if not main_word:
-                    continue
-                main_word = main_word.decode(CHARSET_UTF_8)
+        for t_future in as_completed(thread_futures):
+            t_result = t_future.result()
+            # 更新发呆进度
+            pbar.update(1)
 
-                # 获取主词分词结果
-                main_word_stem = redis_cache.hget(CACHE_WORD_STEM, main_word_position)
-                # 为空则跳过
-                if not main_word_stem:
-                    continue
-                main_word_stem_list = main_word_stem.decode(CHARSET_UTF_8).split(",")
-
-                # 从倒排索引中获取候选词的位置索引
-                candidate_position_set = set()
-                temp_candidate_position_list = redis_cache.hmget(CACHE_WORD_REVERSE_INDEX, main_word_stem_list)
-                for temp_candidate_position in temp_candidate_position_list:
-                    if not temp_candidate_position:
-                        continue
-                    candidate_position_set.update(temp_candidate_position.decode(CHARSET_UTF_8).split(","))
-
-                # 没有找到需要计算的候选词则跳过
-                if not candidate_position_set:
+    return
+
+
+def init_thread(file_path: str):
+    """
+    聚合线程初始化
+    :param file_path: 输出文件路径
+    :return:
+    """
+    # 初始化redis客户端
+    global redis_cache
+    redis_cache = redis.StrictRedis(connection_pool=redis_pool)
+
+    # 生成临时结果文件
+    word_agg_file = os.path.join(file_path,
+                                 FILE_LONG_TAIL_MERGE_AGG_PID % (os.getpid(), threading.current_thread().ident))
+    local_variable.file_writer = open(word_agg_file, "w", encoding=CHARSET_UTF_8)
+
+
+def agg_word_thread(main_word_position: int, agg_threshold: float):
+    try:
+        # 判断主词是否为已使用,是则跳过,否则设置为已使用
+        if redis_cache.getbit(CACHE_UNUSED_BITMAP, main_word_position):
+            return
+        else:
+            redis_cache.setbit(CACHE_UNUSED_BITMAP, main_word_position, 1)
+
+        # 获取主词,移除换行符
+        main_word = redis_cache.hget(CACHE_WORD, main_word_position)
+        if not main_word:
+            return
+        main_word = main_word.decode(CHARSET_UTF_8)
+
+        # 获取主词分词结果
+        main_word_stem = redis_cache.hget(CACHE_WORD_STEM, main_word_position)
+        # 为空则跳过
+        if not main_word_stem:
+            return
+        main_word_stem_list = main_word_stem.decode(CHARSET_UTF_8).split(",")
+
+        # 从倒排索引中获取候选词的位置索引
+        candidate_position_set = set()
+        temp_candidate_position_list = redis_cache.hmget(CACHE_WORD_REVERSE_INDEX, main_word_stem_list)
+        for temp_candidate_position in temp_candidate_position_list:
+            if not temp_candidate_position:
+                continue
+            # 排除已聚合
+            for candidate_position in temp_candidate_position.decode(CHARSET_UTF_8).split(","):
+                if redis_cache.getbit(CACHE_UNUSED_BITMAP, candidate_position):
                     continue
+                candidate_position_set.add(candidate_position)
 
-                # 结果列表
-                result_list = []
-
-                # 计算相似度
-                for candidate_position in candidate_position_set:
-                    # 跳过重复
-                    if main_word_position == candidate_position:
-                        continue
-
-                    # 获取关键词
-                    candidate_word = redis_cache.hget(CACHE_WORD, candidate_position)
-                    if not candidate_word:
-                        continue
-                    candidate_word = candidate_word.decode(CHARSET_UTF_8)
-
-                    # 获取分词结果
-                    candidate_word_stem = redis_cache.hget(CACHE_WORD_STEM, candidate_position)
-                    # 为空则跳过
-                    if not candidate_word_stem:
-                        continue
-                    candidate_word_stem_list = candidate_word_stem.decode(CHARSET_UTF_8).split(",")
-
-                    # 计算相关性
-                    try:
-                        val = utils.cal_cos_sim(main_word, main_word_stem_list, candidate_word,
-                                                candidate_word_stem_list)
-                        if val >= agg_threshold:
-                            redis_cache.setbit(CACHE_UNUSED_BITMAP, candidate_position, 1)
-                            result_list.append(candidate_word)
-                    except Exception as e:
-                        logging.error("主关键词:%s 发生异常,涉及的副关键词信息-关键词:%s,分词:%s" % (
-                            main_word, candidate_word, candidate_word_stem_list), e)
-
-                # 保存结果
-                if len(result_list) > 0:
-                    fo.write("%s\n" % main_word)
-                    for candidate_word in result_list:
-                        fo.write("%s\n" % candidate_word)
-                    fo.write("\n")
-
-                # 清除容器数据
-                candidate_position_set.clear()
-                result_list.clear()
+        # 没有找到需要计算的候选词则跳过
+        if not candidate_position_set:
+            return
 
-                # 更新发呆进度
-                pbar.update(1)
+        # 结果列表
+        result_list = []
+
+        # 计算相似度
+        for candidate_position in candidate_position_set:
+            # 跳过重复
+            if main_word_position == candidate_position:
+                continue
+
+            # 获取关键词
+            candidate_word = redis_cache.hget(CACHE_WORD, candidate_position)
+            if not candidate_word:
+                continue
+            candidate_word = candidate_word.decode(CHARSET_UTF_8)
+
+            # 获取分词结果
+            candidate_word_stem = redis_cache.hget(CACHE_WORD_STEM, candidate_position)
+            # 为空则跳过
+            if not candidate_word_stem:
+                continue
+            candidate_word_stem_list = candidate_word_stem.decode(CHARSET_UTF_8).split(",")
+
+            # 计算相关性
+            try:
+                val = utils.cal_cos_sim(main_word, main_word_stem_list, candidate_word,
+                                        candidate_word_stem_list)
+                if val >= agg_threshold:
+                    redis_cache.setbit(CACHE_UNUSED_BITMAP, candidate_position, 1)
+                    result_list.append(candidate_word)
             except Exception as e:
-                logging.error("子进程发生异常", e)
+                logging.error("主关键词:%s 发生异常,涉及的副关键词信息-关键词:%s,分词:%s" % (
+                    main_word, candidate_word, candidate_word_stem_list), e)
+
+        # 保存结果
+        if len(result_list) > 0:
+            local_variable.file_writer.write("%s\n" % main_word)
+            for candidate_word in result_list:
+                local_variable.file_writer.write("%s\n" % candidate_word)
+            local_variable.file_writer.write("\n")
+            local_variable.file_writer.flush()
+
+        # 清除容器数据
+        candidate_position_set.clear()
+        result_list.clear()
+
+    except Exception as e:
+        logging.error("子进程发生异常", e)
 
     return

+ 1 - 1
src/money.py

@@ -7,7 +7,7 @@ import jieba
 
 import utils
 from agg import agg_word
-from src.constant import FILE_LONG_TAIL_MERGE
+from constant import FILE_LONG_TAIL_MERGE
 
 # 文件后缀:长尾词.txt
 FILE_SUFFIX_LONG_TAIL = "_长尾词.txt"