Skip to content

生成器

生成器是 Python 中一种特殊的迭代器,可以按需生成值,节省内存。

生成器基础

什么是生成器

生成器是一种特殊的函数,使用 yield 关键字返回值,每次调用返回一个值,函数状态会被保留。

python
def simple_generator():
    """简单生成器"""
    yield 1
    yield 2
    yield 3

# 创建生成器对象
gen = simple_generator()

# 迭代
print(next(gen))  # 1
print(next(gen))  # 2
print(next(gen))  # 3
# print(next(gen))  # StopIteration

# 使用 for 循环
for value in simple_generator():
    print(value)

生成器函数 vs 普通函数

python
# 普通函数 - 返回列表
def get_numbers_list(n):
    result = []
    for i in range(n):
        result.append(i)
    return result

# 生成器函数 - 按需生成
def get_numbers_generator(n):
    for i in range(n):
        yield i

# 内存对比
import sys
lst = get_numbers_list(100000)
gen = get_numbers_generator(100000)

print(f"列表大小: {sys.getsizeof(lst)} bytes")  # 很大
print(f"生成器大小: {sys.getsizeof(gen)} bytes")  # 很小

# 生成器只能迭代一次
gen = get_numbers_generator(3)
print(list(gen))  # [0, 1, 2]
print(list(gen))  # []

yield 关键字

python
def counter(start, end):
    """计数器生成器"""
    current = start
    while current < end:
        yield current
        current += 1

for num in counter(0, 5):
    print(num)

# yield 表达式
def accumulator():
    """累加器"""
    total = 0
    while True:
        value = yield total  # 接收 send() 发送的值
        if value is None:
            break
        total += value

acc = accumulator()
next(acc)  # 启动生成器
print(acc.send(10))  # 10
print(acc.send(20))  # 30
print(acc.send(30))  # 60

生成器表达式

基本语法

python
# 列表推导式
squares_list = [x ** 2 for x in range(10)]

# 生成器表达式
squares_gen = (x ** 2 for x in range(10))

print(type(squares_list))  # <class 'list'>
print(type(squares_gen))   # <class 'generator'>

# 迭代生成器
for square in squares_gen:
    print(square)

# 转换为列表
squares = list(x ** 2 for x in range(10))

内存效率

python
import sys

# 列表 - 立即计算所有值
lst = [x ** 2 for x in range(10000)]
print(f"列表大小: {sys.getsizeof(lst)} bytes")

# 生成器 - 按需计算
gen = (x ** 2 for x in range(10000))
print(f"生成器大小: {sys.getsizeof(gen)} bytes")

# 在函数参数中使用
sum_squares = sum(x ** 2 for x in range(100))
max_square = max(x ** 2 for x in range(100))

嵌套生成器表达式

python
# 嵌套循环
pairs = ((x, y) for x in range(3) for y in range(3))
for pair in pairs:
    print(pair)

# 条件过滤
evens = (x for x in range(20) if x % 2 == 0)
print(list(evens))

# 复杂表达式
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
flat = (item for row in matrix for item in row)
print(list(flat))

生成器方法

send() 方法

python
def echo_generator():
    """回显生成器"""
    while True:
        received = yield
        print(f"收到: {received}")

gen = echo_generator()
next(gen)  # 启动生成器
gen.send("Hello")  # 收到: Hello
gen.send("World")  # 收到: World

# send() 也可以发送值并获取下一个 yield
def accumulator():
    total = 0
    while True:
        value = yield total
        if value is not None:
            total += value

acc = accumulator()
print(next(acc))      # 0
print(acc.send(10))   # 10
print(acc.send(20))   # 30

throw() 方法

python
def error_handler():
    """错误处理生成器"""
    try:
        while True:
            value = yield
            print(f"处理: {value}")
    except ValueError as e:
        print(f"捕获异常: {e}")
    finally:
        print("清理资源")

gen = error_handler()
next(gen)
gen.send("数据1")  # 处理: 数据1
gen.throw(ValueError, "出错了")  # 捕获异常: 出错了 / 清理资源

close() 方法

python
def infinite_counter():
    """无限计数器"""
    count = 0
    try:
        while True:
            yield count
            count += 1
    finally:
        print("生成器关闭")

