| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- """
- 从 chat_history.db 数据库读取 intent_recognition_records 表,
- 按语言(中文/英文)分别导出为 CSV 文件
- 支持翻译功能:将中文记录翻译为英文,英文记录翻译为中文
- """
- import sqlite3
- import csv
- import re
- import translators as ts
- import pandas as pd
- def translate(text:str, from_language:str, to_language:str, translator: str = 'youdao', timeout: int = 10) -> str:
- """
- 中文翻译为英文
- 参数:
- text: 要翻译的中文文本
- translator: 翻译引擎 ('google', 'bing', 'baidu', 'youdao'等)
- timeout: 超时时间(秒)
- 返回:
- 翻译后的英文文本
- """
- if not text or not isinstance(text, str):
- return text
- try:
- # 清理文本
- text = text.strip()
- if not text:
- return ""
- # 使用translators库进行翻译
- translated_text = ts.translate_text(
- text,
- translator=translator,
- from_language=from_language,
- to_language=to_language,
- timeout=timeout
- )
- # 如果返回None,返回原文本
- if translated_text is None:
- return text
- # 处理可能的字典返回值(某些翻译引擎可能返回字典)
- if isinstance(translated_text, dict):
- # 尝试从字典中提取翻译结果
- if 'text' in translated_text:
- translated_text = translated_text['text']
- elif 'translated' in translated_text:
- translated_text = translated_text['translated']
- elif 'result' in translated_text:
- translated_text = translated_text['result']
- else:
- print(f" [警告] 未知的字典格式: {translated_text}")
- return text
- # 确保返回字符串
- if isinstance(translated_text, str):
- return translated_text.strip()
- else:
- print(f" [警告] 非字符串返回: type={type(translated_text)}")
- return text
- except Exception as e:
- print(f"翻译错误: {text} -> {str(e)}")
- # 出错时返回原文本
- return text
- def is_chinese(text):
- """判断文本是否包含中文字符"""
- if not text:
- return False
- # 匹配中文字符的 Unicode 范围
- chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
- return bool(chinese_pattern.search(text))
- def translate_to_chinese(text):
- """将文本翻译为中文"""
- if not text:
- return text
- try:
- result = translate(text, from_language='en', to_language='zh')
- return result if result else text
- except Exception as e:
- print(f"翻译失败 (->中文): {text[:30]}... 错误: {e}")
- return text
- def translate_to_english(text):
- """将文本翻译为英文"""
- if not text:
- return text
- try:
- result = translate(text, from_language='zh', to_language='en')
- return result if result else text
- except Exception as e:
- print(f"翻译失败 (->英文): {text[:30]}... 错误: {e}")
- return text
- def export_records():
- # 连接数据库
- conn = sqlite3.connect('chat_history.db')
- cursor = conn.cursor()
-
- # 读取所有数据
- cursor.execute("SELECT id, record, time, label FROM intent_recognition_records")
- all_records = cursor.fetchall()
-
- # 按语言分类
- chinese_records = []
- english_records = []
-
- for record in all_records:
- record_text = record[1] # record 字段
- if is_chinese(record_text):
- chinese_records.append(record)
- else:
- english_records.append(record)
-
- # 翻译中文记录为英文
- print("正在翻译中文记录为英文...")
- chinese_records_en = []
- for i, record in enumerate(chinese_records):
- id_, text, time, label = record
- translated_text = translate_to_english(text) # 中文翻译为英文
- chinese_records_en.append((id_, translated_text, time, label))
- if (i + 1) % 100 == 0:
- print(f" 已翻译 {i + 1}/{len(chinese_records)} 条中文记录")
-
- # 翻译英文记录为中文
- print("正在翻译英文记录为中文...")
- english_records_zh = []
- for i, record in enumerate(english_records):
- id_, text, time, label = record
- translated_text = translate_to_chinese(text)
- english_records_zh.append((id_, translated_text, time, label))
- if (i + 1) % 100 == 0:
- print(f" 已翻译 {i + 1}/{len(english_records)} 条英文记录")
-
- # 合并记录
- # 中文CSV:chinese_records + english_records_cn(全部为中文内容)
- all_chinese_records = chinese_records + english_records_zh
- print(f'中文记录为{len(all_chinese_records)}条')
- # 英文CSV:english_records + chinese_records_en(全部为英文内容)
- all_english_records = english_records + chinese_records_en
- print(f'英文记录为{len(all_english_records)}条')
-
- # 导出中文CSV文件
- with open('intent_records_chinese.csv', 'w', encoding='utf-8-sig', newline='') as f:
- writer = csv.writer(f)
- writer.writerow(['id', 'record', 'time', 'label'])
- writer.writerows(all_chinese_records)
-
- # 导出英文CSV文件
- with open('intent_records_english.csv', 'w', encoding='utf-8-sig', newline='') as f:
- writer = csv.writer(f)
- writer.writerow(['id', 'record', 'time', 'label'])
- writer.writerows(all_english_records)
-
- conn.close()
-
- print(f"\n导出完成!")
- print(f"中文CSV文件 (intent_records_chinese.csv):")
- print(f" - 原中文记录: {len(chinese_records)} 条")
- print(f" - 英文翻译为中文: {len(english_records_zh)} 条")
- print(f" - 总计: {len(all_chinese_records)} 条")
- print(f"英文CSV文件 (intent_records_english.csv):")
- print(f" - 原英文记录: {len(english_records)} 条")
- print(f" - 中文翻译为英文: {len(chinese_records_en)} 条")
- print(f" - 总计: {len(all_english_records)} 条")
- def clean_csv(file_path, output_path):
- """
- 将label列全部减去1,抛弃负数仅保留0和1的行,仅保留record列和label列
- :param file_path: 输入的CSV文件路径
- :return: 处理后的DataFrame
- """
- print(f"正在整理CSV文件中的数据: {file_path}")
- # 使用pandas读取
- df = pd.read_csv(file_path, encoding='utf-8')
- # 将label列的数值全部减去1
- df['label'] = df['label'] - 1 # 剔除label为负的行,无效标注
- # 剔除所有label小于0的行,仅保留label为0和1的行
- df = df[df['label'] >= 0]
- # 仅保留record列和label列
- df = df[['record', 'label']]
- # 保存回原文件
- df.to_csv(output_path, index=False, encoding='utf-8-sig')
- print(f"处理完成,共保留 {len(df)} 条记录,保存至{output_path}")
- return df
- def print_statistics(file_path):
- """
- 打印CSV文件中label的统计信息
- :param file_path: CSV文件路径
- """
- print(f"正在打印CSV文件统计数据: {file_path}")
- df = pd.read_csv(file_path, encoding='utf-8')
- # 总记录数
- total_count = len(df)
- print(f"总记录数: {total_count}")
- # label分布统计
- label_counts = df['label'].value_counts().sort_index()
- print("\nLabel分布:")
- for label, count in label_counts.items():
- percentage = count / total_count * 100
- print(f" Label {label}: {count} 条 ({percentage:.2f}%)")
- # 缺失值统计
- missing_count = df['label'].isna().sum()
- if missing_count > 0:
- print(f"\n缺失值: {missing_count} 条")
- print("-" * 40)
- if __name__ == '__main__':
- # 第一步,从数据库导出所有记录
- # export_records()
- # 第二步,整理数据
- # clean_csv('intent_records_chinese.csv','data_zh.csv')
- # clean_csv('intent_records_english.csv','data_en.csv')
- # 打印统计数据
- print_statistics('data_en.csv')
- print_statistics('data_zh.csv')
|