فهرست منبع

feat:聚合优化,增加计算线程;完善README.md文件

ChenYL 2 سال پیش
والد
کامیت
a45d6262e7
2فایلهای تغییر یافته به همراه87 افزوده شده و 34 حذف شده
  1. 5 0
      README.md
  2. 82 34
      src/agg.py

+ 5 - 0
README.md

@@ -1,5 +1,10 @@
 # 开发记录
 
+## 程序运行指令
+```commandline
+conda activate money-mining && python money.py
+```
+
 ## 待办列表
 * 主进程进度不显示
 * 子进程显示不合理

+ 82 - 34
src/agg.py

@@ -53,6 +53,9 @@ redis_pool: redis.ConnectionPool = None
 # 线程池
 thread_pool: ThreadPoolExecutor = None
 
+# 线程池(计算用)
+cal_thread_pool: ThreadPoolExecutor = None
+
 # 线程本地变量
 local_var = threading.local()
 
@@ -71,22 +74,25 @@ def agg_word(file_path: str):
     # 聚合阈值
     agg_threshold = 0.8
 
-    # 每份任务计算量
-    task_cal_num = 10000
+    # 每个进程任务计算量
+    per_process_task_num = 10000
+
+    # 每个线程任务计算量
+    per_thread_task_num = 50
 
     # 工作现成(减1是为了留一个处理器给redis)
     worker_num = os.cpu_count() - 1
     # worker_num = 1
-
-    # 正则表达式:聚合文件分文件
-    agg_file_pattern = re.compile(r"长尾词_合并_聚合_\d+_\d+.txt", re.I)
-
+    
     # 最大线程数
     max_threads = 2
 
     # redis最大连接数(和工作线程数保持一致,免得浪费)
     redis_max_conns = max_threads
 
+    # 正则表达式:聚合文件分文件
+    agg_file_pattern = re.compile(r"长尾词_合并_聚合_\d+_\d+.txt", re.I)
+
     # redis缓存
     m_redis_cache = redis.StrictRedis(host='127.0.0.1', port=6379)
 
@@ -160,14 +166,14 @@ def agg_word(file_path: str):
     with ProcessPoolExecutor(max_workers=worker_num, initializer=init_process,
                              initargs=(redis_max_conns, max_threads, file_path)) as process_pool:
         # 计算任务边界
-        task_list = utils.avg_split_task(word_total_num, task_cal_num, 1)
+        task_list = utils.avg_split_task(word_total_num, per_process_task_num, 1)
 
         # 提交任务
         process_futures = []
         for skip_line, pos in enumerate(task_list, start=1):
             skip_line = (skip_line % worker_num) + 1
             p_future = process_pool.submit(agg_word_process, agg_threshold, pos[0], pos[1], word_total_num,
-                                           skip_line)
+                                           skip_line, per_thread_task_num)
             process_futures.append(p_future)
 
         # 显示任务进度
@@ -332,9 +338,12 @@ def init_process(max_conns: int, max_threads: int, file_path: str):
     global thread_pool
     thread_pool = ThreadPoolExecutor(max_threads, initializer=init_thread, initargs=(file_path,))
 
+    global cal_thread_pool
+    cal_thread_pool = ThreadPoolExecutor(max_threads * 3, initargs=(file_path,))
+
 
 def agg_word_process(agg_threshold: float, start_pos: int, end_pos: int, final_pos: int,
-                     skip_line: int):
+                     skip_line: int, per_thread_task_num: int):
     """
     长尾词聚合处理
     :param agg_threshold: 聚合阈值
@@ -342,6 +351,7 @@ def agg_word_process(agg_threshold: float, start_pos: int, end_pos: int, final_p
     :param end_pos: 任务处理结束边界(不包含)
     :param final_pos: 总任务边界
     :param skip_line: 进度条显示位置
+    :param per_thread_task_num: 每个线程的计算量
     :return:
     """
 
