#!/usr/bin/env python3
"""导入 memory-tdai/conversations 目录的会话数据到本地向量记忆"""

import os
import json
import hashlib
from datetime import datetime
import sys

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import chromadb
import requests

# 配置
CHROMA_PATH = os.path.expanduser("~/.hermes/local_memory_db")
EMBEDDING_URL = "https://api.nwafu-ai.cn/v1/embeddings"
EMBEDDING_MODEL = "netease"
API_KEY = "sk-X4oFsluxb5FrJR0mC3D047F492254f6dA686EdD65a05PJmr"

def get_embedding(text: str) -> list:
    resp = requests.post(
        EMBEDDING_URL,
        headers={"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"},
        json={"model": EMBEDDING_MODEL, "input": text},
        timeout=30
    )
    resp.raise_for_status()
    return resp.json()["data"][0]["embedding"]

def chunk_text(text: str, max_chars: int = 400) -> list:
    if len(text) <= max_chars:
        return [text]
    
    chunks = []
    paragraphs = text.split('\n\n')
    current_chunk = ""
    
    for para in paragraphs:
        if len(current_chunk) + len(para) + 2 <= max_chars:
            current_chunk += ("\n\n" if current_chunk else "") + para
        else:
            if current_chunk:
                chunks.append(current_chunk)
            if len(para) > max_chars:
                sentences = para.replace('。', '。\n').replace('！', '！\n').replace('？', '？\n').split('\n')
                current_chunk = ""
                for sent in sentences:
                    if len(current_chunk) + len(sent) <= max_chars:
                        current_chunk += sent
                    else:
                        if current_chunk:
                            chunks.append(current_chunk)
                        current_chunk = sent[:max_chars] if len(sent) > max_chars else sent
            else:
                current_chunk = para
    
    if current_chunk:
        chunks.append(current_chunk)
    
    return chunks

def extract_memories_from_jsonl(file_path: str) -> list:
    memories = []
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            
            try:
                data = json.loads(line)
            except json.JSONDecodeError:
                continue
            
            # memory-tdai 格式可能不同，尝试多种提取方式
            content = ""
            role = data.get('role', 'unknown')
            timestamp = data.get('timestamp', data.get('created_at', ''))
            
            # 尝试不同的内容字段
            if 'content' in data:
                content = data['content']
            elif 'text' in data:
                content = data['text']
            elif 'message' in data:
                msg = data['message']
                if isinstance(msg, dict):
                    content = msg.get('content', '')
                    role = msg.get('role', role)
                else:
                    content = str(msg)
            
            if isinstance(content, list):
                text_parts = []
                for item in content:
                    if isinstance(item, dict) and item.get('type') == 'text':
                        text_parts.append(item.get('text', ''))
                    elif isinstance(item, str):
                        text_parts.append(item)
                content = '\n'.join(text_parts)
            
            if not content or len(content) < 20:
                continue
            
            # 跳过纯 JSON/代码
            if content.strip().startswith('{') or content.strip().startswith('['):
                continue
            
            chunks = chunk_text(content, max_chars=400)
            for chunk in chunks:
                if len(chunk) >= 20:
                    memories.append({
                        'content': chunk,
                        'metadata': {
                            'role': role,
                            'timestamp': timestamp,
                            'file': os.path.basename(file_path),
                            'source': 'memory_tdai_conversations'
                        }
                    })
    
    return memories

def main():
    conv_dir = os.path.expanduser("~/.hermes/memory-tdai/conversations")
    
    if not os.path.exists(conv_dir):
        print(f"目录不存在: {conv_dir}")
        return
    
    files = [f for f in os.listdir(conv_dir) if f.endswith('.jsonl')]
    print(f"找到 {len(files)} 个文件")
    
    all_memories = []
    for f in files:
        file_path = os.path.join(conv_dir, f)
        memories = extract_memories_from_jsonl(file_path)
        print(f"  {f}: {len(memories)} 条记忆")
        all_memories.extend(memories)
    
    print(f"\n总共提取 {len(all_memories)} 条记忆")
    
    # 去重
    seen = set()
    unique_memories = []
    for m in all_memories:
        h = hashlib.md5(m['content'].encode()).hexdigest()[:12]
        if h not in seen:
            seen.add(h)
            unique_memories.append(m)
    
    print(f"去重后 {len(unique_memories)} 条记忆")
    
    client = chromadb.PersistentClient(path=CHROMA_PATH)
    collection = client.get_or_create_collection(name="memories", metadata={"hnsw:space": "cosine"})
    
    print(f"\n开始导入...")
    success = 0
    failed = 0
    
    for i, mem in enumerate(unique_memories):
        try:
            embedding = get_embedding(mem['content'])
            doc_id = hashlib.md5(mem['content'].encode()).hexdigest()[:12]
            
            collection.upsert(
                ids=[doc_id],
                embeddings=[embedding],
                documents=[mem['content']],
                metadatas=[{**mem['metadata'], 'created_at': datetime.now().isoformat()}]
            )
            success += 1
        except Exception as e:
            failed += 1
            if failed <= 3:
                print(f"导入失败: {e}")
        
        if (i + 1) % 100 == 0:
            print(f"已导入 {i+1}/{len(unique_memories)}，成功 {success}，失败 {failed}")
    
    print(f"\n完成！成功 {success}，失败 {failed}")
    print(f"数据库总记忆数: {collection.count()}")

if __name__ == "__main__":
    main()
