#!/usr/bin/env python3
"""
本地向量记忆系统
- Embedding: netease (bce-embedding-base_v1)
- 向量数据库: Chroma (本地)
- Reranker: bge-reranker-v2-m3

用法:
  python local_memory.py add "记忆内容"
  python local_memory.py search "查询内容" [--top_k 5] [--rerank]
  python local_memory.py list [--limit 20]
  python local_memory.py delete <memory_id>
  python local_memory.py export > memories.json
  python local_memory.py import < memories.json
"""

import os
import sys
import json
import argparse
import hashlib
import requests
from datetime import datetime
from typing import List, Optional

import chromadb
from chromadb.config import Settings

# 配置
CHROMA_PATH = os.path.expanduser("~/.hermes/local_memory_db")
COLLECTION_NAME = "memories_4096"  # 新的 4096 维 collection
COLLECTION_NAME_FALLBACK = "memories"  # 旧的 768 维 collection（fallback 搜索）

# 主 API：硅基流动 Qwen3-VL 系列（4096 维）
EMBEDDING_API_URL = "https://api.siliconflow.cn/v1/embeddings"
EMBEDDING_API_KEY = "sk-llicoufjowafeksyzknhxebhxfyaawjrmiiefpswhofheirv"
EMBEDDING_MODEL = "Qwen/Qwen3-VL-Embedding-8B"

RERANKER_API_URL = "https://api.siliconflow.cn/v1/rerank"
RERANKER_API_KEY = "sk-llicoufjowafeksyzknhxebhxfyaawjrmiiefpswhofheirv"
RERANKER_MODEL = "Qwen/Qwen3-VL-Reranker-8B"

# Fallback API：硅基流动 BCE 系列（768 维，用于旧 collection）
FALLBACK_EMBEDDING_API_URL = "https://api.siliconflow.cn/v1/embeddings"
FALLBACK_EMBEDDING_API_KEY = "sk-llicoufjowafeksyzknhxebhxfyaawjrmiiefpswhofheirv"
FALLBACK_EMBEDDING_MODEL = "netease-youdao/bce-embedding-base_v1"

FALLBACK_RERANKER_API_URL = "https://api.siliconflow.cn/v1/rerank"
FALLBACK_RERANKER_API_KEY = "sk-llicoufjowafeksyzknhxebhxfyaawjrmiiefpswhofheirv"
FALLBACK_RERANKER_MODEL = "netease-youdao/bce-reranker-base_v1"


def get_embedding(text: str) -> List[float]:
    """获取文本的 embedding 向量（主：硅基流动，fallback：西农）"""
    # 主 API：硅基流动
    try:
        headers = {
            "Authorization": f"Bearer {EMBEDDING_API_KEY}",
            "Content-Type": "application/json"
        }
        data = {
            "model": EMBEDDING_MODEL,
            "input": text
        }
        resp = requests.post(EMBEDDING_API_URL, headers=headers, json=data, timeout=30)
        resp.raise_for_status()
        result = resp.json()
        return result["data"][0]["embedding"]
    except Exception as e:
        print(f"[Embedding] 硅基流动失败: {e}，尝试西农...")
    
    # Fallback：西农
    headers = {
        "Authorization": f"Bearer {FALLBACK_EMBEDDING_API_KEY}",
        "Content-Type": "application/json"
    }
    data = {
        "model": FALLBACK_EMBEDDING_MODEL,
        "input": text
    }
    resp = requests.post(FALLBACK_EMBEDDING_API_URL, headers=headers, json=data, timeout=30)
    resp.raise_for_status()
    result = resp.json()
    return result["data"][0]["embedding"]


def rerank(query: str, documents: List[str], top_k: int = 5) -> List[dict]:
    """使用 Reranker 重排序（主：硅基流动，fallback：西农）"""
    if not documents:
        return []
    
    # 主 API：硅基流动
    try:
        headers = {
            "Authorization": f"Bearer {RERANKER_API_KEY}",
            "Content-Type": "application/json"
        }
        data = {
            "model": RERANKER_MODEL,
            "query": query,
            "documents": documents,
            "top_n": min(top_k, len(documents))
        }
        resp = requests.post(RERANKER_API_URL, headers=headers, json=data, timeout=30)
        resp.raise_for_status()
        result = resp.json()
        return result.get("results", [])
    except Exception as e:
        print(f"[Reranker] 硅基流动失败: {e}，尝试西农...")
    
    # Fallback：西农
    headers = {
        "Authorization": f"Bearer {FALLBACK_RERANKER_API_KEY}",
        "Content-Type": "application/json"
    }
    data = {
        "model": FALLBACK_RERANKER_MODEL,
        "query": query,
        "documents": documents,
        "top_n": min(top_k, len(documents))
    }
    resp = requests.post(FALLBACK_RERANKER_API_URL, headers=headers, json=data, timeout=30)
    resp.raise_for_status()
    result = resp.json()
    return result.get("results", [])


