Skip to content

Stream API

Stream API 是 Java 8 引入的函数式编程特性,用于处理集合数据的声明式操作。

Stream 概述

什么是 Stream

Stream 是数据渠道,用于操作数据源(集合、数组等)所生成的元素序列。

java
// 传统方式
List<String> names = new ArrayList<>();
for (Person person : persons) {
    if (person.getAge() > 18) {
        names.add(person.getName());
    }
}

// Stream 方式
List<String> names = persons.stream()
    .filter(p -> p.getAge() > 18)
    .map(Person::getName)
    .collect(Collectors.toList());

Stream 特点

  • 声明式编程:描述做什么,而不是怎么做
  • 链式操作:支持流水线操作
  • 内部迭代:无需手动迭代
  • 延迟执行:中间操作不会立即执行
  • 可能并行:支持并行处理

创建 Stream

java
import java.util.stream.*;
import java.util.*;

// 从集合创建
List<String> list = Arrays.asList("a", "b", "c");
Stream<String> stream = list.stream();
Stream<String> parallelStream = list.parallelStream();

// 从数组创建
String[] array = {"a", "b", "c"};
Stream<String> stream2 = Arrays.stream(array);

// 使用 Stream.of
Stream<String> stream3 = Stream.of("a", "b", "c");

// 创建空 Stream
Stream<String> emptyStream = Stream.empty();

// 无限流
Stream<Double> randoms = Stream.generate(Math::random);
Stream<Integer> naturals = Stream.iterate(0, n -> n + 1);

// 有限流(Java 9+)
Stream<Integer> limited = Stream.iterate(0, n -> n < 10, n -> n + 1);

// 基本类型流
IntStream intStream = IntStream.range(1, 10);      // 1-9
IntStream intStream2 = IntStream.rangeClosed(1, 10); // 1-10
LongStream longStream = LongStream.of(1L, 2L, 3L);
DoubleStream doubleStream = DoubleStream.of(1.0, 2.0, 3.0);

// 从文件创建
Stream<String> lines = Files.lines(Paths.get("file.txt"));

中间操作

filter - 过滤

java
List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);

// 过滤偶数
List<Integer> evens = numbers.stream()
    .filter(n -> n % 2 == 0)
    .collect(Collectors.toList());
// [2, 4, 6, 8, 10]

// 多个 filter
List<Integer> result = numbers.stream()
    .filter(n -> n > 3)
    .filter(n -> n < 8)
    .collect(Collectors.toList());
// [4, 5, 6, 7]

// 过滤对象
List<Person> adults = persons.stream()
    .filter(p -> p.getAge() >= 18)
    .collect(Collectors.toList());

map - 映射

java
List<String> names = Arrays.asList("alice", "bob", "charlie");

// 转大写
List<String> upperNames = names.stream()
    .map(String::toUpperCase)
    .collect(Collectors.toList());
// [ALICE, BOB, CHARLIE]

// 获取长度
List<Integer> lengths = names.stream()
    .map(String::length)
    .collect(Collectors.toList());
// [5, 3, 7]

// 提取属性
List<String> personNames = persons.stream()
    .map(Person::getName)
    .collect(Collectors.toList());

flatMap - 扁平化映射

java
List<List<Integer>> nestedLists = Arrays.asList(
    Arrays.asList(1, 2),
    Arrays.asList(3, 4),
    Arrays.asList(5, 6)
);

// 扁平化
List<Integer> flatList = nestedLists.stream()
    .flatMap(Collection::stream)
    .collect(Collectors.toList());
// [1, 2, 3, 4, 5, 6]

// 字符串拆分
List<String> words = Arrays.asList("Hello World", "Java Stream");
List<String> allWords = words.stream()
    .flatMap(s -> Arrays.stream(s.split(" ")))
    .collect(Collectors.toList());
// [Hello, World, Java, Stream]

distinct - 去重

java
List<Integer> numbers = Arrays.asList(1, 2, 2, 3, 3, 3, 4);

