RAG系统优化:性能、质量与成本的平衡

Header Image

性能优化 🚀

检索效率提升

优化向量检索的性能是提升系统响应速度的关键:

import faiss
import numpy as np
from typing import List, Dict, Optional
import time

class OptimizedRetriever:
    def __init__(
        self,
        dimension: int,
        index_type: str = "IVF",
        metric: str = "cosine",
        nlist: int = 100
    ):
        """初始化检索器
        
        Args:
            dimension: 向量维度
            index_type: 索引类型 ('flat', 'IVF', 'HNSW')
            metric: 距离度量方式 ('cosine', 'l2')
            nlist: IVF索引的聚类中心数量
        """
        self.dimension = dimension
        self.index_type = index_type
        self.metric = metric
        
        # 创建FAISS索引
        if metric == "cosine":
            self.index = faiss.IndexFlatIP(dimension)  # 内积用于余弦相似度
        else:
            self.index = faiss.IndexFlatL2(dimension)  # L2距离
            
        if index_type == "IVF":
            # 创建IVF索引
            quantizer = faiss.IndexFlatL2(dimension)
            self.index = faiss.IndexIVFFlat(
                quantizer, dimension, nlist
            )
        elif index_type == "HNSW":
            # 创建HNSW索引
            self.index = faiss.IndexHNSWFlat(dimension, 32)
            
    def batch_add(
        self,
        vectors: np.ndarray,
        batch_size: int = 10000
    ):
        """批量添加向量"""
        total_vectors = len(vectors)
        for i in range(0, total_vectors, batch_size):
            batch = vectors[i:min(i + batch_size, total_vectors)]
            self.index.add(batch)
            
    def search(
        self,
        query_vector: np.ndarray,
        k: int = 10,
        nprobe: Optional[int] = None
    ) -> tuple:
        """执行向量检索
        
        Args:
            query_vector: 查询向量
            k: 返回最相似的k个结果
            nprobe: IVF索引搜索的聚类中心数量
            
        Returns:
            (distances, indices)
        """
        if nprobe and self.index_type == "IVF":
            self.index.nprobe = nprobe
            
        start_time = time.time()
        distances, indices = self.index.search(
            query_vector.reshape(1, -1), k
        )
        search_time = time.time() - start_time
        
        return distances[0], indices[0], search_time

缓存策略

实现高效的缓存机制:

from functools import lru_cache
import hashlib
import json
from typing import Any, Optional
import redis

class CacheManager:
    def __init__(
        self,
        redis_host: str = "localhost",
        redis_port: int = 6379,
        ttl: int = 3600  # 1小时过期
    ):
        self.redis_client = redis.Redis(
            host=redis_host,
            port=redis_port,
            decode_responses=True
        )
        self.ttl = ttl
        
    def _generate_key(self, data: Any) -> str:
        """生成缓存键"""
        if isinstance(data, str):
            key_data = data
        else:
            key_data = json.dumps(data, sort_keys=True)
        return hashlib.md5(key_data.encode()).hexdigest()
    
    def get(self, key_data: Any) -> Optional[str]:
        """获取缓存数据"""
        key = self._generate_key(key_data)
        return self.redis_client.get(key)
    
    def set(self, key_data: Any, value: str):
        """设置缓存数据"""
        key = self._generate_key(key_data)
        self.redis_client.setex(key, self.ttl, value)
        
    @lru_cache(maxsize=1000)
    def get_local_cache(self, key: str) -> Optional[str]:
        """本地缓存"""
        return self.get(key)

class OptimizedRAG:
    def __init__(self):
        self.cache_manager = CacheManager()
        
    def query(self, question: str) -> Dict[str, Any]:
        """优化后的查询处理"""
        # 检查缓存
        cached_result = self.cache_manager.get(question)
        if cached_result:
            return json.loads(cached_result)
            
        # 处理查询
        result = self._process_query(question)
        
        # 存入缓存
        self.cache_manager.set(
            question,
            json.dumps(result)
        )
        
        return result
    
    def _process_query(self, question: str) -> Dict[str, Any]:
        """实际的查询处理逻辑"""
        # 实现查询处理逻辑
        pass

质量优化 📈

召回率提升

class RecallOptimizer:
    def __init__(self):
        self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
        
    def generate_query_variations(
        self,
        query: str,
        num_variations: int = 3
    ) -> List[str]:
        """生成查询变体"""
        prompt = f"""请为以下查询生成{num_variations}个语义相似的变体:

查询:{query}

要求:
1. 保持原始语义
2. 使用不同的表达方式
3. 考虑可能的上下文

请直接返回变体列表,每行一个。"""
        
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "你是一个专业的查询重写专家。"},
                {"role": "user", "content": prompt}
            ],
            temperature=0.7
        )
        
        variations = response.choices[0].message.content.strip().split('\n')
        return variations[:num_variations]
    
    def hybrid_search(
        self,
        query: str,
        documents: List[Dict],
        top_k: int = 10
    ) -> List[Dict]:
        """混合检索策略"""
        # 生成查询变体
        query_variations = self.generate_query_variations(query)
        
        # 向量检索结果
        vector_results = self._vector_search(
            query_variations + [query],
            documents,
            top_k
        )
        
        # 关键词检索结果
        keyword_results = self._keyword_search(
            query,
            documents,
            top_k
        )
        
        # 合并结果
        combined_results = self._merge_results(
            vector_results,
            keyword_results
        )
        
        return combined_results[:top_k]

准确率优化

