向量数据库查询优化指南

向量数据库查询优化指南

查询方法概述 🔍

向量数据库支持多种查询方法,每种方法都有其特点和适用场景:

1. KNN 查询

import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams

# 初始化客户端
client = QdrantClient("localhost", port=6333)

# 创建集合
client.recreate_collection(
    collection_name="products",
    vectors_config=VectorParams(size=128, distance=Distance.COSINE),
)

# KNN查询示例
def knn_search(query_vector, k=5):
    """执行KNN查询"""
    results = client.search(
        collection_name="products",
        query_vector=query_vector,
        limit=k
    )
    return results

适用场景:

  • 精确匹配需求
  • 数据量适中
  • 实时性要求高

2. 范围查询

def range_search(query_vector, threshold=0.8):
    """执行范围查询"""
    results = client.search(
        collection_name="products",
        query_vector=query_vector,
        score_threshold=threshold
    )
    return results

适用场景:

  • 相似度阈值筛选
  • 质量要求高
  • 召回率优先

3. 混合查询

from qdrant_client.http.models import Filter, FieldCondition, Range

def hybrid_search(query_vector, category, price_range):
    """执行混合查询"""
    filter = Filter(
        must=[
            FieldCondition(
                key="category",
                match={"text": category}
            ),
            FieldCondition(
                key="price",
                range=Range(
                    gte=price_range[0],
                    lte=price_range[1]
                )
            )
        ]
    )
    
    results = client.search(
        collection_name="products",
        query_vector=query_vector,
        query_filter=filter,
        limit=10
    )
    return results

适用场景:

  • 复杂查询条件
  • 精确过滤需求
  • 多维度筛选

参数调优 ⚙️

1. 批量查询优化

def batch_search_optimizer(query_vectors, batch_size=32):
    """批量查询优化器"""
    results = []
    
    # 分批处理
    for i in range(0, len(query_vectors), batch_size):
        batch = query_vectors[i:i + batch_size]
        
        # 并行查询
        batch_results = client.search_batch(
            collection_name="products",
            requests=[
                {
                    "vector": vec.tolist(),
                    "limit": 5,
                    "with_payload": True
                }
                for vec in batch
            ]
        )
        
        results.extend(batch_results)
    
    return results

2. 缓存策略

from functools import lru_cache
import time

class QueryCache:
    def __init__(self, capacity=1000, ttl=3600):
        self.capacity = capacity
        self.ttl = ttl
        self.cache = {}
        
    @lru_cache(maxsize=1000)
    def get(self, query_key):
        """获取缓存结果"""
        if query_key in self.cache:
            result, timestamp = self.cache[query_key]
            if time.time() - timestamp < self.ttl:
                return result
            else:
                del self.cache[query_key]
        return None
        
    def set(self, query_key, result):
        """设置缓存结果"""
        if len(self.cache) >= self.capacity:
            # 删除最旧的缓存
            oldest = min(self.cache.items(), key=lambda x: x[1][1])
            del self.cache[oldest[0]]
        
        self.cache[query_key] = (result, time.time())

3. 向量压缩

def compress_vectors(vectors, bits=8):
    """向量压缩"""
    # 计算向量范围
    v_min = vectors.min(axis=0)
    v_max = vectors.max(axis=0)
    v_range = v_max - v_min
    
    # 量化
    quantized = np.round(
        (vectors - v_min) / v_range * (2**bits - 1)
    ).astype(np.uint8)
    
    # 解量化函数
    def dequantize(q_vectors):
        return q_vectors * v_range / (2**bits - 1) + v_min
    
    return quantized, dequantize

过滤机制 🎯

1. 预过滤

def prefilter_search(query_vector, filters):
    """预过滤搜索"""
    # 构建过滤条件
    filter_conditions = Filter(
        must=[
            FieldCondition(
                key=key,
                match={"value": value}
            )
            for key, value in filters.items()
        ]
    )
    
    # 执行过滤后的向量搜索
    results = client.search(
        collection_name="products",
        query_vector=query_vector,
        query_filter=filter_conditions,
        limit=10
    )
    
    return results

2. 后过滤

def postfilter_search(query_vector, filter_func):
    """后过滤搜索"""
    # 先获取更多结果
    results = client.search(
        collection_name="products",
        query_vector=query_vector,
        limit=50  # 获取更多结果用于过滤
    )
    
    # 应用自定义过滤函数
    filtered_results = [
        r for r in results
        if filter_func(r.payload)
    ]
    
    return filtered_results[:10]  # 返回前10个过滤后的结果

3. 动态过滤

