#!/usr/bin/env python3
"""导入 workspace_backup 目录的会话数据到本地向量记忆"""

import os
import json
import hashlib
from datetime import datetime

# 复用 local_memory.py 的功能
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import chromadb
import requests

# 配置
CHROMA_PATH = os.path.expanduser("~/.hermes/local_memory_db")
EMBEDDING_URL = "https://api.nwafu-ai.cn/v1/embeddings"
EMBEDDING_MODEL = "netease"
API_KEY = "sk-X4oFsluxb5FrJR0mC3D047F492254f6dA686EdD65a05PJmr"

def get_embedding(text: str) -> list:
    """获取文本的 embedding"""
    resp = requests.post(
        EMBEDDING_URL,
        headers={"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"},
        json={"model": EMBEDDING_MODEL, "input": text},
        timeout=30
    )
    resp.raise_for_status()
    return resp.json()["data"][0]["embedding"]

def chunk_text(text: str, max_chars: int = 400) -> list:
    """将长文本分块"""
    if len(text) <= max_chars:
        return [text]
    
    chunks = []
    paragraphs = text.split('\n\n')
    current_chunk = ""
    
    for para in paragraphs:
        if len(current_chunk) + len(para) + 2 <= max_chars:
            current_chunk += ("\n\n" if current_chunk else "") + para
        else:
            if current_chunk:
                chunks.append(current_chunk)
            if len(para) > max_chars:
                # 按句子分割
                sentences = para.replace('。', '。\n').replace('！', '！\n').replace('？', '？\n').split('\n')
                current_chunk = ""
                for sent in sentences:
                    if len(current_chunk) + len(sent) <= max_chars:
                        current_chunk += sent
                    else:
                        if current_chunk:
                            chunks.append(current_chunk)
                        current_chunk = sent[:max_chars] if len(sent) > max_chars else sent
            else:
                current_chunk = para
    
    if current_chunk:
        chunks.append(current_chunk)
    
    return chunks

def extract_memories_from_jsonl(file_path: str) -> list:
    """从 JSONL 文件提取记忆"""
    memories = []
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            
            try:
                data = json.loads(line)
            except json.JSONDecodeError:
                continue
            
            # 提取消息内容
            message = data.get('message', {})
            role = message.get('role', '')
            timestamp = data.get('timestamp', '')
            
            # 只处理 user 和 assistant 消息
            if role not in ('user', 'assistant'):
                continue
            
            # 提取文本内容
            content = message.get('content', '')
            if isinstance(content, list):
                # 处理复杂内容结构
                text_parts = []
                for item in content:
                    if isinstance(item, dict):
                        if item.get('type') == 'text':
                            text_parts.append(item.get('text', ''))
                    elif isinstance(item, str):
                        text_parts.append(item)
                content = '\n'.join(text_parts)
            
            if not content or len(content) < 20:
                continue
            
            # 跳过工具调用和代码输出
            if content.startswith('{') or content.startswith('['):
                continue
            if 'toolCall' in str(message) or 'toolResult' in str(message):
                # 但保留有意义的工具结果摘要
                if len(content) < 100:
                    continue
            
            # 分块
            chunks = chunk_text(content, max_chars=400)
            for chunk in chunks:
                if len(chunk) >= 20:
                    memories.append({
                        'content': chunk,
                        'metadata': {
                            'role': role,
                            'timestamp': timestamp,
                            'file': os.path.basename(file_path),
                            'source': 'openclaw_workspace_backup'
                        }
                    })
    
    return memories

def main():
    backup_dir = os.path.expanduser("~/.hermes/openclaw_sessions_backup/workspace_backup")
    
    if not os.path.exists(backup_dir):
        print(f"目录不存在: {backup_dir}")
        return
    
    # 获取所有 jsonl 文件
    files = [f for f in os.listdir(backup_dir) if f.endswith('.jsonl')]
    print(f"找到 {len(files)} 个文件")
    
    # 提取所有记忆
    all_memories = []
    for f in files:
        file_path = os.path.join(backup_dir, f)
        file_size = os.path.getsize(file_path) / 1024 / 1024
        print(f"处理 {f} ({file_size:.1f}MB)...")
        memories = extract_memories_from_jsonl(file_path)
        print(f"  提取 {len(memories)} 条记忆")
        all_memories.extend(memories)
    
    print(f"\n总共提取 {len(all_memories)} 条记忆")
    
    # 去重
    seen = set()
    unique_memories = []
    for m in all_memories:
        h = hashlib.md5(m['content'].encode()).hexdigest()[:12]
        if h not in seen:
            seen.add(h)
            unique_memories.append(m)
    
    print(f"去重后 {len(unique_memories)} 条记忆")
    
    # 连接数据库
    client = chromadb.PersistentClient(path=CHROMA_PATH)
    collection = client.get_or_create_collection(
        name="memories",
        metadata={"hnsw:space": "cosine"}
    )
    
    print(f"\n开始导入向量数据库...")
    success = 0
    failed = 0
    
    for i, mem in enumerate(unique_memories):
        try:
            embedding = get_embedding(mem['content'])
            doc_id = hashlib.md5(mem['content'].encode()).hexdigest()[:12]
            
            collection.upsert(
                ids=[doc_id],
                embeddings=[embedding],
                documents=[mem['content']],
                metadatas=[{
                    **mem['metadata'],
                    'created_at': datetime.now().isoformat()
                }]
            )
            success += 1
        except Exception as e:
            failed += 1
            if failed <= 5:
                print(f"导入失败: {e}")
        
        if (i + 1) % 100 == 0:
            print(f"已导入 {i+1}/{len(unique_memories)}，成功 {success}，失败 {failed}")
    
    print(f"\n完成！成功导入 {success} 条，失败 {failed} 条")
    print(f"数据库总记忆数: {collection.count()}")

if __name__ == "__main__":
    main()