class AccuracyOptimizer:
    def __init__(self):
        self.client = OpenAI()
        
    def validate_answer(
        self,
        question: str,
        answer: str,
        context: str
    ) -> Dict[str, Any]:
        """验证答案质量"""
        prompt = f"""请验证以下答案的准确性:

问题:{question}
上下文:{context}
答案:{answer}

请检查:
1. 答案是否完全基于上下文
2. 是否存在事实错误
3. 是否有逻辑问题
4. 是否需要补充说明

请以JSON格式返回检查结果。"""
        
        response = self.client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": "你是一个专业的答案质量验证专家。"},
                {"role": "user", "content": prompt}
            ],
            temperature=0.3
        )
        
        try:
            result = json.loads(response.choices[0].message.content)
            return result
        except:
            return {
                "is_accurate": False,
                "issues": ["无法解析验证结果"]
            }
            
    def improve_answer(
        self,
        answer: str,
        validation_result: Dict[str, Any]
    ) -> str:
        """优化答案"""
        if validation_result.get("is_accurate", False):
            return answer
            
        prompt = f"""请根据以下问题修改答案:

原始答案:{answer}

发现的问题:
{json.dumps(validation_result.get('issues', []), indent=2, ensure_ascii=False)}

请提供修改后的答案。"""
        
        response = self.client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": "你是一个专业的答案优化专家。"},
                {"role": "user", "content": prompt}
            ],
            temperature=0.3
        )
        
        return response.choices[0].message.content

成本控制 💰

Token优化

class TokenOptimizer:
    def __init__(self, model_name: str = "gpt-3.5-turbo"):
        self.tokenizer = tiktoken.encoding_for_model(model_name)
        
    def estimate_tokens(self, text: str) -> int:
        """估算token数量"""
        return len(self.tokenizer.encode(text))
    
    def optimize_context(
        self,
        context: str,
        max_tokens: int = 2000
    ) -> str:
        """优化上下文长度"""
        tokens = self.tokenizer.encode(context)
        
        if len(tokens) <= max_tokens:
            return context
            
        # 保留最重要的部分
        optimized_tokens = tokens[:max_tokens]
        return self.tokenizer.decode(optimized_tokens)
    
    def chunk_long_text(
        self,
        text: str,
        chunk_size: int = 1000
    ) -> List[str]:
        """分块处理长文本"""
        tokens = self.tokenizer.encode(text)
        chunks = []
        
        for i in range(0, len(tokens), chunk_size):
            chunk_tokens = tokens[i:i + chunk_size]
            chunks.append(self.tokenizer.decode(chunk_tokens))
            
        return chunks

资源监控

import psutil
import logging
from dataclasses import dataclass
from datetime import datetime

@dataclass
class ResourceMetrics:
    cpu_percent: float
    memory_percent: float
    disk_usage: float
    timestamp: datetime

class ResourceMonitor:
    def __init__(self, log_file: str = "rag_metrics.log"):
        self.logger = logging.getLogger("RAGMonitor")
        handler = logging.FileHandler(log_file)
        handler.setFormatter(
            logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
        )
        self.logger.addHandler(handler)
        self.logger.setLevel(logging.INFO)
        
    def get_metrics(self) -> ResourceMetrics:
        """获取资源使用指标"""
        return ResourceMetrics(
            cpu_percent=psutil.cpu_percent(),
            memory_percent=psutil.virtual_memory().percent,
            disk_usage=psutil.disk_usage('/').percent,
            timestamp=datetime.now()
        )
        
    def log_metrics(self, metrics: ResourceMetrics):
        """记录资源指标"""
        self.logger.info(
            f"CPU: {metrics.cpu_percent}% | "
            f"Memory: {metrics.memory_percent}% | "
            f"Disk: {metrics.disk_usage}%"
        )
        
    def check_thresholds(
        self,
        metrics: ResourceMetrics,
        cpu_threshold: float = 80,
        memory_threshold: float = 80,
        disk_threshold: float = 80
    ) -> bool:
        """检查资源使用是否超过阈值"""
        if metrics.cpu_percent > cpu_threshold:
            self.logger.warning(f"CPU使用率过高: {metrics.cpu_percent}%")
            return False
            
        if metrics.memory_percent > memory_threshold:
            self.logger.warning(f"内存使用率过高: {metrics.memory_percent}%")
            return False
            
        if metrics.disk_usage > disk_threshold:
            self.logger.warning(f"磁盘使用率过高: {metrics.disk_usage}%")
            return False
            
        return True

实践建议 💡

  1. 性能优化

    • 选择合适的索引类型
    • 实现多级缓存
    • 使用异步处理
  2. 质量优化

    • 实施混合检索
    • 添加答案验证
    • 持续收集反馈
  3. 成本控制

    • 优化Token使用
    • 监控资源消耗
    • 实施缓存策略

监控指标 📊

  1. 性能指标

    • 平均响应时间
    • 检索延迟
    • 生成延迟
  2. 质量指标

    • 答案准确率
    • 用户满意度
    • 召回率/精确率
  3. 成本指标

    • Token消耗量
    • API调用次数
    • 资源使用率

小结 📝

本章我们深入探讨了RAG系统的优化策略:

  1. 性能优化

    • 检索效率提升
    • 缓存策略实现
    • 异步处理优化
  2. 质量优化

    • 召回率提升
    • 准确率优化
    • 答案验证
  3. 成本控制

    • Token优化
    • 资源监控
    • 成本管理

通过合理的优化策略,我们可以在性能、质量和成本之间找到最佳平衡点,构建高效可靠的生产级RAG系统。