设计模式
本节将介绍Python中常用的设计模式,这些模式能帮助我们更好地组织和管理代码。我们将重点关注在LLM应用开发中特别有用的设计模式。
单例模式
基本实现
class Singleton:
"""基本单例模式"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
# 使用单例
s1 = Singleton()
s2 = Singleton()
print(s1 is s2) # 输出: True
线程安全单例
from threading import Lock
class ThreadSafeSingleton:
"""线程安全的单例模式"""
_instance = None
_lock = Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
# 双重检查锁定
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""
确保即使多次调用__init__也是安全的
"""
pass
# 配置管理器示例
class ConfigManager(ThreadSafeSingleton):
"""配置管理器"""
def __init__(self):
super().__init__()
self._config = {}
def set(self, key: str, value: str):
"""设置配置"""
self._config[key] = value
def get(self, key: str, default: str = None) -> str:
"""获取配置"""
return self._config.get(key, default)
# 使用配置管理器
config = ConfigManager()
config.set("api_key", "your-api-key")
# 在其他地方获取同一个实例
same_config = ConfigManager()
print(same_config.get("api_key")) # 输出: your-api-key
工厂模式
简单工厂
from abc import ABC, abstractmethod
from typing import Dict, List
class LLMClient(ABC):
"""LLM客户端接口"""
@abstractmethod
def generate(self, prompt: str) -> str:
"""生成文本"""
pass
class OpenAIClient(LLMClient):
"""OpenAI客户端"""
def generate(self, prompt: str) -> str:
return f"OpenAI: {prompt}"
class AnthropicClient(LLMClient):
"""Anthropic客户端"""
def generate(self, prompt: str) -> str:
return f"Anthropic: {prompt}"
class LLMFactory:
"""LLM客户端工厂"""
@staticmethod
def create(provider: str) -> LLMClient:
"""创建LLM客户端"""
if provider == "openai":
return OpenAIClient()
elif provider == "anthropic":
return AnthropicClient()
else:
raise ValueError(f"Unknown provider: {provider}")
# 使用工厂
client = LLMFactory.create("openai")
response = client.generate("Hello!")
工厂方法
class LLMClientFactory(ABC):
"""LLM客户端工厂接口"""
@abstractmethod
def create_client(self) -> LLMClient:
"""创建客户端"""
pass
def generate_text(self, prompt: str) -> str:
"""生成文本的模板方法"""
client = self.create_client()
return client.generate(prompt)
class OpenAIFactory(LLMClientFactory):
"""OpenAI客户端工厂"""
def create_client(self) -> LLMClient:
return OpenAIClient()
class AnthropicFactory(LLMClientFactory):
"""Anthropic客户端工厂"""
def create_client(self) -> LLMClient:
return AnthropicClient()
# 使用工厂方法
factory = OpenAIFactory()
response = factory.generate_text("Hello!")
观察者模式
基本实现
from abc import ABC, abstractmethod
from typing import List, Any
class Observer(ABC):
"""观察者接口"""
@abstractmethod
def update(self, subject: 'Subject'):
"""更新方法"""
pass
class Subject:
"""被观察者"""
def __init__(self):
self._observers: List[Observer] = []
self._state = None
def attach(self, observer: Observer):
"""添加观察者"""
self._observers.append(observer)
def detach(self, observer: Observer):
"""移除观察者"""
self._observers.remove(observer)
def notify(self):
"""通知所有观察者"""
for observer in self._observers:
observer.update(self)
@property
def state(self) -> Any:
"""获取状态"""
return self._state
@state.setter
def state(self, value: Any):
"""设置状态"""
self._state = value
self.notify()
# API调用监控示例
class APIMonitor(Subject):
"""API监控器"""
def __init__(self):
super().__init__()
self._calls = 0
self._errors = 0
def record_call(self):
"""记录API调用"""
self._calls += 1
self.state = {"calls": self._calls, "errors": self._errors}
def record_error(self):
"""记录错误"""
self._errors += 1
self.state = {"calls": self._calls, "errors": self._errors}
class AlertObserver(Observer):
"""告警观察者"""
def __init__(self, error_threshold: int):
self.error_threshold = error_threshold
def update(self, subject: Subject):
state = subject.state
if state["errors"] >= self.error_threshold:
print(f"Alert: Error count ({state['errors']}) "
f"exceeded threshold!")
class LogObserver(Observer):
"""日志观察者"""
def update(self, subject: Subject):
state = subject.state
print(f"Log: API calls: {state['calls']}, "
f"Errors: {state['errors']}")
# 使用观察者模式
monitor = APIMonitor()
monitor.attach(AlertObserver(error_threshold=3))
monitor.attach(LogObserver())
# 模拟API调用
monitor.record_call()
monitor.record_call()
monitor.record_error()
monitor.record_error()
monitor.record_error() # 触发告警
策略模式
基本实现
from abc import ABC, abstractmethod
from typing import List, Dict
class RetryStrategy(ABC):
"""重试策略接口"""
@abstractmethod
def should_retry(
self,
attempt: int,
error: Exception
) -> bool:
"""判断是否应该重试"""
pass
@abstractmethod
def get_delay(self, attempt: int) -> float:
"""获取重试延迟时间"""
pass
class ExponentialBackoff(RetryStrategy):
"""指数退避策略"""
def __init__(
self,
max_attempts: int,
base_delay: float
):
self.max_attempts = max_attempts
self.base_delay = base_delay
def should_retry(
self,
attempt: int,
error: Exception
) -> bool:
return (attempt < self.max_attempts and
isinstance(error, (TimeoutError, ConnectionError)))
def get_delay(self, attempt: int) -> float:
return self.base_delay * (2 ** attempt)
class FixedRetry(RetryStrategy):
"""固定间隔重试策略"""
def __init__(
self,
max_attempts: int,
delay: float
):
self.max_attempts = max_attempts
self.delay = delay
def should_retry(
self,
attempt: int,
error: Exception
) -> bool:
return attempt < self.max_attempts
def get_delay(self, attempt: int) -> float:
return self.delay
# API客户端示例
class APIClient:
"""API客户端"""
def __init__(
self,
retry_strategy: RetryStrategy
):
self.retry_strategy = retry_strategy
def call_api(self) -> str:
"""调用API"""
attempt = 0
while True:
try:
# 模拟API调用
if attempt < 2:
raise TimeoutError("API timeout")
return "Success"
except Exception as e:
if not self.retry_strategy.should_retry(
attempt, e
):
raise
import time
delay = self.retry_strategy.get_delay(attempt)
time.sleep(delay)
attempt += 1
# 使用策略模式
client = APIClient(
ExponentialBackoff(
max_attempts=3,
base_delay=1
)
)
try:
result = client.call_api()
print(result)
except Exception as e:
print(f"Failed after retries: {e}")
实际应用示例
模型选择器
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import random
class ModelStrategy(ABC):
"""模型选择策略"""
@abstractmethod
def select_model(
self,
prompt: str,
context: Dict
) -> str:
"""选择模型"""
pass
class CostBasedStrategy(ModelStrategy):
"""基于成本的选择策略"""
def __init__(self, budget: float):
self.budget = budget
self.costs = {
"gpt-4": 0.03,
"gpt-3.5-turbo": 0.002,
"claude-2": 0.01
}
def select_model(
self,
prompt: str,
context: Dict
) -> str:
# 根据预算选择模型
if context.get("importance") == "high":
if self.budget >= self.costs["gpt-4"]:
return "gpt-4"
return "gpt-3.5-turbo"
class PerformanceBasedStrategy(ModelStrategy):
"""基于性能的选择策略"""
def __init__(self):
self.performance_scores = {
"gpt-4": 0.95,
"gpt-3.5-turbo": 0.85,
"claude-2": 0.90
}
def select_model(
self,
prompt: str,
context: Dict
) -> str:
# 根据任务类型选择模型
task_type = context.get("task_type", "general")
if task_type == "coding":
return "gpt-4"
elif task_type == "writing":
return "claude-2"
return "gpt-3.5-turbo"
class LoadBalancingStrategy(ModelStrategy):
"""负载均衡策略"""
def __init__(self):
self.models = ["gpt-4", "gpt-3.5-turbo", "claude-2"]
self.current_index = 0
def select_model(
self,
prompt: str,
context: Dict
) -> str:
# 轮询选择模型
model = self.models[self.current_index]
self.current_index = (
self.current_index + 1
) % len(self.models)
return model
class ModelSelector:
"""模型选择器"""
def __init__(self, strategy: ModelStrategy):
self.strategy = strategy
def set_strategy(self, strategy: ModelStrategy):
"""设置策略"""
self.strategy = strategy
def select_model(
self,
prompt: str,
context: Dict = None
) -> str:
"""选择模型"""
if context is None:
context = {}
return self.strategy.select_model(prompt, context)
# 使用模型选择器
selector = ModelSelector(
CostBasedStrategy(budget=0.1)
)
# 高重要性任务
model = selector.select_model(
"Complex analysis",
{"importance": "high"}
)
print(f"Selected model: {model}")
# 切换到性能策略
selector.set_strategy(PerformanceBasedStrategy())
# 编码任务
model = selector.select_model(
"Write a function",
{"task_type": "coding"}
)
print(f"Selected model: {model}")
插件系统
from abc import ABC, abstractmethod
from typing import Dict, List, Any
import json
class Plugin(ABC):
"""插件接口"""
@abstractmethod
def process(
self,
input_data: Any
) -> Any:
"""处理数据"""
pass
class PluginManager:
"""插件管理器"""
def __init__(self):
self._plugins: Dict[str, Plugin] = {}
def register(
self,
name: str,
plugin: Plugin
):
"""注册插件"""
self._plugins[name] = plugin
def unregister(self, name: str):
"""注销插件"""
self._plugins.pop(name, None)
def get_plugin(
self,
name: str
) -> Optional[Plugin]:
"""获取插件"""
return self._plugins.get(name)
def process(
self,
name: str,
input_data: Any
) -> Any:
"""使用插件处理数据"""
plugin = self.get_plugin(name)
if plugin is None:
raise ValueError(f"Plugin {name} not found")
return plugin.process(input_data)
# 实现具体插件
class TranslationPlugin(Plugin):
"""翻译插件"""
def process(self, input_data: str) -> str:
# 实现翻译逻辑
return f"Translated: {input_data}"
class SummaryPlugin(Plugin):
"""摘要插件"""
def process(self, input_data: str) -> str:
# 实现摘要逻辑
return f"Summary: {input_data}"
class FormatPlugin(Plugin):
"""格式化插件"""
def process(self, input_data: Dict) -> str:
# 实现格式化逻辑
return json.dumps(input_data, indent=2)
# 使用插件系统
manager = PluginManager()
# 注册插件
manager.register("translate", TranslationPlugin())
manager.register("summarize", SummaryPlugin())
manager.register("format", FormatPlugin())
# 使用插件
text = "Hello, World!"
translated = manager.process("translate", text)
summarized = manager.process("summarize", text)
data = {"name": "John", "age": 30}
formatted = manager.process("format", data)
print(translated)
print(summarized)
print(formatted)
最佳实践
设计模式选择:
- 根据实际需求选择
- 避免过度设计
- 保持代码简单
模式实现:
- 遵循SOLID原则
- 保持接口清晰
- 考虑扩展性
性能考虑:
- 权衡抽象和性能
- 避免不必要的复杂性
- 适当使用缓存
代码维护:
- 添加充分的文档
- 编写单元测试
- 定期重构代码
下一步
现在您已经掌握了常用的设计模式,接下来我们将通过实战练习来应用这些知识。