List<Integer> distinct = numbers.stream()
    .distinct()
    .collect(Collectors.toList());
// [1, 2, 3, 4]

sorted - 排序

java
List<Integer> numbers = Arrays.asList(5, 2, 8, 1, 9);

// 自然排序
List<Integer> sorted = numbers.stream()
    .sorted()
    .collect(Collectors.toList());
// [1, 2, 5, 8, 9]

// 降序
List<Integer> descSorted = numbers.stream()
    .sorted(Comparator.reverseOrder())
    .collect(Collectors.toList());
// [9, 8, 5, 2, 1]

// 自定义排序
List<Person> sortedPersons = persons.stream()
    .sorted(Comparator.comparing(Person::getAge))
    .collect(Collectors.toList());

// 多字段排序
List<Person> sorted = persons.stream()
    .sorted(Comparator.comparing(Person::getAge)
        .thenComparing(Person::getName))
    .collect(Collectors.toList());

limit 和 skip

java
List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);

// 取前5个
List<Integer> first5 = numbers.stream()
    .limit(5)
    .collect(Collectors.toList());
// [1, 2, 3, 4, 5]

// 跳过前5个
List<Integer> skip5 = numbers.stream()
    .skip(5)
    .collect(Collectors.toList());
// [6, 7, 8, 9, 10]

// 分页
int page = 2, size = 5;
List<Integer> pageData = numbers.stream()
    .skip((page - 1) * size)
    .limit(size)
    .collect(Collectors.toList());
// [6, 7, 8, 9, 10]

peek - 查看元素

java
List<Integer> result = numbers.stream()
    .peek(n -> System.out.println("处理: " + n))
    .filter(n -> n > 5)
    .peek(n -> System.out.println("过滤后: " + n))
    .collect(Collectors.toList());

终端操作

collect - 收集

java
// 收集为 List
List<String> list = stream.collect(Collectors.toList());

// 收集为 Set
Set<String> set = stream.collect(Collectors.toSet());

// 收集为 Map
Map<Integer, String> map = persons.stream()
    .collect(Collectors.toMap(
        Person::getId,
        Person::getName
    ));

// 收集为指定集合
LinkedList<String> linkedList = stream.collect(
    Collectors.toCollection(LinkedList::new)
);

// 连接字符串
String joined = names.stream()
    .collect(Collectors.joining(", "));
// "Alice, Bob, Charlie"

// 带前后缀
String joined2 = names.stream()
    .collect(Collectors.joining(", ", "[", "]"));
// "[Alice, Bob, Charlie]"

聚合操作

java
// 计数
long count = persons.stream().count();

// 最大值
Optional<Integer> max = numbers.stream()
    .max(Integer::compareTo);

// 最小值
Optional<Integer> min = numbers.stream()
    .min(Integer::compareTo);

// 求和
int sum = numbers.stream()
    .mapToInt(Integer::intValue)
    .sum();

// 平均值
OptionalDouble avg = numbers.stream()
    .mapToInt(Integer::intValue)
    .average();

// 统计信息
IntSummaryStatistics stats = numbers.stream()
    .mapToInt(Integer::intValue)
    .summaryStatistics();
stats.getCount();   // 数量
stats.getSum();     // 总和
stats.getAverage(); // 平均值
stats.getMax();     // 最大值
stats.getMin();     // 最小值

归约操作

java
// 求和
int sum = numbers.stream()
    .reduce(0, Integer::sum);

// 无初始值
Optional<Integer> sum2 = numbers.stream()
    .reduce(Integer::sum);

// 最大值
Optional<Integer> max = numbers.stream()
    .reduce(Integer::max);

// 字符串拼接
String result = names.stream()
    .reduce("", (a, b) -> a + " " + b);

// 复杂归约
int sumOfSquares = numbers.stream()
    .reduce(0, (sum, n) -> sum + n * n);

匹配操作

java
// 是否所有元素满足条件
boolean allAdult = persons.stream()
    .allMatch(p -> p.getAge() >= 18);

// 是否存在元素满足条件
boolean hasChild = persons.stream()
    .anyMatch(p -> p.getAge() < 18);