gen = infinite_counter()
print(next(gen))  # 0
print(next(gen))  # 1
gen.close()       # 生成器关闭
# next(gen)       # StopIteration

yield from

基本用法

python
def sub_generator():
    yield 1
    yield 2
    yield 3

def main_generator():
    yield from sub_generator()
    yield 4
    yield 5

for value in main_generator():
    print(value)  # 1, 2, 3, 4, 5

# 嵌套生成器
def flatten(nested):
    """扁平化嵌套列表"""
    for item in nested:
        if isinstance(item, list):
            yield from flatten(item)
        else:
            yield item

nested = [1, [2, 3], [4, [5, 6]]]
print(list(flatten(nested)))  # [1, 2, 3, 4, 5, 6]

双向通信

python
def sub_gen():
    result = yield "子生成器"
    return f"子生成器返回: {result}"

def main_gen():
    value = yield from sub_gen()
    yield f"主生成器收到: {value}"

gen = main_gen()
print(next(gen))           # 子生成器
print(gen.send("数据"))    # 主生成器收到: 子生成器返回: 数据

itertools 模块

无限迭代器

python
from itertools import count, cycle, repeat

# count - 无限计数
for i in count(0, 2):  # 从0开始,步长2
    if i > 10:
        break
    print(i)  # 0, 2, 4, 6, 8, 10

# cycle - 无限循环
colors = cycle(["红", "绿", "蓝"])
for i, color in enumerate(colors):
    if i >= 6:
        break
    print(color)  # 红, 绿, 蓝, 红, 绿, 蓝

# repeat - 重复
for item in repeat("Hello", 3):
    print(item)  # Hello, Hello, Hello

迭代器工具

python
from itertools import (
    chain, islice, takewhile, dropwhile,
    groupby, starmap, zip_longest
)

# chain - 链接多个迭代器
list1 = [1, 2, 3]
list2 = [4, 5, 6]
print(list(chain(list1, list2)))  # [1, 2, 3, 4, 5, 6]

# islice - 切片
print(list(islice(range(20), 2, 10, 2)))  # [2, 4, 6, 8]

# takewhile - 满足条件时取值
print(list(takewhile(lambda x: x < 5, range(10))))  # [0, 1, 2, 3, 4]

# dropwhile - 不满足条件后取值
print(list(dropwhile(lambda x: x < 5, range(10))))  # [5, 6, 7, 8, 9]

# groupby - 分组
from itertools import groupby
data = [("a", 1), ("a", 2), ("b", 1), ("b", 2)]
for key, group in groupby(data, lambda x: x[0]):
    print(key, list(group))

# zip_longest - 不等长压缩
print(list(zip_longest([1, 2], [3, 4, 5], fillvalue=0)))
# [(1, 3), (2, 4), (0, 5)]

排列组合

python
from itertools import permutations, combinations, combinations_with_replacement, product

# permutations - 排列
print(list(permutations([1, 2, 3], 2)))
# [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]

# combinations - 组合
print(list(combinations([1, 2, 3], 2)))
# [(1, 2), (1, 3), (2, 3)]

# combinations_with_replacement - 可重复组合
print(list(combinations_with_replacement([1, 2, 3], 2)))
# [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]

# product - 笛卡尔积
print(list(product([1, 2], [3, 4])))
# [(1, 3), (1, 4), (2, 3), (2, 4)]

实践示例

文件处理

python
def read_large_file(filename, chunk_size=1024):
    """分块读取大文件"""
    with open(filename, "r", encoding="utf-8") as f:
        while True:
            chunk = f.read(chunk_size)
            if not chunk:
                break
            yield chunk

def read_lines(filename):
    """逐行读取文件"""
    with open(filename, "r", encoding="utf-8") as f:
        for line in f:
            yield line.rstrip()

def filter_lines(filename, pattern):
    """过滤文件行"""
    with open(filename, "r", encoding="utf-8") as f:
        for line in f:
            if pattern in line:
                yield line.rstrip()

# 使用
for chunk in read_large_file("large.txt"):
    process(chunk)

for line in filter_lines("log.txt", "ERROR"):
    print(line)

数据管道

python
def generate_numbers(n):
    """生成数字"""
    for i in range(n):
        yield i

