This commit is contained in:
dev07 2025-09-09 15:14:44 +08:00
parent 070ce91c53
commit 6b5b05de17

View File

@ -262,7 +262,7 @@ def run_dbscan_on_corpus(corpus, eps, min_samples, max_samples=10):
return {processed_corpus[i]: [corpus[i]][:max_samples] for i in range(len(corpus))}
def extract_templates_iterative(input_files, output_file, rules, batch_size=1000, eps=0.4, min_samples=2, max_samples_per_template=0):
def extract_templates_iterative(input_files, output_file, rules, batch_size=1000, eps=0.4, min_samples=2, max_samples_per_template=0, content_key='content'):
"""
使用小批量迭代的混合策略来提取模板并为每个模板收集最多10个原始数据集
支持多个输入文件
@ -280,10 +280,16 @@ def extract_templates_iterative(input_files, output_file, rules, batch_size=1000
total_lines += sum(1 for _ in f)
for input_file in input_files:
print(f"\n--- 开始处理文件: {input_file} ---")
# 计算当前文件的行数
with open(input_file, 'r', encoding='utf-8') as f:
for line in tqdm(f, total=total_lines, desc="主进程"):
file_lines = sum(1 for _ in f)
with open(input_file, 'r', encoding='utf-8') as f:
for line in tqdm(f, total=file_lines, desc=f"处理 {input_file.split('/')[-1]}"):
try:
content = json.loads(line).get('content')
data = json.loads(line)
content = data.get(content_key)
if not content: continue
normalized_content = normalize_text(content)
@ -365,7 +371,7 @@ def extract_templates_iterative(input_files, output_file, rules, batch_size=1000
print(f"错误:找不到输入文件 {e.filename}")
return
def extract_values_with_templates(input_files, template_file, output_file):
def extract_values_with_templates(input_files, template_file, output_file, content_key='content'):
"""
使用DBSCAN生成的模板从原始消息中提取参数值
支持多个输入文件
@ -390,11 +396,16 @@ def extract_values_with_templates(input_files, template_file, output_file):
total_lines += sum(1 for _ in f)
for input_file in input_files:
print(f"\n--- 开始处理文件: {input_file} ---")
# 计算当前文件的行数
with open(input_file, 'r', encoding='utf-8') as f:
for line in tqdm(f, total=total_lines, desc="提取参数"):
file_lines = sum(1 for _ in f)
with open(input_file, 'r', encoding='utf-8') as f:
for line in tqdm(f, total=file_lines, desc=f"提取 {input_file.split('/')[-1]}"):
try:
data = json.loads(line)
content = data.get('content', '')
content = data.get(content_key, '')
if not content:
continue
@ -436,6 +447,7 @@ if __name__ == "__main__":
parser.add_argument('--eps', type=float, default=0.4, help='DBSCAN eps parameter')
parser.add_argument('--min_samples', type=int, default=5, help='DBSCAN min_samples parameter')
parser.add_argument('--extract_values', action='store_true', help='Extract values using generated templates')
parser.add_argument('--content_key', type=str, default='content', help='Key to extract content from JSON objects (default: content)')
args = parser.parse_args()
@ -444,7 +456,8 @@ if __name__ == "__main__":
extract_values_with_templates(
input_files=args.input_file,
template_file='templates_iterative.txt',
output_file=args.output_file
output_file=args.output_file,
content_key=args.content_key
)
else:
# 执行模板提取
@ -454,5 +467,6 @@ if __name__ == "__main__":
rules=PREDEFINED_RULES,
batch_size=args.batch_size,
eps=args.eps,
min_samples=args.min_samples
min_samples=args.min_samples,
content_key=args.content_key
)