#!/usr/bin/env python3
"""迁移记忆到 4096 维 Qwen3-VL-Embedding"""

import chromadb
import requests
import time
import sys

CHROMA_PATH = "/home/ubuntu/.hermes/local_memory_db"
OLD_COLLECTION = "memories"
NEW_COLLECTION = "memories_4096"

# 硅基流动 API
API_URL = "https://api.siliconflow.cn/v1/embeddings"
API_KEY = "sk-llicoufjowafeksyzknhxebhxfyaawjrmiiefpswhofheirv"
MODEL = "Qwen/Qwen3-VL-Embedding-8B"

# 批量处理
BATCH_SIZE = 50  # 每批 50 条
RATE_LIMIT = 2000  # 每分钟 2000 次

def get_embeddings_batch(texts):
    """批量获取 embeddings"""
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json"
    }
    data = {
        "model": MODEL,
        "input": texts
    }
    resp = requests.post(API_URL, headers=headers, json=data, timeout=60)
    resp.raise_for_status()
    result = resp.json()
    return [item["embedding"] for item in result["data"]]

def main():
    client = chromadb.PersistentClient(path=CHROMA_PATH)
    
    # 获取旧 collection
    old_col = client.get_collection(OLD_COLLECTION)
    total = old_col.count()
    print(f"旧 collection: {total} 条记忆")
    
    # 创建新 collection
    try:
        new_col = client.get_collection(NEW_COLLECTION)
        existing = new_col.count()
        print(f"新 collection 已存在: {existing} 条，继续迁移...")
    except:
        new_col = client.create_collection(
            name=NEW_COLLECTION,
            metadata={"description": "本地向量记忆 (4096维 Qwen3-VL-Embedding)"}
        )
        existing = 0
        print("创建新 collection")
    
    # 获取所有旧数据
    all_data = old_col.get(include=["documents", "metadatas"])
    ids = all_data["ids"]
    documents = all_data["documents"]
    metadatas = all_data["metadatas"]
    
    # 获取已迁移的 ID
    if existing > 0:
        migrated = set(new_col.get()["ids"])
    else:
        migrated = set()
    
    # 过滤未迁移的
    to_migrate = [(i, doc, meta) for i, doc, meta in zip(ids, documents, metadatas) if i not in migrated]
    print(f"待迁移: {len(to_migrate)} 条")
    
    if not to_migrate:
        print("全部已迁移完成！")
        return
    
    # 批量迁移
    success = 0
    errors = 0
    start_time = time.time()
    
    for i in range(0, len(to_migrate), BATCH_SIZE):
        batch = to_migrate[i:i+BATCH_SIZE]
        batch_ids = [item[0] for item in batch]
        batch_docs = [item[1] for item in batch]
        batch_metas = [item[2] for item in batch]
        
        try:
            # 获取新 embeddings
            embeddings = get_embeddings_batch(batch_docs)
            
            # 存入新 collection
            new_col.add(
                ids=batch_ids,
                documents=batch_docs,
                metadatas=batch_metas,
                embeddings=embeddings
            )
            
            success += len(batch)
            elapsed = time.time() - start_time
            rate = success / elapsed * 60 if elapsed > 0 else 0
            print(f"进度: {success}/{len(to_migrate)} ({success*100//len(to_migrate)}%) | 速率: {rate:.0f}/min")
            
        except Exception as e:
            errors += len(batch)
            print(f"错误: {e}")
            time.sleep(1)
    
    print(f"\n完成！成功: {success}, 错误: {errors}")
    print(f"新 collection 总数: {new_col.count()}")

if __name__ == "__main__":
    main()