class DynamicFilter:
    def __init__(self):
        self.filters = {}
        
    def add_filter(self, name, condition):
        """添加过滤条件"""
        self.filters[name] = condition
        
    def remove_filter(self, name):
        """移除过滤条件"""
        if name in self.filters:
            del self.filters[name]
    
    def apply_filters(self, query_vector):
        """应用所有过滤条件"""
        filter_conditions = Filter(
            must=[
                condition
                for condition in self.filters.values()
            ]
        )
        
        results = client.search(
            collection_name="products",
            query_vector=query_vector,
            query_filter=filter_conditions,
            limit=10
        )
        
        return results

性能监控 📊

1. 查询性能追踪

import time
from dataclasses import dataclass
from typing import List, Dict

@dataclass
class QueryMetrics:
    latency: float
    result_count: int
    filter_count: int
    cache_hit: bool

class QueryMonitor:
    def __init__(self):
        self.metrics: List[QueryMetrics] = []
        
    def track_query(self, func):
        """查询性能追踪装饰器"""
        def wrapper(*args, **kwargs):
            start_time = time.time()
            
            # 执行查询
            results = func(*args, **kwargs)
            
            # 记录指标
            latency = time.time() - start_time
            self.metrics.append(QueryMetrics(
                latency=latency,
                result_count=len(results),
                filter_count=len(kwargs.get('filters', {})),
                cache_hit=hasattr(results, '_cache_info')
            ))
            
            return results
        return wrapper
    
    def get_statistics(self) -> Dict:
        """获取性能统计"""
        if not self.metrics:
            return {}
            
        latencies = [m.latency for m in self.metrics]
        return {
            'avg_latency': sum(latencies) / len(latencies),
            'max_latency': max(latencies),
            'min_latency': min(latencies),
            'total_queries': len(self.metrics),
            'cache_hit_rate': sum(1 for m in self.metrics if m.cache_hit) / len(self.metrics)
        }

2. 资源监控

import psutil
import threading
from typing import Dict

class ResourceMonitor:
    def __init__(self, interval=1):
        self.interval = interval
        self.stats: Dict = {}
        self._stop = threading.Event()
        
    def start(self):
        """开始监控"""
        thread = threading.Thread(target=self._monitor)
        thread.start()
        
    def stop(self):
        """停止监控"""
        self._stop.set()
        
    def _monitor(self):
        """监控线程"""
        while not self._stop.is_set():
            self.stats = {
                'cpu_percent': psutil.cpu_percent(interval=1),
                'memory_percent': psutil.virtual_memory().percent,
                'disk_io': psutil.disk_io_counters(),
                'network_io': psutil.net_io_counters()
            }
            self._stop.wait(self.interval)
    
    def get_stats(self) -> Dict:
        """获取监控数据"""
        return self.stats.copy()

3. 告警系统

import smtplib
from email.mime.text import MIMEText
from typing import List, Callable

class AlertSystem:
    def __init__(self, thresholds: Dict[str, float]):
        self.thresholds = thresholds
        self.handlers: List[Callable] = []
        
    def add_handler(self, handler: Callable):
        """添加告警处理器"""
        self.handlers.append(handler)
        
    def check_metrics(self, metrics: Dict):
        """检查指标是否超过阈值"""
        alerts = []
        
        for metric, value in metrics.items():
            if metric in self.thresholds:
                if value > self.thresholds[metric]:
                    alerts.append(f"{metric} exceeded threshold: {value}")
        
        if alerts:
            self._trigger_alerts(alerts)
    
    def _trigger_alerts(self, alerts: List[str]):
        """触发告警"""
        message = "\n".join(alerts)
        for handler in self.handlers:
            handler(message)
            
    @staticmethod
    def email_handler(smtp_config: Dict):
        """邮件告警处理器"""
        def send_email(message: str):
            msg = MIMEText(message)
            msg['Subject'] = 'Vector DB Alert'
            msg['From'] = smtp_config['from']
            msg['To'] = smtp_config['to']
            
            with smtplib.SMTP(smtp_config['host'], smtp_config['port']) as server:
                server.login(smtp_config['user'], smtp_config['password'])
                server.send_message(msg)
        
        return send_email

最佳实践 💡

1. 查询优化建议

  • 使用批量查询减少网络开销
  • 合理设置缓存策略
  • 选择合适的过滤时机

2. 性能优化要点

  • 监控关键指标
  • 设置合理的告警阈值
  • 定期优化和维护

3. 注意事项

  • 平衡查询精度和性能
  • 合理使用缓存
  • 及时处理告警

小结 📝

本章我们深入学习了向量数据库的查询优化技术:

  1. 查询方法

    • KNN查询
    • 范围查询
    • 混合查询
  2. 参数调优

    • 批量优化
    • 缓存策略
    • 向量压缩
  3. 过滤机制

    • 预过滤
    • 后过滤
    • 动态过滤
  4. 性能监控

    • 查询追踪
    • 资源监控
    • 告警系统

下一章,我们将学习向量数据库的运维管理!