// 是否没有元素满足条件
boolean noneMatch = persons.stream()
    .noneMatch(p -> p.getAge() < 0);

查找操作

java
// 查找第一个元素
Optional<String> first = names.stream()
    .filter(s -> s.startsWith("A"))
    .findFirst();

// 查找任意元素(并行流更高效)
Optional<String> any = names.stream()
    .filter(s -> s.startsWith("A"))
    .findAny();

遍历操作

java
// forEach
names.stream().forEach(System.out::println);

// forEachOrdered(保证顺序)
names.parallelStream()
    .forEachOrdered(System.out::println);

toArray

java
// 转数组
String[] array = names.stream().toArray(String[]::new);

// 转基本类型数组
int[] intArray = numbers.stream()
    .mapToInt(Integer::intValue)
    .toArray();

分组与分区

分组

java
// 按属性分组
Map<String, List<Person>> byCity = persons.stream()
    .collect(Collectors.groupingBy(Person::getCity));

// 多级分组
Map<String, Map<String, List<Person>>> byCityAndGender = persons.stream()
    .collect(Collectors.groupingBy(
        Person::getCity,
        Collectors.groupingBy(Person::getGender)
    ));

// 分组并计数
Map<String, Long> countByCity = persons.stream()
    .collect(Collectors.groupingBy(
        Person::getCity,
        Collectors.counting()
    ));

// 分组并求和
Map<String, Integer> sumByCity = persons.stream()
    .collect(Collectors.groupingBy(
        Person::getCity,
        Collectors.summingInt(Person::getSalary)
    ));

// 分组并获取最大值
Map<String, Optional<Person>> maxByCity = persons.stream()
    .collect(Collectors.groupingBy(
        Person::getCity,
        Collectors.maxBy(Comparator.comparing(Person::getSalary))
    ));

// 分组并转换
Map<String, List<String>> namesByCity = persons.stream()
    .collect(Collectors.groupingBy(
        Person::getCity,
        Collectors.mapping(Person::getName, Collectors.toList())
    ));

分区

java
// 按条件分为两组
Map<Boolean, List<Person>> partitioned = persons.stream()
    .collect(Collectors.partitioningBy(p -> p.getAge() >= 18));
// {false=[未成年列表], true=[成年列表]}

// 分区并计数
Map<Boolean, Long> countByAdult = persons.stream()
    .collect(Collectors.partitioningBy(
        p -> p.getAge() >= 18,
        Collectors.counting()
    ));

并行流

创建并行流

java
// 从集合创建
Stream<String> parallelStream = list.parallelStream();

// 转换为并行流
Stream<String> parallelStream2 = list.stream().parallel();

// 判断是否为并行流
boolean isParallel = stream.isParallel();

// 转换为顺序流
Stream<String> sequential = parallelStream.sequential();

并行流示例

java
// 并行求和
long sum = numbers.parallelStream()
    .mapToLong(Long::longValue)
    .sum();

// 并行处理大数据
long count = Files.lines(Paths.get("large.txt"))
    .parallel()
    .filter(line -> line.contains("error"))
    .count();

并行流注意事项

java
// 避免共享可变状态
List<Integer> list = new ArrayList<>();
IntStream.range(0, 1000)
    .parallel()
    .forEach(i -> {
        // 错误:线程安全问题
        // list.add(i);
    });

// 正确:使用线程安全集合
List<Integer> safeList = IntStream.range(0, 1000)
    .parallel()
    .boxed()
    .collect(Collectors.toList());

// 或使用同步
List<Integer> syncList = Collections.synchronizedList(new ArrayList<>());
IntStream.range(0, 1000)
    .parallel()
    .forEach(syncList::add);

原始类型流

IntStream / LongStream / DoubleStream

java
// 创建
IntStream intStream = IntStream.range(1, 10);
LongStream longStream = LongStream.rangeClosed(1, 10);
DoubleStream doubleStream = DoubleStream.of(1.0, 2.0, 3.0);

