From 6b5b05de17ea9728cf1f3d9e9781a3f79fcf8a5f Mon Sep 17 00:00:00 2001 From: dev07 Date: Tue, 9 Sep 2025 15:14:44 +0800 Subject: [PATCH] fix --- sql_ml.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/sql_ml.py b/sql_ml.py index 9884cc8..7843064 100644 --- a/sql_ml.py +++ b/sql_ml.py @@ -262,7 +262,7 @@ def run_dbscan_on_corpus(corpus, eps, min_samples, max_samples=10): return {processed_corpus[i]: [corpus[i]][:max_samples] for i in range(len(corpus))} -def extract_templates_iterative(input_files, output_file, rules, batch_size=1000, eps=0.4, min_samples=2, max_samples_per_template=0): +def extract_templates_iterative(input_files, output_file, rules, batch_size=1000, eps=0.4, min_samples=2, max_samples_per_template=0, content_key='content'): """ 使用小批量迭代的混合策略来提取模板,并为每个模板收集最多10个原始数据集。 支持多个输入文件。 @@ -280,10 +280,16 @@ def extract_templates_iterative(input_files, output_file, rules, batch_size=1000 total_lines += sum(1 for _ in f) for input_file in input_files: + print(f"\n--- 开始处理文件: {input_file} ---") + # 计算当前文件的行数 with open(input_file, 'r', encoding='utf-8') as f: - for line in tqdm(f, total=total_lines, desc="主进程"): + file_lines = sum(1 for _ in f) + + with open(input_file, 'r', encoding='utf-8') as f: + for line in tqdm(f, total=file_lines, desc=f"处理 {input_file.split('/')[-1]}"): try: - content = json.loads(line).get('content') + data = json.loads(line) + content = data.get(content_key) if not content: continue normalized_content = normalize_text(content) @@ -365,7 +371,7 @@ def extract_templates_iterative(input_files, output_file, rules, batch_size=1000 print(f"错误:找不到输入文件 {e.filename}。") return -def extract_values_with_templates(input_files, template_file, output_file): +def extract_values_with_templates(input_files, template_file, output_file, content_key='content'): """ 使用DBSCAN生成的模板从原始消息中提取参数值 支持多个输入文件。 @@ -390,11 +396,16 @@ def extract_values_with_templates(input_files, template_file, output_file): total_lines += sum(1 for _ in f) for input_file in input_files: + print(f"\n--- 开始处理文件: {input_file} ---") + # 计算当前文件的行数 with open(input_file, 'r', encoding='utf-8') as f: - for line in tqdm(f, total=total_lines, desc="提取参数"): + file_lines = sum(1 for _ in f) + + with open(input_file, 'r', encoding='utf-8') as f: + for line in tqdm(f, total=file_lines, desc=f"提取 {input_file.split('/')[-1]}"): try: data = json.loads(line) - content = data.get('content', '') + content = data.get(content_key, '') if not content: continue @@ -436,6 +447,7 @@ if __name__ == "__main__": parser.add_argument('--eps', type=float, default=0.4, help='DBSCAN eps parameter') parser.add_argument('--min_samples', type=int, default=5, help='DBSCAN min_samples parameter') parser.add_argument('--extract_values', action='store_true', help='Extract values using generated templates') + parser.add_argument('--content_key', type=str, default='content', help='Key to extract content from JSON objects (default: content)') args = parser.parse_args() @@ -444,7 +456,8 @@ if __name__ == "__main__": extract_values_with_templates( input_files=args.input_file, template_file='templates_iterative.txt', - output_file=args.output_file + output_file=args.output_file, + content_key=args.content_key ) else: # 执行模板提取 @@ -454,5 +467,6 @@ if __name__ == "__main__": rules=PREDEFINED_RULES, batch_size=args.batch_size, eps=args.eps, - min_samples=args.min_samples + min_samples=args.min_samples, + content_key=args.content_key ) \ No newline at end of file