fix
This commit is contained in:
parent
070ce91c53
commit
6b5b05de17
30
sql_ml.py
30
sql_ml.py
@ -262,7 +262,7 @@ def run_dbscan_on_corpus(corpus, eps, min_samples, max_samples=10):
|
||||
return {processed_corpus[i]: [corpus[i]][:max_samples] for i in range(len(corpus))}
|
||||
|
||||
|
||||
def extract_templates_iterative(input_files, output_file, rules, batch_size=1000, eps=0.4, min_samples=2, max_samples_per_template=0):
|
||||
def extract_templates_iterative(input_files, output_file, rules, batch_size=1000, eps=0.4, min_samples=2, max_samples_per_template=0, content_key='content'):
|
||||
"""
|
||||
使用小批量迭代的混合策略来提取模板,并为每个模板收集最多10个原始数据集。
|
||||
支持多个输入文件。
|
||||
@ -280,10 +280,16 @@ def extract_templates_iterative(input_files, output_file, rules, batch_size=1000
|
||||
total_lines += sum(1 for _ in f)
|
||||
|
||||
for input_file in input_files:
|
||||
print(f"\n--- 开始处理文件: {input_file} ---")
|
||||
# 计算当前文件的行数
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, total=total_lines, desc="主进程"):
|
||||
file_lines = sum(1 for _ in f)
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, total=file_lines, desc=f"处理 {input_file.split('/')[-1]}"):
|
||||
try:
|
||||
content = json.loads(line).get('content')
|
||||
data = json.loads(line)
|
||||
content = data.get(content_key)
|
||||
if not content: continue
|
||||
|
||||
normalized_content = normalize_text(content)
|
||||
@ -365,7 +371,7 @@ def extract_templates_iterative(input_files, output_file, rules, batch_size=1000
|
||||
print(f"错误:找不到输入文件 {e.filename}。")
|
||||
return
|
||||
|
||||
def extract_values_with_templates(input_files, template_file, output_file):
|
||||
def extract_values_with_templates(input_files, template_file, output_file, content_key='content'):
|
||||
"""
|
||||
使用DBSCAN生成的模板从原始消息中提取参数值
|
||||
支持多个输入文件。
|
||||
@ -390,11 +396,16 @@ def extract_values_with_templates(input_files, template_file, output_file):
|
||||
total_lines += sum(1 for _ in f)
|
||||
|
||||
for input_file in input_files:
|
||||
print(f"\n--- 开始处理文件: {input_file} ---")
|
||||
# 计算当前文件的行数
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, total=total_lines, desc="提取参数"):
|
||||
file_lines = sum(1 for _ in f)
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, total=file_lines, desc=f"提取 {input_file.split('/')[-1]}"):
|
||||
try:
|
||||
data = json.loads(line)
|
||||
content = data.get('content', '')
|
||||
content = data.get(content_key, '')
|
||||
|
||||
if not content:
|
||||
continue
|
||||
@ -436,6 +447,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--eps', type=float, default=0.4, help='DBSCAN eps parameter')
|
||||
parser.add_argument('--min_samples', type=int, default=5, help='DBSCAN min_samples parameter')
|
||||
parser.add_argument('--extract_values', action='store_true', help='Extract values using generated templates')
|
||||
parser.add_argument('--content_key', type=str, default='content', help='Key to extract content from JSON objects (default: content)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -444,7 +456,8 @@ if __name__ == "__main__":
|
||||
extract_values_with_templates(
|
||||
input_files=args.input_file,
|
||||
template_file='templates_iterative.txt',
|
||||
output_file=args.output_file
|
||||
output_file=args.output_file,
|
||||
content_key=args.content_key
|
||||
)
|
||||
else:
|
||||
# 执行模板提取
|
||||
@ -454,5 +467,6 @@ if __name__ == "__main__":
|
||||
rules=PREDEFINED_RULES,
|
||||
batch_size=args.batch_size,
|
||||
eps=args.eps,
|
||||
min_samples=args.min_samples
|
||||
min_samples=args.min_samples,
|
||||
content_key=args.content_key
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user