Skip to content

装饰器

装饰器是 Python 的一个高级特性,用于修改或增强函数、类的行为。

装饰器基础

什么是装饰器

装饰器本质上是一个函数,它接受一个函数作为参数,返回一个新的函数。

python
def my_decorator(func):
    """简单的装饰器"""
    def wrapper():
        print("函数执行前")
        func()
        print("函数执行后")
    return wrapper

@my_decorator
def say_hello():
    print("Hello!")

# 等价于
# say_hello = my_decorator(say_hello)

say_hello()
# 函数执行前
# Hello!
# 函数执行后

带参数的装饰器

python
def my_decorator(func):
    def wrapper(*args, **kwargs):
        print("函数执行前")
        result = func(*args, **kwargs)
        print("函数执行后")
        return result
    return wrapper

@my_decorator
def add(a, b):
    return a + b

result = add(5, 3)
print(result)  # 8

保留函数元信息

python
from functools import wraps

def my_decorator(func):
    @wraps(func)  # 保留原函数的元信息
    def wrapper(*args, **kwargs):
        """wrapper 文档"""
        return func(*args, **kwargs)
    return wrapper

@my_decorator
def greet(name):
    """问候函数"""
    print(f"Hello, {name}!")

print(greet.__name__)  # greet(没有 @wraps 会是 wrapper)
print(greet.__doc__)   # 问候函数

装饰器类型

简单装饰器

python
def log_calls(func):
    """记录函数调用"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        print(f"调用 {func.__name__}")
        print(f"参数: args={args}, kwargs={kwargs}")
        result = func(*args, **kwargs)
        print(f"返回: {result}")
        return result
    return wrapper

@log_calls
def add(a, b):
    return a + b

add(5, 3)
# 调用 add
# 参数: args=(5, 3), kwargs={}
# 返回: 8

带参数的装饰器

python
def repeat(times):
    """重复执行装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            results = []
            for _ in range(times):
                results.append(func(*args, **kwargs))
            return results
        return wrapper
    return decorator

@repeat(times=3)
def greet(name):
    return f"Hello, {name}!"

print(greet("World"))
# ['Hello, World!', 'Hello, World!', 'Hello, World!']

类装饰器

python
class CountCalls:
    """统计调用次数的类装饰器"""
    
    def __init__(self, func):
        self.func = func
        self.count = 0
    
    def __call__(self, *args, **kwargs):
        self.count += 1
        print(f"第 {self.count} 次调用 {self.func.__name__}")
        return self.func(*args, **kwargs)

@CountCalls
def say_hello():
    print("Hello!")

say_hello()  # 第 1 次调用 say_hello
say_hello()  # 第 2 次调用 say_hello
say_hello()  # 第 3 次调用 say_hello

装饰类的装饰器

python
def add_methods(cls):
    """为类添加方法"""
    def new_method(self):
        return "新方法"
    
    cls.new_method = new_method
    return cls

@add_methods
class MyClass:
    def existing_method(self):
        return "已有方法"

obj = MyClass()
print(obj.existing_method())  # 已有方法
print(obj.new_method())       # 新方法

常用装饰器模式

计时装饰器

python
import time
from functools import wraps

def timer(func):
    """计算函数执行时间"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        print(f"{func.__name__} 执行时间: {end - start:.4f}秒")
        return result
    return wrapper

@timer
def slow_function():
    time.sleep(1)
    return "完成"

slow_function()  # slow_function 执行时间: 1.0012秒

缓存装饰器

python
from functools import wraps

def memoize(func):
    """缓存函数结果"""
    cache = {}
    
    @wraps(func)
    def wrapper(*args):
        if args in cache:
            print(f"从缓存获取: {args}")
            return cache[args]
        print(f"计算: {args}")
        result = func(*args)
        cache[args] = result
        return result
    
    return wrapper

@memoize
def fibonacci(n):
    if n < 2:
        return n
    return fibonacci(n - 1) + fibonacci(n - 2)

print(fibonacci(10))

# 使用 functools.lru_cache
from functools import lru_cache

@lru_cache(maxsize=128)
def fibonacci2(n):
    if n < 2:
        return n
    return fibonacci2(n - 1) + fibonacci2(n - 2)

重试装饰器

python
import time
from functools import wraps

def retry(max_attempts=3, delay=1, exceptions=(Exception,)):
    """重试装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            last_exception = None
            for attempt in range(max_attempts):
                try:
                    return func(*args, **kwargs)
                except exceptions as e:
                    last_exception = e
                    if attempt < max_attempts - 1:
                        print(f"第 {attempt + 1} 次尝试失败,{delay}秒后重试")
                        time.sleep(delay)
            raise last_exception
        return wrapper
    return decorator

@retry(max_attempts=3, delay=2, exceptions=(ValueError,))
def risky_function():
    import random
    if random.random() < 0.7:
        raise ValueError("随机错误")
    return "成功"

验证装饰器

python
from functools import wraps

def validate_types(**types):
    """类型验证装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 获取参数名和值
            import inspect
            sig = inspect.signature(func)
            bound = sig.bind(*args, **kwargs)
            bound.apply_defaults()
            
            # 验证类型
            for name, value in bound.arguments.items():
                if name in types and not isinstance(value, types[name]):
                    raise TypeError(
                        f"参数 '{name}' 应为 {types[name]} 类型,"
                        f"实际为 {type(value)}"
                    )
            
            return func(*args, **kwargs)
        return wrapper
    return decorator

@validate_types(name=str, age=int)
def greet(name, age):
    print(f"{name} 今年 {age} 岁")

greet("张三", 25)  # 正确
# greet("张三", "25")  # TypeError

权限装饰器

python
from functools import wraps

def require_permission(permission):
    """权限检查装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(user, *args, **kwargs):
            if permission not in user.get("permissions", []):
                raise PermissionError(f"需要权限: {permission}")
            return func(user, *args, **kwargs)
        return wrapper
    return decorator

