fix
This commit is contained in:
parent
6267458d07
commit
070ce91c53
16
.gitignore
vendored
16
.gitignore
vendored
@ -4,4 +4,20 @@
|
||||
*.txt
|
||||
*.csv
|
||||
*.log
|
||||
*.pkl
|
||||
*.pth
|
||||
*.pyc
|
||||
*.zip
|
||||
*.ipynb
|
||||
*.pyd
|
||||
*.egg-info
|
||||
*.egg
|
||||
*.so
|
||||
*.sql
|
||||
*.db
|
||||
*.db3
|
||||
*.dll
|
||||
*.pyo
|
||||
|
||||
|
||||
__pycache__
|
156
sql2json.py
Normal file
156
sql2json.py
Normal file
@ -0,0 +1,156 @@
|
||||
import re
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
|
||||
# Specify fields to extract
|
||||
selected_fields = ['content']
|
||||
|
||||
# Mapping from original field names to desired keys
|
||||
mapping = {
|
||||
'content': 'content',
|
||||
}
|
||||
|
||||
# Function to extract fields from CREATE TABLE
|
||||
def extract_fields(create_table_sql):
|
||||
# Find the fields section
|
||||
match = re.search(r'CREATE TABLE.*?\((.*)\).*?;', create_table_sql, re.DOTALL)
|
||||
if not match:
|
||||
return []
|
||||
|
||||
fields_section = match.group(1)
|
||||
print("Fields section:", repr(fields_section))
|
||||
# Split by comma, but handle nested parentheses
|
||||
fields = []
|
||||
current_field = ""
|
||||
paren_count = 0
|
||||
for char in fields_section:
|
||||
if char == '(':
|
||||
paren_count += 1
|
||||
elif char == ')':
|
||||
paren_count -= 1
|
||||
elif char == ',' and paren_count == 0:
|
||||
fields.append(current_field.strip())
|
||||
current_field = ""
|
||||
continue
|
||||
current_field += char
|
||||
if current_field.strip():
|
||||
fields.append(current_field.strip())
|
||||
|
||||
print("Fields:", fields)
|
||||
|
||||
# Extract field names
|
||||
field_names = []
|
||||
for field in fields:
|
||||
field = field.strip()
|
||||
if field.startswith('`') and not field.upper().startswith(('PRIMARY', 'UNIQUE', 'KEY', 'INDEX')):
|
||||
name = field.split('`')[1]
|
||||
field_names.append(name)
|
||||
return field_names
|
||||
|
||||
# Function to extract data from INSERT statements
|
||||
def extract_data(insert_sql, field_names):
|
||||
# Find all INSERT VALUES
|
||||
print("Insert sql:", repr(insert_sql[:200]))
|
||||
inserts = re.findall(r'INSERT.*?INTO.*?VALUES\s*\((.*?)\);', insert_sql, re.DOTALL)
|
||||
print("Inserts found:", len(inserts))
|
||||
for i, ins in enumerate(inserts):
|
||||
print(f"Insert {i}:", repr(ins[:100]))
|
||||
data = []
|
||||
for insert in inserts:
|
||||
# Split values by comma, but handle strings
|
||||
values = []
|
||||
current_value = ""
|
||||
in_string = False
|
||||
for char in insert:
|
||||
if char == "'" and not in_string:
|
||||
in_string = True
|
||||
elif char == "'" and in_string:
|
||||
in_string = False
|
||||
elif char == ',' and not in_string:
|
||||
values.append(current_value.strip())
|
||||
current_value = ""
|
||||
continue
|
||||
current_value += char
|
||||
if current_value.strip():
|
||||
values.append(current_value.strip())
|
||||
|
||||
# Clean values (remove quotes)
|
||||
cleaned_values = []
|
||||
for val in values:
|
||||
val = val.strip()
|
||||
if val.startswith("'") and val.endswith("'"):
|
||||
val = val[1:-1]
|
||||
cleaned_values.append(val)
|
||||
|
||||
# Get indices for selected fields
|
||||
selected_indices = [field_names.index(f) for f in selected_fields]
|
||||
|
||||
# Create dict with mapped keys
|
||||
row = {}
|
||||
for f, i in zip(selected_fields, selected_indices):
|
||||
key = mapping.get(f, f)
|
||||
row[key] = cleaned_values[i]
|
||||
data.append(row)
|
||||
return data
|
||||
|
||||
# Main logic
|
||||
if __name__ == "__main__":
|
||||
print("Calculating total lines...")
|
||||
with open('maya_business_bill.sql', 'r') as f:
|
||||
total_lines = sum(1 for _ in f)
|
||||
print(f"Total lines: {total_lines}")
|
||||
|
||||
with open('maya_business_bill.sql', 'r') as f:
|
||||
create_table_sql = ''
|
||||
in_create = False
|
||||
field_names = []
|
||||
selected_indices = []
|
||||
with open('output.jsonl', 'w') as out_f:
|
||||
for line in tqdm(f, total=total_lines, desc="Processing SQL"):
|
||||
if line.strip().startswith('CREATE TABLE'):
|
||||
in_create = True
|
||||
if in_create:
|
||||
create_table_sql += line
|
||||
if line.strip().endswith(';'):
|
||||
in_create = False
|
||||
field_names = extract_fields(create_table_sql)
|
||||
print("Extracted fields:", field_names)
|
||||
selected_indices = [field_names.index(f) for f in selected_fields]
|
||||
elif line.strip().startswith('INSERT INTO'):
|
||||
# Process INSERT
|
||||
match = re.search(r'INSERT.*?VALUES\s*\((.*?)\);', line.strip())
|
||||
if match:
|
||||
insert = match.group(1)
|
||||
# Parse values
|
||||
values = []
|
||||
current_value = ""
|
||||
in_string = False
|
||||
for char in insert:
|
||||
if char == "'" and not in_string:
|
||||
in_string = True
|
||||
elif char == "'" and in_string:
|
||||
in_string = False
|
||||
elif char == ',' and not in_string:
|
||||
values.append(current_value.strip())
|
||||
current_value = ""
|
||||
continue
|
||||
current_value += char
|
||||
if current_value.strip():
|
||||
values.append(current_value.strip())
|
||||
|
||||
# Clean values
|
||||
cleaned_values = []
|
||||
for val in values:
|
||||
val = val.strip()
|
||||
if val.startswith("'") and val.endswith("'"):
|
||||
val = val[1:-1]
|
||||
cleaned_values.append(val)
|
||||
|
||||
# Create row
|
||||
row = {}
|
||||
for f, i in zip(selected_fields, selected_indices):
|
||||
key = mapping.get(f, f)
|
||||
row[key] = cleaned_values[i]
|
||||
out_f.write(json.dumps(row) + '\n')
|
||||
|
||||
print("JSONL output saved to output.jsonl")
|
197
sql_ml.py
197
sql_ml.py
@ -36,6 +36,19 @@ PLACEHOLDER_PATTERNS = {
|
||||
'<交易类型>': r'(.+?)',
|
||||
}
|
||||
|
||||
def get_placeholder(action):
|
||||
"""
|
||||
根据JSON消息的action类型返回对应的占位符
|
||||
"""
|
||||
if 'Received' in action:
|
||||
return '付款人号码'
|
||||
elif 'Sent' in action:
|
||||
return '收款人号码'
|
||||
elif 'Refunded' in action:
|
||||
return '付款人号码'
|
||||
else:
|
||||
return '付款人号码' # 默认值
|
||||
|
||||
def normalize_text(text):
|
||||
# 模式 8: 从银行收款 (这条规则必须先运行)
|
||||
text = re.sub(
|
||||
@ -153,6 +166,14 @@ def normalize_text(text):
|
||||
text
|
||||
)
|
||||
|
||||
# 新增规则:统一处理所有JSON格式消息
|
||||
# 匹配各种action类型:Received money from, Sent money to, Received settlement from, Reversed settlement from等
|
||||
text = re.sub(
|
||||
r'\\\"(Received money from|Sent money to|Received settlement from|Reversed settlement from|Refunded money via|Sent money via)\\\",\\\"target\\\":\\\"(.+?)\\\"',
|
||||
lambda m: f'\\\"{m.group(1)}\\\",\\\"target\\\":\\\"<{get_placeholder(m.group(1))}>\\\"',
|
||||
text
|
||||
)
|
||||
|
||||
return text
|
||||
|
||||
def template_to_regex(template):
|
||||
@ -241,9 +262,10 @@ def run_dbscan_on_corpus(corpus, eps, min_samples, max_samples=10):
|
||||
return {processed_corpus[i]: [corpus[i]][:max_samples] for i in range(len(corpus))}
|
||||
|
||||
|
||||
def extract_templates_iterative(input_file, output_file, rules, batch_size=1000, eps=0.4, min_samples=2, max_samples_per_template=0):
|
||||
def extract_templates_iterative(input_files, output_file, rules, batch_size=1000, eps=0.4, min_samples=2, max_samples_per_template=0):
|
||||
"""
|
||||
使用小批量迭代的混合策略来提取模板,并为每个模板收集最多10个原始数据集。
|
||||
支持多个输入文件。
|
||||
"""
|
||||
print("--- 开始迭代式模板提取 ---")
|
||||
final_templates = {} # template -> list of original contents
|
||||
@ -251,61 +273,64 @@ def extract_templates_iterative(input_file, output_file, rules, batch_size=1000,
|
||||
batch_num = 1
|
||||
|
||||
try:
|
||||
print(f"步骤 1: 逐行处理 '{input_file}' 并动态构建模板库...")
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
total_lines = sum(1 for _ in f)
|
||||
print(f"步骤 1: 逐行处理输入文件 {input_files} 并动态构建模板库...")
|
||||
total_lines = 0
|
||||
for input_file in input_files:
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
total_lines += sum(1 for _ in f)
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, total=total_lines, desc="主进程"):
|
||||
try:
|
||||
content = json.loads(line).get('content')
|
||||
if not content: continue
|
||||
|
||||
normalized_content = normalize_text(content)
|
||||
|
||||
# 1. 检查是否匹配已发现的任何模板
|
||||
if normalized_content in final_templates:
|
||||
if len(final_templates[normalized_content]) < 10:
|
||||
final_templates[normalized_content].append(content)
|
||||
continue
|
||||
|
||||
# 2. 检查是否匹配预定义规则
|
||||
matched_by_rule = False
|
||||
for rule in rules:
|
||||
if rule['pattern'].match(content):
|
||||
if normalized_content not in final_templates:
|
||||
final_templates[normalized_content] = []
|
||||
for input_file in input_files:
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, total=total_lines, desc="主进程"):
|
||||
try:
|
||||
content = json.loads(line).get('content')
|
||||
if not content: continue
|
||||
|
||||
normalized_content = normalize_text(content)
|
||||
|
||||
# 1. 检查是否匹配已发现的任何模板
|
||||
if normalized_content in final_templates:
|
||||
if len(final_templates[normalized_content]) < 10:
|
||||
final_templates[normalized_content].append(content)
|
||||
matched_by_rule = True
|
||||
break
|
||||
|
||||
if matched_by_rule:
|
||||
continue
|
||||
|
||||
# 2. 检查是否匹配预定义规则
|
||||
matched_by_rule = False
|
||||
for rule in rules:
|
||||
if rule['pattern'].match(content):
|
||||
if normalized_content not in final_templates:
|
||||
final_templates[normalized_content] = []
|
||||
if len(final_templates[normalized_content]) < 10:
|
||||
final_templates[normalized_content].append(content)
|
||||
matched_by_rule = True
|
||||
break
|
||||
|
||||
if matched_by_rule:
|
||||
continue
|
||||
|
||||
# 3. 如果都未匹配,加入批处理列表
|
||||
unmatched_batch.append(content)
|
||||
|
||||
# 4. 检查是否触发批处理
|
||||
if len(unmatched_batch) >= batch_size:
|
||||
print(f"\n--- 处理批次 #{batch_num} (大小: {len(unmatched_batch)}) ---")
|
||||
newly_found_templates = run_dbscan_on_corpus(unmatched_batch, eps, min_samples, 10)
|
||||
|
||||
print(f"批次 #{batch_num}: DBSCAN 发现了 {len(newly_found_templates)} 个潜在模板。")
|
||||
for template, originals in newly_found_templates.items():
|
||||
if template in final_templates:
|
||||
remaining = 10 - len(final_templates[template])
|
||||
final_templates[template].extend(originals[:remaining])
|
||||
else:
|
||||
final_templates[template] = originals[:10]
|
||||
print(f"当前总模板数: {len(final_templates)}")
|
||||
|
||||
unmatched_batch.clear()
|
||||
batch_num += 1
|
||||
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
continue
|
||||
|
||||
# 3. 如果都未匹配,加入批处理列表
|
||||
unmatched_batch.append(content)
|
||||
|
||||
# 4. 检查是否触发批处理
|
||||
if len(unmatched_batch) >= batch_size:
|
||||
print(f"\n--- 处理批次 #{batch_num} (大小: {len(unmatched_batch)}) ---")
|
||||
newly_found_templates = run_dbscan_on_corpus(unmatched_batch, eps, min_samples, 10)
|
||||
|
||||
print(f"批次 #{batch_num}: DBSCAN 发现了 {len(newly_found_templates)} 个潜在模板。")
|
||||
for template, originals in newly_found_templates.items():
|
||||
if template in final_templates:
|
||||
remaining = 10 - len(final_templates[template])
|
||||
final_templates[template].extend(originals[:remaining])
|
||||
else:
|
||||
final_templates[template] = originals[:10]
|
||||
print(f"当前总模板数: {len(final_templates)}")
|
||||
|
||||
unmatched_batch.clear()
|
||||
batch_num += 1
|
||||
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
continue
|
||||
|
||||
# --- 收尾处理 ---
|
||||
print("\n--- 文件处理完毕,处理最后一批剩余内容 ---")
|
||||
if unmatched_batch:
|
||||
@ -336,13 +361,14 @@ def extract_templates_iterative(input_file, output_file, rules, batch_size=1000,
|
||||
|
||||
print(f"所有模板已成功写入到 '{output_file}'。")
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"错误:找不到输入文件 '{input_file}'。")
|
||||
except FileNotFoundError as e:
|
||||
print(f"错误:找不到输入文件 {e.filename}。")
|
||||
return
|
||||
|
||||
def extract_values_with_templates(input_file, template_file, output_file):
|
||||
def extract_values_with_templates(input_files, template_file, output_file):
|
||||
"""
|
||||
使用DBSCAN生成的模板从原始消息中提取参数值
|
||||
支持多个输入文件。
|
||||
"""
|
||||
print("--- 开始使用模板提取参数值 ---")
|
||||
|
||||
@ -358,32 +384,35 @@ def extract_values_with_templates(input_file, template_file, output_file):
|
||||
# 从原始数据中提取值
|
||||
extracted_values = []
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
total_lines = sum(1 for _ in f)
|
||||
total_lines = 0
|
||||
for input_file in input_files:
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
total_lines += sum(1 for _ in f)
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, total=total_lines, desc="提取参数"):
|
||||
try:
|
||||
data = json.loads(line)
|
||||
content = data.get('content', '')
|
||||
|
||||
if not content:
|
||||
for input_file in input_files:
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
for line in tqdm(f, total=total_lines, desc="提取参数"):
|
||||
try:
|
||||
data = json.loads(line)
|
||||
content = data.get('content', '')
|
||||
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# 尝试匹配每个模板
|
||||
for template in templates:
|
||||
parameters = extract_parameters(template, content)
|
||||
if parameters:
|
||||
extracted_values.append({
|
||||
'template': template,
|
||||
'message': content,
|
||||
'parameters': parameters
|
||||
})
|
||||
# 找到匹配就跳出循环
|
||||
break
|
||||
|
||||
except (json.JSONDecodeError, Exception):
|
||||
continue
|
||||
|
||||
# 尝试匹配每个模板
|
||||
for template in templates:
|
||||
parameters = extract_parameters(template, content)
|
||||
if parameters:
|
||||
extracted_values.append({
|
||||
'template': template,
|
||||
'message': content,
|
||||
'parameters': parameters
|
||||
})
|
||||
# 找到匹配就跳出循环
|
||||
break
|
||||
|
||||
except (json.JSONDecodeError, Exception):
|
||||
continue
|
||||
|
||||
# 保存提取的值
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
@ -395,13 +424,13 @@ def extract_values_with_templates(input_file, template_file, output_file):
|
||||
|
||||
# --- 使用示例 ---
|
||||
# 假设您已经运行了上一个脚本,生成了 'content_filtered.jsonl'
|
||||
input_jsonl_file = 'content_filtered.jsonl'
|
||||
input_jsonl_files = ['content_filtered.jsonl', 'output.jsonl'] # 默认单个文件,可扩展为多个
|
||||
output_template_file = 'templates_iterative.txt'
|
||||
BATCH_PROCESSING_SIZE = 10000 # 可以根据你的内存和数据量调整
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Extract templates from GCash transaction data.')
|
||||
parser.add_argument('--input_file', type=str, default=input_jsonl_file, help='Input JSONL file path')
|
||||
parser.add_argument('--input_file', type=str, nargs='+', default=input_jsonl_files, help='Input JSONL file paths (multiple files supported)')
|
||||
parser.add_argument('--output_file', type=str, default=output_template_file, help='Output template file path')
|
||||
parser.add_argument('--batch_size', type=int, default=BATCH_PROCESSING_SIZE, help='Batch processing size (data volume)')
|
||||
parser.add_argument('--eps', type=float, default=0.4, help='DBSCAN eps parameter')
|
||||
@ -413,14 +442,14 @@ if __name__ == "__main__":
|
||||
if args.extract_values:
|
||||
# 执行参数提取
|
||||
extract_values_with_templates(
|
||||
input_file=args.input_file,
|
||||
template_file=args.output_file,
|
||||
output_file='extracted_parameters.jsonl'
|
||||
input_files=args.input_file,
|
||||
template_file='templates_iterative.txt',
|
||||
output_file=args.output_file
|
||||
)
|
||||
else:
|
||||
# 执行模板提取
|
||||
extract_templates_iterative(
|
||||
input_file=args.input_file,
|
||||
input_files=args.input_file,
|
||||
output_file=args.output_file,
|
||||
rules=PREDEFINED_RULES,
|
||||
batch_size=args.batch_size,
|
||||
|
Loading…
x
Reference in New Issue
Block a user