@@ -355,7 +365,7 @@ def agg_word_process(agg_threshold: float, start_pos: int, end_pos: int, final_p
     with tqdm(total=process_len, desc='子进程-%s:文本聚合进度' % os.getpid(), unit='份', unit_scale=True,
               position=skip_line) as pbar:
 
-        thread_futures = [thread_pool.submit(agg_word_thread, main_word_position, agg_threshold) for main_word_position
+        thread_futures = [thread_pool.submit(agg_word_thread, main_word_position, agg_threshold, per_thread_task_num) for main_word_position
                           in
                           range(start_pos, end_pos)]
 
@@ -393,7 +403,14 @@ def init_thread(file_path: str):
     local_var.result_list = []
 
 
-def agg_word_thread(main_word_position: int, agg_threshold: float):
+def agg_word_thread(main_word_position: int, agg_threshold: float, per_thread_task_num: int):
+    """
+    聚合线程
+    :param main_word_position: 主关键词位置
+    :param agg_threshold: 聚合阈值
+    :param per_thread_task_num: 每个线程任务计算量
+    :return:
+    """
     try:
         # 获取已使用位图副本
         local_var.unused_bitmap.frombytes(local_var.redis_cache.get(CACHE_UNUSED_BITMAP))
@@ -445,37 +462,29 @@ def agg_word_thread(main_word_position: int, agg_threshold: float):
         # 延后编码成字符,以防前面直接返回
         main_word = main_word.decode(CHARSET_UTF_8)
 
-        # 计算相似度
-        for candidate_position in range(len(local_var.candidate_position_set)):
+        task_num = len(candidate_word_cache_list)
+        if task_num <= per_thread_task_num:
+            t_result = agg_word_cal(agg_threshold, main_word, main_word_stem_list, candidate_word_cache_list, candidate_word_stem_cache_list)
+            local_var.result_list.extend(t_result)
+        else:
+            task_list = utils.avg_split_task(task_num, per_thread_task_num)
+            cal_futures = [cal_thread_pool.submit(agg_word_cal, agg_threshold, main_word, main_word_stem_list,
+                                       candidate_word_cache_list[start_pos:end_pos],
+                                       candidate_word_stem_cache_list[start_pos:end_pos])
+                for start_pos, end_pos in task_list]
 
-            # 获取关键词、分词,如果存在为空则跳过
-            candidate_word = candidate_word_cache_list[int(candidate_position)]
-            if not candidate_word:
-                continue
-            candidate_word_stem = candidate_word_stem_cache_list[int(candidate_position)]
-            if not candidate_word_stem:
-                continue
-            candidate_word = candidate_word.decode(CHARSET_UTF_8)
-            candidate_word_stem_list = candidate_word_stem.decode(CHARSET_UTF_8).split(",")
-
-            # 计算相关性
-            try:
-                val = utils.cal_cos_sim(main_word, main_word_stem_list, candidate_word,
-                                        candidate_word_stem_list)
-                if val >= agg_threshold:
-                    local_var.redis_cache.setbit(CACHE_UNUSED_BITMAP, candidate_position, 1)
-                    local_var.result_list.append(candidate_word)
-            except Exception as e:
-                logging.error("主关键词:%s 发生异常,涉及的副关键词信息-关键词:%s,分词:%s" % (
-                    main_word, candidate_word, candidate_word_stem_list), e)
+            for cal_future in as_completed(cal_futures):
+                local_var.result_list.extend(cal_future.result())
 
         # 保存结果
         if not local_var.result_list:
             return
         local_var.file_writer.write("%s\n" % main_word)
-        for candidate_word in local_var.result_list:
+        for candidate_position, candidate_word in local_var.result_list:
             local_var.file_writer.write("%s\n" % candidate_word)
+            local_var.redis_pipeline.setbit(CACHE_UNUSED_BITMAP, candidate_position, 1)
         local_var.file_writer.write("\n")
+        local_var.redis_pipeline.execute()
 
     except Exception as e:
         logging.error("子进程发生异常", e)
@@ -486,3 +495,42 @@ def agg_word_thread(main_word_position: int, agg_threshold: float):
         local_var.unused_bitmap.clear()
 
     return
+
+
+def agg_word_cal(agg_threshold: float, main_word: str, main_word_stem_list: list,
+                        candidate_word_cache_list: list, candidate_word_stem_cache_list: list):
+    """
+    计算相关性
+    :param agg_threshold: 聚合阈值
+    :param main_word: 主关键词
+    :param main_word_stem_list: 主关键词分词
+    :param candidate_word_cache_list: 候选词列表
+    :param candidate_word_stem_cache_list: 候选词分词列表
+    :return:
+    """
+    # 计算结果容器
+    cal_result_list = []
+
+    # 计算相似度
+    for candidate_position in range(len(candidate_word_cache_list)):
+
+        # 获取关键词、分词,如果存在为空则跳过
+        candidate_word = candidate_word_cache_list[int(candidate_position)]
+        if not candidate_word:
+            continue
+        candidate_word_stem_list = candidate_word_stem_cache_list[int(candidate_position)]
+        if not candidate_word_stem_list:
+            continue
+        candidate_word = candidate_word.decode(CHARSET_UTF_8)
+        candidate_word_stem_list = candidate_word_stem_list.decode(CHARSET_UTF_8).split(",")
+        # 计算相关性
+        try:
+            val = utils.cal_cos_sim(main_word, main_word_stem_list, candidate_word,
+                                    candidate_word_stem_list)
+            if val >= agg_threshold:
+                cal_result_list.append((candidate_position, candidate_word))
+        except Exception as e:
+            logging.error("主关键词:%s 发生异常,涉及的副关键词信息-关键词:%s,分词:%s" % (
+                main_word, candidate_word, candidate_word_stem_list), e)
+
+    return cal_result_list