Appearance
生成器
生成器是 Python 中一种特殊的迭代器,可以按需生成值,节省内存。
生成器基础
什么是生成器
生成器是一种特殊的函数,使用 yield 关键字返回值,每次调用返回一个值,函数状态会被保留。
python
def simple_generator():
"""简单生成器"""
yield 1
yield 2
yield 3
# 创建生成器对象
gen = simple_generator()
# 迭代
print(next(gen)) # 1
print(next(gen)) # 2
print(next(gen)) # 3
# print(next(gen)) # StopIteration
# 使用 for 循环
for value in simple_generator():
print(value)生成器函数 vs 普通函数
python
# 普通函数 - 返回列表
def get_numbers_list(n):
result = []
for i in range(n):
result.append(i)
return result
# 生成器函数 - 按需生成
def get_numbers_generator(n):
for i in range(n):
yield i
# 内存对比
import sys
lst = get_numbers_list(100000)
gen = get_numbers_generator(100000)
print(f"列表大小: {sys.getsizeof(lst)} bytes") # 很大
print(f"生成器大小: {sys.getsizeof(gen)} bytes") # 很小
# 生成器只能迭代一次
gen = get_numbers_generator(3)
print(list(gen)) # [0, 1, 2]
print(list(gen)) # []yield 关键字
python
def counter(start, end):
"""计数器生成器"""
current = start
while current < end:
yield current
current += 1
for num in counter(0, 5):
print(num)
# yield 表达式
def accumulator():
"""累加器"""
total = 0
while True:
value = yield total # 接收 send() 发送的值
if value is None:
break
total += value
acc = accumulator()
next(acc) # 启动生成器
print(acc.send(10)) # 10
print(acc.send(20)) # 30
print(acc.send(30)) # 60生成器表达式
基本语法
python
# 列表推导式
squares_list = [x ** 2 for x in range(10)]
# 生成器表达式
squares_gen = (x ** 2 for x in range(10))
print(type(squares_list)) # <class 'list'>
print(type(squares_gen)) # <class 'generator'>
# 迭代生成器
for square in squares_gen:
print(square)
# 转换为列表
squares = list(x ** 2 for x in range(10))内存效率
python
import sys
# 列表 - 立即计算所有值
lst = [x ** 2 for x in range(10000)]
print(f"列表大小: {sys.getsizeof(lst)} bytes")
# 生成器 - 按需计算
gen = (x ** 2 for x in range(10000))
print(f"生成器大小: {sys.getsizeof(gen)} bytes")
# 在函数参数中使用
sum_squares = sum(x ** 2 for x in range(100))
max_square = max(x ** 2 for x in range(100))嵌套生成器表达式
python
# 嵌套循环
pairs = ((x, y) for x in range(3) for y in range(3))
for pair in pairs:
print(pair)
# 条件过滤
evens = (x for x in range(20) if x % 2 == 0)
print(list(evens))
# 复杂表达式
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
flat = (item for row in matrix for item in row)
print(list(flat))生成器方法
send() 方法
python
def echo_generator():
"""回显生成器"""
while True:
received = yield
print(f"收到: {received}")
gen = echo_generator()
next(gen) # 启动生成器
gen.send("Hello") # 收到: Hello
gen.send("World") # 收到: World
# send() 也可以发送值并获取下一个 yield
def accumulator():
total = 0
while True:
value = yield total
if value is not None:
total += value
acc = accumulator()
print(next(acc)) # 0
print(acc.send(10)) # 10
print(acc.send(20)) # 30throw() 方法
python
def error_handler():
"""错误处理生成器"""
try:
while True:
value = yield
print(f"处理: {value}")
except ValueError as e:
print(f"捕获异常: {e}")
finally:
print("清理资源")
gen = error_handler()
next(gen)
gen.send("数据1") # 处理: 数据1
gen.throw(ValueError, "出错了") # 捕获异常: 出错了 / 清理资源close() 方法
python
def infinite_counter():
"""无限计数器"""
count = 0
try:
while True:
yield count
count += 1
finally:
print("生成器关闭")
gen = infinite_counter()
print(next(gen)) # 0
print(next(gen)) # 1
gen.close() # 生成器关闭
# next(gen) # StopIterationyield from
基本用法
python
def sub_generator():
yield 1
yield 2
yield 3
def main_generator():
yield from sub_generator()
yield 4
yield 5
for value in main_generator():
print(value) # 1, 2, 3, 4, 5
# 嵌套生成器
def flatten(nested):
"""扁平化嵌套列表"""
for item in nested:
if isinstance(item, list):
yield from flatten(item)
else:
yield item
nested = [1, [2, 3], [4, [5, 6]]]
print(list(flatten(nested))) # [1, 2, 3, 4, 5, 6]双向通信
python
def sub_gen():
result = yield "子生成器"
return f"子生成器返回: {result}"
def main_gen():
value = yield from sub_gen()
yield f"主生成器收到: {value}"
gen = main_gen()
print(next(gen)) # 子生成器
print(gen.send("数据")) # 主生成器收到: 子生成器返回: 数据itertools 模块
无限迭代器
python
from itertools import count, cycle, repeat
# count - 无限计数
for i in count(0, 2): # 从0开始,步长2
if i > 10:
break
print(i) # 0, 2, 4, 6, 8, 10
# cycle - 无限循环
colors = cycle(["红", "绿", "蓝"])
for i, color in enumerate(colors):
if i >= 6:
break
print(color) # 红, 绿, 蓝, 红, 绿, 蓝
# repeat - 重复
for item in repeat("Hello", 3):
print(item) # Hello, Hello, Hello迭代器工具
python
from itertools import (
chain, islice, takewhile, dropwhile,
groupby, starmap, zip_longest
)
# chain - 链接多个迭代器
list1 = [1, 2, 3]
list2 = [4, 5, 6]
print(list(chain(list1, list2))) # [1, 2, 3, 4, 5, 6]
# islice - 切片
print(list(islice(range(20), 2, 10, 2))) # [2, 4, 6, 8]
# takewhile - 满足条件时取值
print(list(takewhile(lambda x: x < 5, range(10)))) # [0, 1, 2, 3, 4]
# dropwhile - 不满足条件后取值
print(list(dropwhile(lambda x: x < 5, range(10)))) # [5, 6, 7, 8, 9]
# groupby - 分组
from itertools import groupby
data = [("a", 1), ("a", 2), ("b", 1), ("b", 2)]
for key, group in groupby(data, lambda x: x[0]):
print(key, list(group))
# zip_longest - 不等长压缩
print(list(zip_longest([1, 2], [3, 4, 5], fillvalue=0)))
# [(1, 3), (2, 4), (0, 5)]排列组合
python
from itertools import permutations, combinations, combinations_with_replacement, product
# permutations - 排列
print(list(permutations([1, 2, 3], 2)))
# [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]
# combinations - 组合
print(list(combinations([1, 2, 3], 2)))
# [(1, 2), (1, 3), (2, 3)]
# combinations_with_replacement - 可重复组合
print(list(combinations_with_replacement([1, 2, 3], 2)))
# [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
# product - 笛卡尔积
print(list(product([1, 2], [3, 4])))
# [(1, 3), (1, 4), (2, 3), (2, 4)]实践示例
文件处理
python
def read_large_file(filename, chunk_size=1024):
"""分块读取大文件"""
with open(filename, "r", encoding="utf-8") as f:
while True:
chunk = f.read(chunk_size)
if not chunk:
break
yield chunk
def read_lines(filename):
"""逐行读取文件"""
with open(filename, "r", encoding="utf-8") as f:
for line in f:
yield line.rstrip()
def filter_lines(filename, pattern):
"""过滤文件行"""
with open(filename, "r", encoding="utf-8") as f:
for line in f:
if pattern in line:
yield line.rstrip()
# 使用
for chunk in read_large_file("large.txt"):
process(chunk)
for line in filter_lines("log.txt", "ERROR"):
print(line)数据管道
python
def generate_numbers(n):
"""生成数字"""
for i in range(n):
yield i
def filter_even(numbers):
"""过滤偶数"""
for num in numbers:
if num % 2 == 0:
yield num
def square(numbers):
"""平方"""
for num in numbers:
yield num ** 2
def pipeline():
"""数据管道"""
numbers = generate_numbers(20)
evens = filter_even(numbers)
squares = square(evens)
return squares
# 使用
for value in pipeline():
print(value) # 0, 4, 16, 36, 64, 100, 144, 196, 256, 324
# 使用生成器表达式
result = (x ** 2 for x in range(20) if x % 2 == 0)
print(list(result))斐波那契数列
python
def fibonacci():
"""无限斐波那契数列"""
a, b = 0, 1
while True:
yield a
a, b = b, a + b
# 获取前10个
from itertools import islice
fib = fibonacci()
print(list(islice(fib, 10))) # [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
# 有限版本
def fibonacci_n(n):
"""前n个斐波那契数"""
a, b = 0, 1
for _ in range(n):
yield a
a, b = b, a + b树遍历
python
class TreeNode:
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
def inorder_traversal(node):
"""中序遍历"""
if node:
yield from inorder_traversal(node.left)
yield node.value
yield from inorder_traversal(node.right)
def preorder_traversal(node):
"""前序遍历"""
if node:
yield node.value
yield from preorder_traversal(node.left)
yield from preorder_traversal(node.right)
def postorder_traversal(node):
"""后序遍历"""
if node:
yield from postorder_traversal(node.left)
yield from postorder_traversal(node.right)
yield node.value
# 使用
root = TreeNode(1,
TreeNode(2, TreeNode(4), TreeNode(5)),
TreeNode(3, None, TreeNode(6))
)
print(list(inorder_traversal(root))) # [4, 2, 5, 1, 3, 6]
print(list(preorder_traversal(root))) # [1, 2, 4, 5, 3, 6]分批处理
python
def batch(iterable, size):
"""分批处理"""
iterator = iter(iterable)
while True:
batch = list(islice(iterator, size))
if not batch:
break
yield batch
# 使用
data = range(20)
for batch_data in batch(data, 5):
print(batch_data)
# [0, 1, 2, 3, 4]
# [5, 6, 7, 8, 9]
# [10, 11, 12, 13, 14]
# [15, 16, 17, 18, 19]
# 分块处理大列表
def process_large_dataset(data, batch_size=1000):
"""处理大数据集"""
for batch_data in batch(data, batch_size):
# 处理每批数据
yield process_batch(batch_data)
def process_batch(batch):
return [x ** 2 for x in batch]并行处理
python
from concurrent.futures import ThreadPoolExecutor
from itertools import islice
def parallel_map(func, iterable, max_workers=4, chunk_size=100):
"""并行映射"""
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for chunk in batch(iterable, chunk_size):
yield from executor.map(func, chunk)
def batch(iterable, size):
iterator = iter(iterable)
while True:
chunk = list(islice(iterator, size))
if not chunk:
break
yield chunk
# 使用
def process(x):
return x ** 2
data = range(1000)
for result in parallel_map(process, data):
print(result)日志解析
python
import re
from datetime import datetime
from collections import namedtuple
LogEntry = namedtuple("LogEntry", ["timestamp", "level", "message"])
def parse_log_file(filename):
"""解析日志文件"""
pattern = r"\[(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})\] \[(\w+)\] (.+)"
with open(filename, "r", encoding="utf-8") as f:
for line in f:
match = re.match(pattern, line.strip())
if match:
timestamp = datetime.strptime(match.group(1), "%Y-%m-%d %H:%M:%S")
level = match.group(2)
message = match.group(3)
yield LogEntry(timestamp, level, message)
def filter_by_level(entries, level):
"""按级别过滤"""
for entry in entries:
if entry.level == level:
yield entry
def filter_by_time(entries, start, end):
"""按时间过滤"""
for entry in entries:
if start <= entry.timestamp <= end:
yield entry
# 使用管道
entries = parse_log_file("app.log")
errors = filter_by_level(entries, "ERROR")
for error in errors:
print(error)