def get_client():
    """获取 Chroma 客户端"""
    os.makedirs(CHROMA_PATH, exist_ok=True)
    return chromadb.PersistentClient(path=CHROMA_PATH)


def get_collection(client):
    """获取或创建 collection"""
    return client.get_or_create_collection(
        name=COLLECTION_NAME,
        metadata={"description": "本地向量记忆"}
    )


def generate_id(content: str) -> str:
    """生成记忆ID"""
    timestamp = datetime.now().isoformat()
    return hashlib.md5(f"{content}{timestamp}".encode()).hexdigest()[:12]


def add_memory(content: str) -> str:
    """添加记忆"""
    client = get_client()
    collection = get_collection(client)
    
    memory_id = generate_id(content)
    embedding = get_embedding(content)
    
    collection.add(
        ids=[memory_id],
        embeddings=[embedding],
        documents=[content],
        metadatas=[{"created_at": datetime.now().isoformat()}]
    )
    
    return memory_id


def get_fallback_embedding(text: str) -> List[float]:
    """获取 768 维 embedding（用于搜索旧 collection）"""
    headers = {
        "Authorization": f"Bearer {FALLBACK_EMBEDDING_API_KEY}",
        "Content-Type": "application/json"
    }
    data = {
        "model": FALLBACK_EMBEDDING_MODEL,
        "input": text
    }
    resp = requests.post(FALLBACK_EMBEDDING_API_URL, headers=headers, json=data, timeout=30)
    resp.raise_for_status()
    result = resp.json()
    return result["data"][0]["embedding"]


def search_memory(query: str, top_k: int = 5, use_rerank: bool = False) -> List[dict]:
    """搜索记忆（同时搜索新旧两个 collection）"""
    client = get_client()
    memories = []
    
    # 1. 搜索新 collection（4096 维）
    try:
        new_col = client.get_collection(COLLECTION_NAME)
        if new_col.count() > 0:
            query_embedding = get_embedding(query)
            results = new_col.query(
                query_embeddings=[query_embedding],
                n_results=min(top_k * 2 if use_rerank else top_k, new_col.count())
            )
            if results["documents"][0]:
                for i, doc in enumerate(results["documents"][0]):
                    memories.append({
                        "id": results["ids"][0][i],
                        "content": doc,
                        "score": 1 - results["distances"][0][i] if results["distances"] else 0,
                        "metadata": results["metadatas"][0][i] if results["metadatas"] else {},
                        "source": "4096"
                    })
    except Exception as e:
        print(f"[搜索] 新 collection 失败: {e}")
    
    # 2. 搜索旧 collection（768 维）
    try:
        old_col = client.get_collection(COLLECTION_NAME_FALLBACK)
        if old_col.count() > 0:
            fallback_embedding = get_fallback_embedding(query)
            results = old_col.query(
                query_embeddings=[fallback_embedding],
                n_results=min(top_k * 2 if use_rerank else top_k, old_col.count())
            )
            if results["documents"][0]:
                for i, doc in enumerate(results["documents"][0]):
                    # 去重：如果内容已存在则跳过
                    if not any(m["content"] == doc for m in memories):
                        memories.append({
                            "id": results["ids"][0][i],
                            "content": doc,
                            "score": 1 - results["distances"][0][i] if results["distances"] else 0,
                            "metadata": results["metadatas"][0][i] if results["metadatas"] else {},
                            "source": "768"
                        })
    except Exception as e:
        print(f"[搜索] 旧 collection 失败: {e}")
    
    if not memories:
        return []
    
    # 可选：使用 reranker 重排序
    if use_rerank and len(memories) > 1:
        docs = [m["content"] for m in memories]
        reranked = rerank(query, docs, top_k)
        
        # 重新排序
        reranked_memories = []
        for r in reranked:
            idx = r["index"]
            mem = memories[idx].copy()
            mem["rerank_score"] = r["relevance_score"]
            reranked_memories.append(mem)
        
        return reranked_memories[:top_k]
    
    # 按 score 排序
    memories.sort(key=lambda x: x["score"], reverse=True)
    return memories[:top_k]


