#!/usr/bin/env python3
"""
从 OpenClaw 聊天记录提取对话，存入本地向量数据库
"""

import os
import sys
import json
import glob
from datetime import datetime
from pathlib import Path

# 添加脚本目录到路径
sys.path.insert(0, str(Path(__file__).parent))
from local_memory import add_memory, get_client, get_collection

SESSIONS_DIR = os.path.expanduser("~/.hermes/openclaw_sessions_backup/main_sessions")

def extract_text_content(content):
    """从 content 数组提取纯文本"""
    if isinstance(content, str):
        return content
    if isinstance(content, list):
        texts = []
        for item in content:
            if isinstance(item, dict):
                if item.get("type") == "text":
                    texts.append(item.get("text", ""))
            elif isinstance(item, str):
                texts.append(item)
        return "\n".join(texts)
    return ""

def is_meaningful(text):
    """判断文本是否有意义（过滤掉系统消息、工具调用等）"""
    if not text or len(text) < 20:
        return False
    # 过滤系统消息
    skip_patterns = [
        "[Subagent Context]",
        "HEARTBEAT_OK",
        "Command not found",
        "bash:",
        "Error:",
        "Traceback",
        "[System note:",
    ]
    for pattern in skip_patterns:
        if pattern in text[:200]:
            return False
    return True

def chunk_text(text, max_len=500):
    """将长文本分块"""
    if len(text) <= max_len:
        return [text]
    
    chunks = []
    sentences = text.replace("。", "。\n").replace("！", "！\n").replace("？", "？\n").split("\n")
    
    current = ""
    for sent in sentences:
        if len(current) + len(sent) <= max_len:
            current += sent
        else:
            if current:
                chunks.append(current.strip())
            current = sent
    if current:
        chunks.append(current.strip())
    
    return chunks if chunks else [text[:max_len]]

def process_session(filepath):
    """处理单个 session 文件"""
    memories = []
    session_id = Path(filepath).stem
    
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line.strip())
                except json.JSONDecodeError:
                    continue
                
                if data.get("type") != "message":
                    continue
                
                msg = data.get("message", {})
                role = msg.get("role")
                content = msg.get("content", [])
                timestamp = data.get("timestamp", "")
                
                # 只处理用户和助手消息
                if role not in ("user", "assistant"):
                    continue
                
                text = extract_text_content(content)
                
                if not is_meaningful(text):
                    continue
                
                # 分块处理长文本
                chunks = chunk_text(text, max_len=800)
                
                for chunk in chunks:
                    if len(chunk) >= 20:
                        memories.append({
                            "content": chunk,
                            "metadata": {
                                "role": role,
                                "session_id": session_id,
                                "timestamp": timestamp,
                                "source": "openclaw"
                            }
                        })
    except Exception as e:
        print(f"处理 {filepath} 出错: {e}", file=sys.stderr)
    
    return memories

def main():
    # 获取所有 session 文件
    session_files = glob.glob(os.path.join(SESSIONS_DIR, "*.jsonl"))
    # 排除 trajectory 文件
    session_files = [f for f in session_files if "trajectory" not in f]
    
    print(f"找到 {len(session_files)} 个 session 文件")
    
    all_memories = []
    
    for i, filepath in enumerate(session_files):
        memories = process_session(filepath)
        all_memories.extend(memories)
        
        if (i + 1) % 50 == 0:
            print(f"已处理 {i + 1}/{len(session_files)} 个文件，提取 {len(all_memories)} 条记忆")
    
    print(f"\n总共提取 {len(all_memories)} 条记忆")
    
    # 去重（基于内容）
    seen = set()
    unique_memories = []
    for mem in all_memories:
        content_hash = hash(mem["content"][:200])
        if content_hash not in seen:
            seen.add(content_hash)
            unique_memories.append(mem)
    
    print(f"去重后 {len(unique_memories)} 条记忆")
    
    # 批量导入
    print("\n开始导入向量数据库...")
    
    success = 0
    failed = 0
    
    for i, mem in enumerate(unique_memories):
        try:
            add_memory(mem["content"])
            success += 1
        except Exception as e:
            failed += 1
            if failed <= 5:
                print(f"导入失败: {str(e)[:100]}", file=sys.stderr)
        
        if (i + 1) % 100 == 0:
            print(f"已导入 {i + 1}/{len(unique_memories)}，成功 {success}，失败 {failed}")
    
    print(f"\n完成！成功导入 {success} 条，失败 {failed} 条")
    
    # 统计
    client = get_client()
    collection = get_collection(client)
    print(f"数据库总记忆数: {collection.count()}")

if __name__ == "__main__":
    main()
