fix
This commit is contained in:
parent
070ce91c53
commit
6b5b05de17
30
sql_ml.py
30
sql_ml.py
@ -262,7 +262,7 @@ def run_dbscan_on_corpus(corpus, eps, min_samples, max_samples=10):
|
|||||||
return {processed_corpus[i]: [corpus[i]][:max_samples] for i in range(len(corpus))}
|
return {processed_corpus[i]: [corpus[i]][:max_samples] for i in range(len(corpus))}
|
||||||
|
|
||||||
|
|
||||||
def extract_templates_iterative(input_files, output_file, rules, batch_size=1000, eps=0.4, min_samples=2, max_samples_per_template=0):
|
def extract_templates_iterative(input_files, output_file, rules, batch_size=1000, eps=0.4, min_samples=2, max_samples_per_template=0, content_key='content'):
|
||||||
"""
|
"""
|
||||||
使用小批量迭代的混合策略来提取模板,并为每个模板收集最多10个原始数据集。
|
使用小批量迭代的混合策略来提取模板,并为每个模板收集最多10个原始数据集。
|
||||||
支持多个输入文件。
|
支持多个输入文件。
|
||||||
@ -280,10 +280,16 @@ def extract_templates_iterative(input_files, output_file, rules, batch_size=1000
|
|||||||
total_lines += sum(1 for _ in f)
|
total_lines += sum(1 for _ in f)
|
||||||
|
|
||||||
for input_file in input_files:
|
for input_file in input_files:
|
||||||
|
print(f"\n--- 开始处理文件: {input_file} ---")
|
||||||
|
# 计算当前文件的行数
|
||||||
with open(input_file, 'r', encoding='utf-8') as f:
|
with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
for line in tqdm(f, total=total_lines, desc="主进程"):
|
file_lines = sum(1 for _ in f)
|
||||||
|
|
||||||
|
with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in tqdm(f, total=file_lines, desc=f"处理 {input_file.split('/')[-1]}"):
|
||||||
try:
|
try:
|
||||||
content = json.loads(line).get('content')
|
data = json.loads(line)
|
||||||
|
content = data.get(content_key)
|
||||||
if not content: continue
|
if not content: continue
|
||||||
|
|
||||||
normalized_content = normalize_text(content)
|
normalized_content = normalize_text(content)
|
||||||
@ -365,7 +371,7 @@ def extract_templates_iterative(input_files, output_file, rules, batch_size=1000
|
|||||||
print(f"错误:找不到输入文件 {e.filename}。")
|
print(f"错误:找不到输入文件 {e.filename}。")
|
||||||
return
|
return
|
||||||
|
|
||||||
def extract_values_with_templates(input_files, template_file, output_file):
|
def extract_values_with_templates(input_files, template_file, output_file, content_key='content'):
|
||||||
"""
|
"""
|
||||||
使用DBSCAN生成的模板从原始消息中提取参数值
|
使用DBSCAN生成的模板从原始消息中提取参数值
|
||||||
支持多个输入文件。
|
支持多个输入文件。
|
||||||
@ -390,11 +396,16 @@ def extract_values_with_templates(input_files, template_file, output_file):
|
|||||||
total_lines += sum(1 for _ in f)
|
total_lines += sum(1 for _ in f)
|
||||||
|
|
||||||
for input_file in input_files:
|
for input_file in input_files:
|
||||||
|
print(f"\n--- 开始处理文件: {input_file} ---")
|
||||||
|
# 计算当前文件的行数
|
||||||
with open(input_file, 'r', encoding='utf-8') as f:
|
with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
for line in tqdm(f, total=total_lines, desc="提取参数"):
|
file_lines = sum(1 for _ in f)
|
||||||
|
|
||||||
|
with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in tqdm(f, total=file_lines, desc=f"提取 {input_file.split('/')[-1]}"):
|
||||||
try:
|
try:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
content = data.get('content', '')
|
content = data.get(content_key, '')
|
||||||
|
|
||||||
if not content:
|
if not content:
|
||||||
continue
|
continue
|
||||||
@ -436,6 +447,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument('--eps', type=float, default=0.4, help='DBSCAN eps parameter')
|
parser.add_argument('--eps', type=float, default=0.4, help='DBSCAN eps parameter')
|
||||||
parser.add_argument('--min_samples', type=int, default=5, help='DBSCAN min_samples parameter')
|
parser.add_argument('--min_samples', type=int, default=5, help='DBSCAN min_samples parameter')
|
||||||
parser.add_argument('--extract_values', action='store_true', help='Extract values using generated templates')
|
parser.add_argument('--extract_values', action='store_true', help='Extract values using generated templates')
|
||||||
|
parser.add_argument('--content_key', type=str, default='content', help='Key to extract content from JSON objects (default: content)')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -444,7 +456,8 @@ if __name__ == "__main__":
|
|||||||
extract_values_with_templates(
|
extract_values_with_templates(
|
||||||
input_files=args.input_file,
|
input_files=args.input_file,
|
||||||
template_file='templates_iterative.txt',
|
template_file='templates_iterative.txt',
|
||||||
output_file=args.output_file
|
output_file=args.output_file,
|
||||||
|
content_key=args.content_key
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 执行模板提取
|
# 执行模板提取
|
||||||
@ -454,5 +467,6 @@ if __name__ == "__main__":
|
|||||||
rules=PREDEFINED_RULES,
|
rules=PREDEFINED_RULES,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
eps=args.eps,
|
eps=args.eps,
|
||||||
min_samples=args.min_samples
|
min_samples=args.min_samples,
|
||||||
|
content_key=args.content_key
|
||||||
)
|
)
|
Loading…
x
Reference in New Issue
Block a user