export_records.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. """
  2. 从 chat_history.db 数据库读取 intent_recognition_records 表,
  3. 按语言(中文/英文)分别导出为 CSV 文件
  4. 支持翻译功能:将中文记录翻译为英文,英文记录翻译为中文
  5. """
  6. import sqlite3
  7. import csv
  8. import re
  9. import translators as ts
  10. import pandas as pd
  11. def translate(text:str, from_language:str, to_language:str, translator: str = 'youdao', timeout: int = 10) -> str:
  12. """
  13. 中文翻译为英文
  14. 参数:
  15. text: 要翻译的中文文本
  16. translator: 翻译引擎 ('google', 'bing', 'baidu', 'youdao'等)
  17. timeout: 超时时间(秒)
  18. 返回:
  19. 翻译后的英文文本
  20. """
  21. if not text or not isinstance(text, str):
  22. return text
  23. try:
  24. # 清理文本
  25. text = text.strip()
  26. if not text:
  27. return ""
  28. # 使用translators库进行翻译
  29. translated_text = ts.translate_text(
  30. text,
  31. translator=translator,
  32. from_language=from_language,
  33. to_language=to_language,
  34. timeout=timeout
  35. )
  36. # 如果返回None,返回原文本
  37. if translated_text is None:
  38. return text
  39. # 处理可能的字典返回值(某些翻译引擎可能返回字典)
  40. if isinstance(translated_text, dict):
  41. # 尝试从字典中提取翻译结果
  42. if 'text' in translated_text:
  43. translated_text = translated_text['text']
  44. elif 'translated' in translated_text:
  45. translated_text = translated_text['translated']
  46. elif 'result' in translated_text:
  47. translated_text = translated_text['result']
  48. else:
  49. print(f" [警告] 未知的字典格式: {translated_text}")
  50. return text
  51. # 确保返回字符串
  52. if isinstance(translated_text, str):
  53. return translated_text.strip()
  54. else:
  55. print(f" [警告] 非字符串返回: type={type(translated_text)}")
  56. return text
  57. except Exception as e:
  58. print(f"翻译错误: {text} -> {str(e)}")
  59. # 出错时返回原文本
  60. return text
  61. def is_chinese(text):
  62. """判断文本是否包含中文字符"""
  63. if not text:
  64. return False
  65. # 匹配中文字符的 Unicode 范围
  66. chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
  67. return bool(chinese_pattern.search(text))
  68. def translate_to_chinese(text):
  69. """将文本翻译为中文"""
  70. if not text:
  71. return text
  72. try:
  73. result = translate(text, from_language='en', to_language='zh')
  74. return result if result else text
  75. except Exception as e:
  76. print(f"翻译失败 (->中文): {text[:30]}... 错误: {e}")
  77. return text
  78. def translate_to_english(text):
  79. """将文本翻译为英文"""
  80. if not text:
  81. return text
  82. try:
  83. result = translate(text, from_language='zh', to_language='en')
  84. return result if result else text
  85. except Exception as e:
  86. print(f"翻译失败 (->英文): {text[:30]}... 错误: {e}")
  87. return text
  88. def export_records():
  89. # 连接数据库
  90. conn = sqlite3.connect('chat_history.db')
  91. cursor = conn.cursor()
  92. # 读取所有数据
  93. cursor.execute("SELECT id, record, time, label FROM intent_recognition_records")
  94. all_records = cursor.fetchall()
  95. # 按语言分类
  96. chinese_records = []
  97. english_records = []
  98. for record in all_records:
  99. record_text = record[1] # record 字段
  100. if is_chinese(record_text):
  101. chinese_records.append(record)
  102. else:
  103. english_records.append(record)
  104. # 翻译中文记录为英文
  105. print("正在翻译中文记录为英文...")
  106. chinese_records_en = []
  107. for i, record in enumerate(chinese_records):
  108. id_, text, time, label = record
  109. translated_text = translate_to_english(text) # 中文翻译为英文
  110. chinese_records_en.append((id_, translated_text, time, label))
  111. if (i + 1) % 100 == 0:
  112. print(f" 已翻译 {i + 1}/{len(chinese_records)} 条中文记录")
  113. # 翻译英文记录为中文
  114. print("正在翻译英文记录为中文...")
  115. english_records_zh = []
  116. for i, record in enumerate(english_records):
  117. id_, text, time, label = record
  118. translated_text = translate_to_chinese(text)
  119. english_records_zh.append((id_, translated_text, time, label))
  120. if (i + 1) % 100 == 0:
  121. print(f" 已翻译 {i + 1}/{len(english_records)} 条英文记录")
  122. # 合并记录
  123. # 中文CSV:chinese_records + english_records_cn(全部为中文内容)
  124. all_chinese_records = chinese_records + english_records_zh
  125. print(f'中文记录为{len(all_chinese_records)}条')
  126. # 英文CSV:english_records + chinese_records_en(全部为英文内容)
  127. all_english_records = english_records + chinese_records_en
  128. print(f'英文记录为{len(all_english_records)}条')
  129. # 导出中文CSV文件
  130. with open('intent_records_chinese.csv', 'w', encoding='utf-8-sig', newline='') as f:
  131. writer = csv.writer(f)
  132. writer.writerow(['id', 'record', 'time', 'label'])
  133. writer.writerows(all_chinese_records)
  134. # 导出英文CSV文件
  135. with open('intent_records_english.csv', 'w', encoding='utf-8-sig', newline='') as f:
  136. writer = csv.writer(f)
  137. writer.writerow(['id', 'record', 'time', 'label'])
  138. writer.writerows(all_english_records)
  139. conn.close()
  140. print(f"\n导出完成!")
  141. print(f"中文CSV文件 (intent_records_chinese.csv):")
  142. print(f" - 原中文记录: {len(chinese_records)} 条")
  143. print(f" - 英文翻译为中文: {len(english_records_zh)} 条")
  144. print(f" - 总计: {len(all_chinese_records)} 条")
  145. print(f"英文CSV文件 (intent_records_english.csv):")
  146. print(f" - 原英文记录: {len(english_records)} 条")
  147. print(f" - 中文翻译为英文: {len(chinese_records_en)} 条")
  148. print(f" - 总计: {len(all_english_records)} 条")
  149. def clean_csv(file_path, output_path):
  150. """
  151. 将label列全部减去1,抛弃负数仅保留0和1的行,仅保留record列和label列
  152. :param file_path: 输入的CSV文件路径
  153. :return: 处理后的DataFrame
  154. """
  155. print(f"正在整理CSV文件中的数据: {file_path}")
  156. # 使用pandas读取
  157. df = pd.read_csv(file_path, encoding='utf-8')
  158. # 将label列的数值全部减去1
  159. df['label'] = df['label'] - 1 # 剔除label为负的行,无效标注
  160. # 剔除所有label小于0的行,仅保留label为0和1的行
  161. df = df[df['label'] >= 0]
  162. # 仅保留record列和label列
  163. df = df[['record', 'label']]
  164. # 保存回原文件
  165. df.to_csv(output_path, index=False, encoding='utf-8-sig')
  166. print(f"处理完成,共保留 {len(df)} 条记录,保存至{output_path}")
  167. return df
  168. def print_statistics(file_path):
  169. """
  170. 打印CSV文件中label的统计信息
  171. :param file_path: CSV文件路径
  172. """
  173. print(f"正在打印CSV文件统计数据: {file_path}")
  174. df = pd.read_csv(file_path, encoding='utf-8')
  175. # 总记录数
  176. total_count = len(df)
  177. print(f"总记录数: {total_count}")
  178. # label分布统计
  179. label_counts = df['label'].value_counts().sort_index()
  180. print("\nLabel分布:")
  181. for label, count in label_counts.items():
  182. percentage = count / total_count * 100
  183. print(f" Label {label}: {count} 条 ({percentage:.2f}%)")
  184. # 缺失值统计
  185. missing_count = df['label'].isna().sum()
  186. if missing_count > 0:
  187. print(f"\n缺失值: {missing_count} 条")
  188. print("-" * 40)
  189. if __name__ == '__main__':
  190. # 第一步,从数据库导出所有记录
  191. # export_records()
  192. # 第二步,整理数据
  193. # clean_csv('intent_records_chinese.csv','data_zh.csv')
  194. # clean_csv('intent_records_english.csv','data_en.csv')
  195. # 打印统计数据
  196. print_statistics('data_en.csv')
  197. print_statistics('data_zh.csv')