// 从普通流转换
IntStream ages = persons.stream()
    .mapToInt(Person::getAge);

// 转换回对象流
Stream<Integer> boxed = intStream.boxed();

// 常用操作
int sum = intStream.sum();
OptionalDouble avg = intStream.average();
OptionalInt max = intStream.max();
OptionalInt min = intStream.min();

// 统计信息
IntSummaryStatistics stats = intStream.summaryStatistics();

自定义收集器

实现收集器

java
public class StringJoiner implements Collector<CharSequence, StringBuilder, String> {
    
    private final String delimiter;
    
    public StringJoiner(String delimiter) {
        this.delimiter = delimiter;
    }
    
    @Override
    public Supplier<StringBuilder> supplier() {
        return StringBuilder::new;
    }
    
    @Override
    public BiConsumer<StringBuilder, CharSequence> accumulator() {
        return (sb, s) -> {
            if (sb.length() > 0) {
                sb.append(delimiter);
            }
            sb.append(s);
        };
    }
    
    @Override
    public BinaryOperator<StringBuilder> combiner() {
        return (sb1, sb2) -> {
            if (sb1.length() > 0 && sb2.length() > 0) {
                sb1.append(delimiter);
            }
            return sb1.append(sb2);
        };
    }
    
    @Override
    public Function<StringBuilder, String> finisher() {
        return StringBuilder::toString;
    }
    
    @Override
    public Set<Characteristics> characteristics() {
        return Collections.emptySet();
    }
}

// 使用
String result = names.stream()
    .collect(new StringJoiner(", "));

实践示例

数据处理

java
public class DataProcessor {
    
    // 过滤和排序
    public List<Person> getTopAdults(List<Person> persons, int limit) {
        return persons.stream()
            .filter(p -> p.getAge() >= 18)
            .sorted(Comparator.comparing(Person::getSalary).reversed())
            .limit(limit)
            .collect(Collectors.toList());
    }
    
    // 分组统计
    public Map<String, Double> getAverageSalaryByDepartment(List<Person> persons) {
        return persons.stream()
            .collect(Collectors.groupingBy(
                Person::getDepartment,
                Collectors.averagingDouble(Person::getSalary)
            ));
    }
    
    // 查找重复
    public <T> List<T> findDuplicates(List<T> list) {
        return list.stream()
            .collect(Collectors.groupingBy(Function.identity(), Collectors.counting()))
            .entrySet().stream()
            .filter(e -> e.getValue() > 1)
            .map(Map.Entry::getKey)
            .collect(Collectors.toList());
    }
    
    // 扁平化嵌套列表
    public <T> List<T> flatten(List<List<T>> nestedLists) {
        return nestedLists.stream()
            .flatMap(Collection::stream)
            .collect(Collectors.toList());
    }
    
    // 分批处理
    public <T> List<List<T>> partition(List<T> list, int batchSize) {
        return IntStream.range(0, (list.size() + batchSize - 1) / batchSize)
            .mapToObj(i -> list.subList(
                i * batchSize,
                Math.min((i + 1) * batchSize, list.size())
            ))
            .collect(Collectors.toList());
    }
}

文件处理

java
public class FileAnalyzer {
    
    // 统计单词频率
    public Map<String, Long> wordFrequency(String filePath) throws IOException {
        return Files.lines(Paths.get(filePath))
            .flatMap(line -> Arrays.stream(line.split("\\s+")))
            .filter(word -> !word.isEmpty())
            .map(String::toLowerCase)
            .collect(Collectors.groupingBy(
                Function.identity(),
                Collectors.counting()
            ));
    }
    
    // 查找最长行
    public Optional<String> findLongestLine(String filePath) throws IOException {
        return Files.lines(Paths.get(filePath))
            .max(Comparator.comparingInt(String::length));
    }
    
    // 过滤日志
    public List<String> filterErrors(String logFile) throws IOException {
        return Files.lines(Paths.get(logFile))
            .filter(line -> line.contains("ERROR"))
            .collect(Collectors.toList());
    }
}