#!/usr/bin/env python3
"""
从 Hermes 聊天记录提取对话，存入本地向量数据库
"""

import os
import sys
import json
import glob
from datetime import datetime
from pathlib import Path

# 添加脚本目录到路径
sys.path.insert(0, str(Path(__file__).parent))
from local_memory import add_memory, get_client, get_collection

SESSIONS_DIR = os.path.expanduser("~/.hermes/sessions")

def extract_text_content(content):
    """从 content 提取纯文本"""
    if isinstance(content, str):
        return content
    if isinstance(content, list):
        texts = []
        for item in content:
            if isinstance(item, dict):
                if item.get("type") == "text":
                    texts.append(item.get("text", ""))
                elif item.get("type") == "tool_result":
                    # 工具结果也可能有价值
                    result = item.get("content", "")
                    if isinstance(result, str) and len(result) < 500:
                        texts.append(result)
            elif isinstance(item, str):
                texts.append(item)
        return "\n".join(texts)
    return ""

def is_meaningful(text):
    """判断文本是否有意义"""
    if not text or len(text) < 30:
        return False
    # 过滤系统消息和噪音
    skip_patterns = [
        "[System note:",
        "HEARTBEAT_OK",
        "Command not found",
        "bash: ",
        "Traceback (most recent",
        "Error:",
        "exit_code",
        '{"output":',
        '{"success":',
        "```json",
        "HTTP 404",
        "HTTP 503",
    ]
    for pattern in skip_patterns:
        if pattern in text[:300]:
            return False
    return True

def chunk_text(text, max_len=400):
    """将长文本分块"""
    if len(text) <= max_len:
        return [text]
    
    chunks = []
    # 按段落分
    paragraphs = text.split("\n\n")
    
    current = ""
    for para in paragraphs:
        if len(current) + len(para) <= max_len:
            current += para + "\n\n"
        else:
            if current.strip():
                chunks.append(current.strip())
            current = para + "\n\n"
    if current.strip():
        chunks.append(current.strip())
    
    return chunks if chunks else [text[:max_len]]

def process_jsonl(filepath):
    """处理 Hermes jsonl 文件"""
    memories = []
    session_id = Path(filepath).stem
    
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line.strip())
                except json.JSONDecodeError:
                    continue
                
                role = data.get("role")
                content = data.get("content", "")
                
                # 只处理用户和助手消息
                if role not in ("user", "assistant"):
                    continue
                
                text = extract_text_content(content)
                
                if not is_meaningful(text):
                    continue
                
                # 分块处理长文本
                chunks = chunk_text(text, max_len=400)
                
                for chunk in chunks:
                    if len(chunk) >= 30:
                        memories.append({
                            "content": chunk,
                            "role": role,
                            "session_id": session_id,
                            "source": "hermes"
                        })
    except Exception as e:
        print(f"处理 {filepath} 出错: {e}", file=sys.stderr)
    
    return memories

def main():
    # 获取所有 jsonl 文件（排除 request_dump）
    session_files = glob.glob(os.path.join(SESSIONS_DIR, "*.jsonl"))
    session_files = [f for f in session_files if "request_dump" not in f]
    
    print(f"找到 {len(session_files)} 个 Hermes session 文件")
    
    all_memories = []
    
    for i, filepath in enumerate(session_files):
        memories = process_jsonl(filepath)
        all_memories.extend(memories)
        
        if (i + 1) % 10 == 0:
            print(f"已处理 {i + 1}/{len(session_files)} 个文件，提取 {len(all_memories)} 条记忆")
    
    print(f"\n总共提取 {len(all_memories)} 条记忆")
    
    # 去重
    seen = set()
    unique_memories = []
    for mem in all_memories:
        content_hash = hash(mem["content"][:150])
        if content_hash not in seen:
            seen.add(content_hash)
            unique_memories.append(mem)
    
    print(f"去重后 {len(unique_memories)} 条记忆")
    
    # 批量导入
    print("\n开始导入向量数据库...")
    
    success = 0
    failed = 0
    
    for i, mem in enumerate(unique_memories):
        try:
            add_memory(mem["content"])
            success += 1
        except Exception as e:
            failed += 1
            if failed <= 3:
                print(f"导入失败: {str(e)[:80]}", file=sys.stderr)
        
        if (i + 1) % 100 == 0:
            print(f"已导入 {i + 1}/{len(unique_memories)}，成功 {success}，失败 {failed}")
    
    print(f"\n完成！成功导入 {success} 条，失败 {failed} 条")
    
    # 统计
    client = get_client()
    collection = get_collection(client)
    print(f"数据库总记忆数: {collection.count()}")

if __name__ == "__main__":
    main()