@require_permission("admin")
def delete_user(user, user_id):
    print(f"删除用户 {user_id}")

admin = {"name": "管理员", "permissions": ["admin", "read", "write"]}
user = {"name": "普通用户", "permissions": ["read"]}

delete_user(admin, 123)  # 正确
# delete_user(user, 123)  # PermissionError

单例装饰器

python
from functools import wraps

def singleton(cls):
    """单例装饰器"""
    instances = {}
    
    @wraps(cls)
    def get_instance(*args, **kwargs):
        if cls not in instances:
            instances[cls] = cls(*args, **kwargs)
        return instances[cls]
    
    return get_instance

@singleton
class Database:
    def __init__(self, connection_string):
        self.connection_string = connection_string
        print(f"创建数据库连接: {connection_string}")

db1 = Database("mysql://...")
db2 = Database("mysql://...")
print(db1 is db2)  # True

属性验证装饰器

python
def validated_property(name, type_, min_val=None, max_val=None):
    """属性验证装饰器"""
    private_name = f"_{name}"
    
    def getter(self):
        return getattr(self, private_name, None)
    
    def setter(self, value):
        if not isinstance(value, type_):
            raise TypeError(f"{name} 必须是 {type_} 类型")
        if min_val is not None and value < min_val:
            raise ValueError(f"{name} 不能小于 {min_val}")
        if max_val is not None and value > max_val:
            raise ValueError(f"{name} 不能大于 {max_val}")
        setattr(self, private_name, value)
    
    return property(getter, setter)

class Person:
    name = validated_property("name", str)
    age = validated_property("age", int, min_val=0, max_val=150)
    
    def __init__(self, name, age):
        self.name = name
        self.age = age

p = Person("张三", 25)
# p.age = -5  # ValueError
# p.age = "25"  # TypeError

装饰器组合

多个装饰器

python
@decorator1
@decorator2
@decorator3
def func():
    pass

# 等价于
# func = decorator1(decorator2(decorator3(func)))

# 执行顺序: decorator1 -> decorator2 -> decorator3 -> func

装饰器示例

python
import time
from functools import wraps

def timer(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        print(f"耗时: {time.time() - start:.4f}秒")
        return result
    return wrapper

def log(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        print(f"调用 {func.__name__}")
        result = func(*args, **kwargs)
        print(f"完成 {func.__name__}")
        return result
    return wrapper

@timer
@log
def process():
    time.sleep(1)
    return "结果"

process()
# 调用 process
# 完成 process
# 耗时: 1.0012秒

实践示例

API 路由装饰器

python
from functools import wraps

class Router:
    def __init__(self):
        self.routes = {}
    
    def route(self, path, methods=["GET"]):
        """路由装饰器"""
        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)
            
            for method in methods:
                key = f"{method}:{path}"
                self.routes[key] = wrapper
            
            return wrapper
        return decorator
    
    def get(self, path):
        return self.route(path, ["GET"])
    
    def post(self, path):
        return self.route(path, ["POST"])

router = Router()

@router.get("/users")
def list_users():
    return ["user1", "user2"]

@router.post("/users")
def create_user():
    return {"status": "created"}

print(router.routes)

上下文管理装饰器

python
from contextlib import contextmanager
from functools import wraps

def context_manager(func):
    """将生成器函数转换为上下文管理器"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        return contextmanager(func)(*args, **kwargs)
    return wrapper

@context_manager
def managed_file(filename, mode):
    """文件管理器"""
    f = open(filename, mode)
    try:
        yield f
    finally:
        f.close()

with managed_file("test.txt", "w") as f:
    f.write("Hello")

注册装饰器

python
class Registry:
    """函数注册器"""
    
    def __init__(self):
        self._functions = {}
    
    def register(self, name=None):
        """注册装饰器"""
        def decorator(func):
            key = name or func.__name__
            self._functions[key] = func
            return func
        return decorator
    
    def call(self, name, *args, **kwargs):
        """调用注册的函数"""
        if name not in self._functions:
            raise ValueError(f"未注册的函数: {name}")
        return self._functions[name](*args, **kwargs)
    
    def list_functions(self):
        """列出所有注册的函数"""
        return list(self._functions.keys())

registry = Registry()

@registry.register()
def greet(name):
    return f"Hello, {name}!"

@registry.register("custom_name")
def farewell(name):
    return f"Goodbye, {name}!"

print(registry.call("greet", "World"))      # Hello, World!
print(registry.call("custom_name", "World")) # Goodbye, World!

状态机装饰器

python
from functools import wraps

def state_machine(initial_state):
    """状态机装饰器"""
    def decorator(cls):
        class StateMachine:
            def __init__(self, *args, **kwargs):
                self._state = initial_state
                self._instance = cls(*args, **kwargs)
            
            def transition(self, new_state):
                """状态转换"""
                old_state = self._state
                self._state = new_state
                if hasattr(self._instance, f"on_enter_{new_state}"):
                    getattr(self._instance, f"on_enter_{new_state}")(old_state)
            
            def __getattr__(self, name):
                return getattr(self._instance, name)
        
        return StateMachine
    return decorator

@state_machine("idle")
class Door:
    def open(self):
        print("开门")
        self._state = "open"
    
    def close(self):
        print("关门")
        self._state = "closed"
    
    def on_enter_open(self, old_state):
        print(f"从 {old_state} 进入 open 状态")

door = Door()
door.open()  # 开门 / 从 idle 进入 open 状态