def list_memories(limit: int = 20) -> List[dict]:
    """列出所有记忆"""
    client = get_client()
    collection = get_collection(client)
    
    if collection.count() == 0:
        return []
    
    results = collection.get(limit=limit)
    
    memories = []
    for i, doc in enumerate(results["documents"]):
        memories.append({
            "id": results["ids"][i],
            "content": doc,
            "metadata": results["metadatas"][i] if results["metadatas"] else {}
        })
    
    return memories


def delete_memory(memory_id: str) -> bool:
    """删除记忆"""
    client = get_client()
    collection = get_collection(client)
    
    try:
        collection.delete(ids=[memory_id])
        return True
    except Exception:
        return False


def export_memories() -> List[dict]:
    """导出所有记忆"""
    client = get_client()
    collection = get_collection(client)
    
    results = collection.get()
    
    memories = []
    for i, doc in enumerate(results["documents"]):
        memories.append({
            "id": results["ids"][i],
            "content": doc,
            "metadata": results["metadatas"][i] if results["metadatas"] else {}
        })
    
    return memories


def import_memories(memories: List[dict]) -> int:
    """导入记忆"""
    client = get_client()
    collection = get_collection(client)
    
    count = 0
    for mem in memories:
        try:
            content = mem.get("content", "")
            if not content:
                continue
            
            memory_id = mem.get("id") or generate_id(content)
            embedding = get_embedding(content)
            metadata = mem.get("metadata", {})
            if "created_at" not in metadata:
                metadata["created_at"] = datetime.now().isoformat()
            
            collection.add(
                ids=[memory_id],
                embeddings=[embedding],
                documents=[content],
                metadatas=[metadata]
            )
            count += 1
        except Exception as e:
            print(f"导入失败: {mem.get('content', '')[:50]}... - {e}", file=sys.stderr)
    
    return count


def main():
    parser = argparse.ArgumentParser(description="本地向量记忆系统")
    subparsers = parser.add_subparsers(dest="command", help="命令")
    
    # add
    add_parser = subparsers.add_parser("add", help="添加记忆")
    add_parser.add_argument("content", help="记忆内容")
    
    # search
    search_parser = subparsers.add_parser("search", help="搜索记忆")
    search_parser.add_argument("query", help="查询内容")
    search_parser.add_argument("--top_k", type=int, default=5, help="返回数量")
    search_parser.add_argument("--rerank", action="store_true", help="使用 reranker")
    
    # list
    list_parser = subparsers.add_parser("list", help="列出记忆")
    list_parser.add_argument("--limit", type=int, default=20, help="数量限制")
    
    # delete
    delete_parser = subparsers.add_parser("delete", help="删除记忆")
    delete_parser.add_argument("memory_id", help="记忆ID")
    
    # export
    subparsers.add_parser("export", help="导出记忆")
    
    # import
    subparsers.add_parser("import", help="导入记忆")
    
    # count
    subparsers.add_parser("count", help="统计记忆数量")
    
    args = parser.parse_args()
    
    if args.command == "add":
        memory_id = add_memory(args.content)
        print(json.dumps({"success": True, "id": memory_id}))
    
    elif args.command == "search":
        results = search_memory(args.query, args.top_k, args.rerank)
        print(json.dumps({"results": results}, ensure_ascii=False, indent=2))
    
    elif args.command == "list":
        memories = list_memories(args.limit)
        print(json.dumps({"memories": memories, "count": len(memories)}, ensure_ascii=False, indent=2))
    
    elif args.command == "delete":
        success = delete_memory(args.memory_id)
        print(json.dumps({"success": success}))
    
    elif args.command == "export":
        memories = export_memories()
        print(json.dumps(memories, ensure_ascii=False, indent=2))
    
    elif args.command == "import":
        data = json.load(sys.stdin)
        if isinstance(data, list):
            memories = data
        else:
            memories = data.get("memories", [])
        count = import_memories(memories)
        print(json.dumps({"success": True, "imported": count}))
    
    elif args.command == "count":
        client = get_client()
        collection = get_collection(client)
        print(json.dumps({"count": collection.count()}))
    
    else:
        parser.print_help()


if __name__ == "__main__":
    main()
