API封装

本节将介绍如何封装LLM API,创建一个健壮、易用的客户端库。我们将学习如何处理错误、实现重试机制和速率限制。

基础API客户端

基类设计

from typing import Dict, List, Optional
import requests
from abc import ABC, abstractmethod

class BaseLLMClient(ABC):
    """LLM API客户端基类"""
    
    def __init__(self, api_key: str, base_url: str):
        """
        初始化客户端
        
        Args:
            api_key: API密钥
            base_url: API基础URL
        """
        self.api_key = api_key
        self.base_url = base_url.rstrip('/')
        self.session = requests.Session()
        self.session.headers.update({
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json"
        })
    
    @abstractmethod
    def chat_completion(self, messages: List[Dict[str, str]], **kwargs) -> Dict:
        """
        聊天补全API
        
        Args:
            messages: 消息列表
            **kwargs: 其他参数
        
        Returns:
            API响应
        """
        pass
    
    @abstractmethod
    def embeddings(self, text: str) -> List[float]:
        """
        文本嵌入API
        
        Args:
            text: 输入文本
        
        Returns:
            嵌入向量
        """
        pass

OpenAI客户端实现

import json
from typing import Dict, List, Optional

class OpenAIClient(BaseLLMClient):
    """OpenAI API客户端"""
    
    def __init__(self, api_key: str):
        super().__init__(
            api_key=api_key,
            base_url="https://api.openai.com/v1"
        )
    
    def chat_completion(
        self,
        messages: List[Dict[str, str]],
        model: str = "gpt-3.5-turbo",
        temperature: float = 0.7,
        max_tokens: Optional[int] = None,
        **kwargs
    ) -> Dict:
        """
        调用chat completion API
        
        Args:
            messages: 消息列表
            model: 模型名称
            temperature: 温度参数
            max_tokens: 最大生成token数
            **kwargs: 其他参数
        
        Returns:
            API响应
        """
        endpoint = f"{self.base_url}/chat/completions"
        
        data = {
            "model": model,
            "messages": messages,
            "temperature": temperature,
            **kwargs
        }
        
        if max_tokens is not None:
            data["max_tokens"] = max_tokens
        
        response = self.session.post(endpoint, json=data)
        response.raise_for_status()
        
        return response.json()
    
    def embeddings(
        self,
        text: str,
        model: str = "text-embedding-ada-002"
    ) -> List[float]:
        """
        获取文本嵌入向量
        
        Args:
            text: 输入文本
            model: 模型名称
        
        Returns:
            嵌入向量
        """
        endpoint = f"{self.base_url}/embeddings"
        
        response = self.session.post(endpoint, json={
            "model": model,
            "input": text
        })
        response.raise_for_status()
        
        return response.json()["data"][0]["embedding"]

错误处理

自定义异常

class LLMError(Exception):
    """LLM API错误基类"""
    pass

class AuthenticationError(LLMError):
    """认证错误"""
    pass

class RateLimitError(LLMError):
    """速率限制错误"""
    pass

class APIError(LLMError):
    """API调用错误"""
    def __init__(self, message: str, status_code: int, response: Dict):
        super().__init__(message)
        self.status_code = status_code
        self.response = response

错误处理装饰器

from functools import wraps
import requests

def handle_api_errors(func):
    """处理API错误的装饰器"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except requests.exceptions.HTTPError as e:
            response = e.response
            status_code = response.status_code
            
            try:
                error_data = response.json()
            except json.JSONDecodeError:
                error_data = {"error": response.text}
            
            if status_code == 401:
                raise AuthenticationError("Invalid API key")
            elif status_code == 429:
                raise RateLimitError("Rate limit exceeded")
            else:
                raise APIError(
                    f"API request failed: {error_data.get('error', 'Unknown error')}",
                    status_code,
                    error_data
                )
        except requests.exceptions.ConnectionError:
            raise LLMError("Failed to connect to API server")
        except requests.exceptions.Timeout:
            raise LLMError("API request timed out")
        except Exception as e:
            raise LLMError(f"Unexpected error: {str(e)}")
    
    return wrapper

重试机制

重试装饰器

import time
from typing import Type, Tuple

def retry_with_exponential_backoff(
    max_retries: int = 3,
    base_delay: float = 1,
    max_delay: float = 60,
    retryable_exceptions: Tuple[Type[Exception], ...] = (
        RateLimitError,
        requests.exceptions.ConnectionError,
        requests.exceptions.Timeout
    )
):
    """
    实现指数退避重试的装饰器
    
    Args:
        max_retries: 最大重试次数
        base_delay: 基础延迟时间(秒)
        max_delay: 最大延迟时间(秒)
        retryable_exceptions: 可重试的异常类型
    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            for attempt in range(max_retries + 1):
                try:
                    return func(*args, **kwargs)
                except retryable_exceptions as e:
                    if attempt == max_retries:
                        raise
                    
                    delay = min(
                        base_delay * (2 ** attempt),
                        max_delay
                    )
                    
                    print(f"Attempt {attempt + 1} failed: {str(e)}")
                    print(f"Retrying in {delay} seconds...")
                    
                    time.sleep(delay)
            
            return None  # 不应该到达这里
        return wrapper
    return decorator