def filter_even(numbers):
    """过滤偶数"""
    for num in numbers:
        if num % 2 == 0:
            yield num

def square(numbers):
    """平方"""
    for num in numbers:
        yield num ** 2

def pipeline():
    """数据管道"""
    numbers = generate_numbers(20)
    evens = filter_even(numbers)
    squares = square(evens)
    return squares

# 使用
for value in pipeline():
    print(value)  # 0, 4, 16, 36, 64, 100, 144, 196, 256, 324

# 使用生成器表达式
result = (x ** 2 for x in range(20) if x % 2 == 0)
print(list(result))

斐波那契数列

python
def fibonacci():
    """无限斐波那契数列"""
    a, b = 0, 1
    while True:
        yield a
        a, b = b, a + b

# 获取前10个
from itertools import islice
fib = fibonacci()
print(list(islice(fib, 10)))  # [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]

# 有限版本
def fibonacci_n(n):
    """前n个斐波那契数"""
    a, b = 0, 1
    for _ in range(n):
        yield a
        a, b = b, a + b

树遍历

python
class TreeNode:
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right

def inorder_traversal(node):
    """中序遍历"""
    if node:
        yield from inorder_traversal(node.left)
        yield node.value
        yield from inorder_traversal(node.right)

def preorder_traversal(node):
    """前序遍历"""
    if node:
        yield node.value
        yield from preorder_traversal(node.left)
        yield from preorder_traversal(node.right)

def postorder_traversal(node):
    """后序遍历"""
    if node:
        yield from postorder_traversal(node.left)
        yield from postorder_traversal(node.right)
        yield node.value

# 使用
root = TreeNode(1,
    TreeNode(2, TreeNode(4), TreeNode(5)),
    TreeNode(3, None, TreeNode(6))
)

print(list(inorder_traversal(root)))   # [4, 2, 5, 1, 3, 6]
print(list(preorder_traversal(root)))  # [1, 2, 4, 5, 3, 6]

分批处理

python
def batch(iterable, size):
    """分批处理"""
    iterator = iter(iterable)
    while True:
        batch = list(islice(iterator, size))
        if not batch:
            break
        yield batch

# 使用
data = range(20)
for batch_data in batch(data, 5):
    print(batch_data)
# [0, 1, 2, 3, 4]
# [5, 6, 7, 8, 9]
# [10, 11, 12, 13, 14]
# [15, 16, 17, 18, 19]

# 分块处理大列表
def process_large_dataset(data, batch_size=1000):
    """处理大数据集"""
    for batch_data in batch(data, batch_size):
        # 处理每批数据
        yield process_batch(batch_data)

def process_batch(batch):
    return [x ** 2 for x in batch]

并行处理

python
from concurrent.futures import ThreadPoolExecutor
from itertools import islice

def parallel_map(func, iterable, max_workers=4, chunk_size=100):
    """并行映射"""
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for chunk in batch(iterable, chunk_size):
            yield from executor.map(func, chunk)

def batch(iterable, size):
    iterator = iter(iterable)
    while True:
        chunk = list(islice(iterator, size))
        if not chunk:
            break
        yield chunk

# 使用
def process(x):
    return x ** 2

data = range(1000)
for result in parallel_map(process, data):
    print(result)

日志解析

python
import re
from datetime import datetime
from collections import namedtuple

LogEntry = namedtuple("LogEntry", ["timestamp", "level", "message"])

def parse_log_file(filename):
    """解析日志文件"""
    pattern = r"\[(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})\] \[(\w+)\] (.+)"

    with open(filename, "r", encoding="utf-8") as f:
        for line in f:
            match = re.match(pattern, line.strip())
            if match:
                timestamp = datetime.strptime(match.group(1), "%Y-%m-%d %H:%M:%S")
                level = match.group(2)
                message = match.group(3)
                yield LogEntry(timestamp, level, message)

def filter_by_level(entries, level):
    """按级别过滤"""
    for entry in entries:
        if entry.level == level:
            yield entry

def filter_by_time(entries, start, end):
    """按时间过滤"""
    for entry in entries:
        if start <= entry.timestamp <= end:
            yield entry

# 使用管道
entries = parse_log_file("app.log")
errors = filter_by_level(entries, "ERROR")
for error in errors:
    print(error)