速率限制

令牌桶算法

import time
from threading import Lock

class TokenBucket:
    """令牌桶速率限制器"""
    
    def __init__(
        self,
        tokens_per_second: float,
        max_tokens: int
    ):
        """
        初始化令牌桶
        
        Args:
            tokens_per_second: 每秒补充的令牌数
            max_tokens: 桶的最大容量
        """
        self.tokens_per_second = tokens_per_second
        self.max_tokens = max_tokens
        self.tokens = max_tokens
        self.last_update = time.time()
        self.lock = Lock()
    
    def _add_tokens(self):
        """补充令牌"""
        now = time.time()
        time_passed = now - self.last_update
        new_tokens = time_passed * self.tokens_per_second
        
        self.tokens = min(
            self.tokens + new_tokens,
            self.max_tokens
        )
        self.last_update = now
    
    def acquire(self, tokens: int = 1, timeout: Optional[float] = None) -> bool:
        """
        获取令牌
        
        Args:
            tokens: 需要的令牌数
            timeout: 超时时间(秒)
        
        Returns:
            是否获取成功
        """
        start_time = time.time()
        
        while True:
            with self.lock:
                self._add_tokens()
                
                if self.tokens >= tokens:
                    self.tokens -= tokens
                    return True
            
            if timeout is not None:
                if time.time() - start_time >= timeout:
                    return False
            
            time.sleep(0.1)

使用速率限制

class RateLimitedClient(OpenAIClient):
    """带速率限制的API客户端"""
    
    def __init__(
        self,
        api_key: str,
        tokens_per_minute: int = 60
    ):
        super().__init__(api_key)
        self.rate_limiter = TokenBucket(
            tokens_per_second=tokens_per_minute / 60,
            max_tokens=tokens_per_minute
        )
    
    @retry_with_exponential_backoff()
    @handle_api_errors
    def chat_completion(self, messages: List[Dict[str, str]], **kwargs) -> Dict:
        """带速率限制的chat completion调用"""
        if not self.rate_limiter.acquire(timeout=60):
            raise RateLimitError("Failed to acquire rate limit token")
        
        return super().chat_completion(messages, **kwargs)

完整示例

客户端使用

# 创建客户端
client = RateLimitedClient(
    api_key="your-api-key",
    tokens_per_minute=60
)

try:
    # 调用chat completion API
    response = client.chat_completion(
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": "Tell me about Python."}
        ],
        temperature=0.7,
        max_tokens=100
    )
    
    # 处理响应
    message = response["choices"][0]["message"]["content"]
    print(f"Assistant: {message}")

except AuthenticationError:
    print("Invalid API key")
except RateLimitError:
    print("Rate limit exceeded")
except APIError as e:
    print(f"API error: {e}")
except LLMError as e:
    print(f"General error: {e}")

异步客户端

import asyncio
import aiohttp
from typing import Dict, List, Optional

class AsyncOpenAIClient(BaseLLMClient):
    """异步OpenAI API客户端"""
    
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.base_url = "https://api.openai.com/v1"
        self.session = None
    
    async def __aenter__(self):
        """创建异步会话"""
        self.session = aiohttp.ClientSession(headers={
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        })
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """关闭异步会话"""
        if self.session:
            await self.session.close()
    
    @retry_with_exponential_backoff()
    @handle_api_errors
    async def chat_completion(
        self,
        messages: List[Dict[str, str]],
        **kwargs
    ) -> Dict:
        """异步chat completion调用"""
        endpoint = f"{self.base_url}/chat/completions"
        
        async with self.session.post(endpoint, json={
            "messages": messages,
            **kwargs
        }) as response:
            response.raise_for_status()
            return await response.json()

# 使用异步客户端
async def main():
    async with AsyncOpenAIClient("your-api-key") as client:
        response = await client.chat_completion([
            {"role": "user", "content": "Hello!"}
        ])
        print(response)

asyncio.run(main())

最佳实践

  1. 错误处理:

    • 定义清晰的异常层次
    • 提供详细的错误信息
    • 适当的日志记录
  2. 重试策略:

    • 使用指数退避
    • 只重试可恢复的错误
    • 设置最大重试次数
  3. 速率限制:

    • 实现平滑的限制
    • 考虑并发情况
    • 提供超时机制
  4. 代码组织:

    • 使用装饰器分离关注点
    • 提供同步和异步接口
    • 保持代码可测试性

下一步

现在您已经学会了如何封装LLM API,接下来我们将通过实战练习来应用这些知识。