Skip to content

集合框架

Java 集合框架提供了一套性能优良、使用方便的接口和类,位于 java.util 包中。

集合框架概述

集合结构

Collection (接口)
├── List (接口) - 有序、可重复
│   ├── ArrayList
│   ├── LinkedList
│   ├── Vector
│   └── Stack
├── Set (接口) - 无序、去重
│   ├── HashSet
│   ├── LinkedHashSet
│   └── TreeSet
└── Queue (接口) - 队列
    ├── PriorityQueue
    ├── Deque
    │   ├── ArrayDeque
    │   └── LinkedList

Map (接口) - 键值对
├── HashMap
├── LinkedHashMap
├── TreeMap
├── Hashtable
└── Properties

List 接口

ArrayList

java
import java.util.ArrayList;
import java.util.List;

List<String> list = new ArrayList<>();

// 添加元素
list.add("苹果");
list.add("香蕉");
list.add("橙子");
list.add(1, "葡萄");  // 指定索引插入

// 获取元素
String first = list.get(0);

// 修改元素
list.set(0, "西瓜");

// 删除元素
list.remove(0);       // 按索引删除
list.remove("香蕉");   // 按对象删除

// 遍历
for (int i = 0; i < list.size(); i++) {
    System.out.println(list.get(i));
}

for (String fruit : list) {
    System.out.println(fruit);
}

// 使用迭代器
Iterator<String> iterator = list.iterator();
while (iterator.hasNext()) {
    String item = iterator.next();
    System.out.println(item);
}

// 使用 forEach (Java 8+)
list.forEach(System.out::println);

LinkedList

java
import java.util.LinkedList;

LinkedList<String> linkedList = new LinkedList<>();

// 添加元素
linkedList.add("第一");
linkedList.addFirst("最前");
linkedList.addLast("最后");

// 获取
String first = linkedList.getFirst();
String last = linkedList.getLast();

// 删除
linkedList.removeFirst();
linkedList.removeLast();

// 队列操作
linkedList.offer("新元素");    // 添加到末尾
linkedList.poll();            // 移除并返回头部
linkedList.peek();            // 查看头部元素

// ArrayList vs LinkedList
// ArrayList: 随机访问快,插入删除慢
// LinkedList: 插入删除快,随机访问慢

常见操作

java
List<Integer> numbers = new ArrayList<>();
numbers.add(1);
numbers.add(2);
numbers.add(3);
numbers.add(2);
numbers.add(4);

// 判断
boolean isEmpty = numbers.isEmpty();
boolean has2 = numbers.contains(2);

// 大小
int size = numbers.size();

// 查找索引
int index = numbers.indexOf(2);      // 首次出现位置
int lastIndex = numbers.lastIndexOf(2); // 最后一次出现位置

// 子列表
List<Integer> subList = numbers.subList(1, 3);  // [2, 3]

// 转换为数组
Object[] array = numbers.toArray();
Integer[] intArray = numbers.toArray(new Integer[0]);

// 批量操作
List<String> list1 = new ArrayList<>();
list1.addAll(numbers);         // 添加集合
boolean changed = list1.retainAll(numbers); // 保留交集
list1.clear();                 // 清空

// 排序
Collections.sort(numbers);    // 自然排序
Collections.reverse(numbers); // 反转

// 查找最大最小
int max = Collections.max(numbers);
int min = Collections.min(numbers);

// 填充
Collections.fill(numbers, 0);

Set 接口

HashSet

java
import java.util.HashSet;
import java.util.Set;

Set<String> set = new HashSet<>();

// 添加
set.add("苹果");
set.add("香蕉");
set.add("橙子");
set.add("苹果");  // 重复元素,不会添加

// 大小
System.out.println(set.size());  // 3

// 遍历
for (String fruit : set) {
    System.out.println(fruit);
}

// 判断
boolean hasApple = set.contains("苹果");
set.remove("香蕉");
boolean isEmpty = set.isEmpty();

// 集合运算
Set<Integer> set1 = new HashSet<>(Arrays.asList(1, 2, 3, 4));
Set<Integer> set2 = new HashSet<>(Arrays.asList(3, 4, 5, 6));

Set<Integer> union = new HashSet<>(set1);
union.addAll(set2);              // 并集: [1,2,3,4,5,6]

Set<Integer> intersection = new HashSet<>(set1);
intersection.retainAll(set2);    // 交集: [3, 4]

Set<Integer> difference = new HashSet<>(set1);
difference.removeAll(set2);     // 差集: [1, 2]

LinkedHashSet

java
import java.util.LinkedHashSet;

// 保持插入顺序
LinkedHashSet<String> linkedSet = new LinkedHashSet<>();
linkedSet.add("第一");
linkedSet.add("第二");
linkedSet.add("第三");

// 遍历顺序: 第一 -> 第二 -> 第三
for (String s : linkedSet) {
    System.out.println(s);
}

TreeSet

java
import java.util.TreeSet;
import java.util.NavigableSet;

// 自然顺序(需要元素实现 Comparable)
TreeSet<Integer> treeSet = new TreeSet<>();
treeSet.add(5);
treeSet.add(2);
treeSet.add(8);
treeSet.add(1);

// 遍历(已排序)
for (Integer num : treeSet) {
    System.out.println(num);  // 1, 2, 5, 8
}

// 范围查询
System.out.println(treeSet.lower(5));   // 小于5的最大值: 2
System.out.println(treeSet.floor(5));  // 小于等于5的最大值: 5
System.out.println(treeSet.higher(5)); // 大于5的最小值: 8
System.out.println(treeSet.ceiling(5)); // 大于等于5的最小值: 5

// 子集
NavigableSet<Integer> subSet = treeSet.subSet(2, true, 8, false);
System.out.println(subSet);  // [2, 5]

// 使用自定义比较器
TreeSet<String> customSet = new TreeSet<>(new Comparator<String>() {
    @Override
    public int compare(String s1, String s2) {
        return s2.compareTo(s1);  // 降序
    }
});

// 或使用 Lambda
TreeSet<String> lambdaSet = new TreeSet<>((s1, s2) -> s2.compareTo(s1));

Map 接口

HashMap

java
import java.util.HashMap;
import java.util.Map;

Map<String, Integer> map = new HashMap<>();

// 添加/修改
map.put("苹果", 10);
map.put("香蕉", 5);
map.put("橙子", 8);
map.put("苹果", 15);  // 修改已有键的值

// 获取
Integer count = map.get("苹果");  // 15
Integer notExist = map.get("葡萄"); // null

// 使用 getOrDefault
int value = map.getOrDefault("葡萄", 0);  // 0

// 判断
boolean hasKey = map.containsKey("苹果");
boolean hasValue = map.containsValue(10);

// 删除
map.remove("香蕉");
map.remove("苹果", 10);  // 仅当键值对匹配时删除

// 大小
int size = map.size();

// 遍历
// 方式1: 遍历键
for (String key : map.keySet()) {
    System.out.println(key + ": " + map.get(key));
}

// 方式2: 遍历值
for (Integer value : map.values()) {
    System.out.println(value);
}

// 方式3: 遍历键值对
for (Map.Entry<String, Integer> entry : map.entrySet()) {
    System.out.println(entry.getKey() + ": " + entry.getValue());
}

// 方式4: 使用 forEach (Java 8+)
map.forEach((k, v) -> System.out.println(k + ": " + v));

LinkedHashMap

java
import java.util.LinkedHashMap;

// 保持插入顺序
LinkedHashMap<String, Integer> linkedMap = new LinkedHashMap<>();
linkedMap.put("first", 1);
linkedMap.put("second", 2);
linkedMap.put("third", 3);

// 实现 LRU 缓存
LinkedHashMap<Integer, String> lruCache = new LinkedHashMap<>(16, 0.75f, true) {
    @Override
    protected boolean removeEldestEntry(Map.Entry eldest) {
        return size() > 3;  // 超过3个元素时移除最老的
    }
};

TreeMap

java
import java.util.TreeMap;
import java.util.Map;

TreeMap<String, Integer> treeMap = new TreeMap<>();

treeMap.put("c", 3);
treeMap.put("a", 1);
treeMap.put("b", 2);

// 遍历(按键排序)
for (String key : treeMap.keySet()) {
    System.out.println(key + ": " + treeMap.get(key));
}

// 范围查询
System.out.println(treeMap.firstKey());   // a
System.out.println(treeMap.lastKey());    // c
System.out.println(treeMap.lowerKey("b")); // a
System.out.println(treeMap.higherKey("b")); // c
System.out.println(treeMap.subMap("a", "c")); // {a=1, b=2}

// 使用自定义比较器
TreeMap<Integer, String> customMap = new TreeMap<>((a, b) -> b - a);  // 降序

Hashtable

java
import java.util.Hashtable;

// 线程安全,不允许 null 键和值
Hashtable<String, Integer> table = new Hashtable<>();
table.put("key", 1);

// 已废弃,不推荐使用
// 如需线程安全,使用 ConcurrentHashMap

Queue 接口

PriorityQueue

java
import java.util.PriorityQueue;
import java.util.Iterator;

PriorityQueue<Integer> pq = new PriorityQueue<>();

// 添加元素
pq.add(5);
pq.add(2);
pq.add(8);
pq.offer(1);

// 获取队首(不删除)
Integer peek = pq.peek();  // 1

// 获取并删除队首
Integer poll = pq.poll();  // 1

// 遍历(不保证顺序)
for (Integer num : pq) {
    System.out.println(num);
}

// 使用自定义比较器(最大堆)
PriorityQueue<Integer> maxHeap = new PriorityQueue<>((a, b) -> b - a);

// 最小堆(默认)
PriorityQueue<Integer> minHeap = new PriorityQueue<>();

// 优先队列用于 Top K 问题
PriorityQueue<Integer> topK = new PriorityQueue<>(Comparator.reverseOrder());

ArrayDeque

java
import java.util.ArrayDeque;

ArrayDeque<String> deque = new ArrayDeque<>();

// 作为栈使用
deque.push("A");
deque.push("B");
deque.push("C");
System.out.println(deque.pop());  // C (LIFO)

// 作为队列使用
deque.clear();
deque.add("A");
deque.add("B");
deque.add("C");
System.out.println(deque.poll());  // A (FIFO)

// 双端操作
deque.addFirst("first");
deque.addLast("last");
System.out.println(deque.getFirst());
System.out.println(deque.getLast());

Collections 工具类

排序与查找

java
List<Integer> list = new ArrayList<>(Arrays.asList(5, 2, 8, 1, 9));

// 排序
Collections.sort(list);              // 自然排序
Collections.reverse(list);           // 反转
Collections.shuffle(list);           // 随机打乱

// 查找
int index = Collections.binarySearch(list, 5);  // 二分查找(需先排序)
Collections.fill(list, 0);           // 填充

// 最大最小
int max = Collections.max(list);
int min = Collections.min(list);

同步控制

java
// 同步集合(线程安全,但性能较差)
List<String> synchronizedList = Collections.synchronizedList(new ArrayList<>());
Map<String, Integer> synchronizedMap = Collections.synchronizedMap(new HashMap<>());
Set<String> synchronizedSet = Collections.synchronizedSet(new HashSet<>());

// 推荐使用并发集合
List<String> concurrentList = new CopyOnWriteArrayList<>();
Map<String, Integer> concurrentMap = new ConcurrentHashMap<>();
Set<String> concurrentSet = new ConcurrentHashMap().newKeySet();

不可变集合

java
// 创建不可变集合
List<String> immutableList = Collections.unmodifiableList(Arrays.asList("a", "b", "c"));

// Java 9+ 简洁写法
List<String> list = List.of("a", "b", "c");
Set<String> set = Set.of("a", "b", "c");
Map<String, Integer> map = Map.of("a", 1, "b", 2);

// 尝试修改会抛出 UnsupportedOperationException
// immutableList.add("d");  // 运行时异常

Stream API

基本使用

java
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);

// 创建 Stream
List<Integer> list = Arrays.asList(1, 2, 3);
Stream<Integer> stream = list.stream();

// 中间操作
List<Integer> filtered = numbers.stream()
    .filter(n -> n % 2 == 0)    // 过滤偶数
    .collect(Collectors.toList());

List<Integer> mapped = numbers.stream()
    .map(n -> n * 2)            // 翻倍
    .collect(Collectors.toList());

List<Integer> distinct = numbers.stream()
    .distinct()                 // 去重
    .collect(Collectors.toList());

List<Integer> sorted = numbers.stream()
    .sorted()                   // 排序
    .collect(Collectors.toList());

List<Integer> limited = numbers.stream()
    .limit(5)                   // 取前5个
    .collect(Collectors.toList());

List<Integer> skipped = numbers.stream()
    .skip(5)                    // 跳过前5个
    .collect(Collectors.toList());

// 终端操作
long count = numbers.stream().count();
boolean hasEven = numbers.stream().anyMatch(n -> n % 2 == 0);
boolean allPositive = numbers.stream().allMatch(n -> n > 0);
boolean noneNegative = numbers.stream().noneMatch(n -> n < 0);
Integer sum = numbers.stream().reduce(0, Integer::sum);
Integer max = numbers.stream().max(Integer::compareTo).orElse(0);
Integer min = numbers.stream().min(Integer::compareTo).orElse(0);

分组与分区

java
List<String> words = Arrays.asList("apple", "banana", "cherry", "date", "elderberry");

// 按长度分组
Map<Integer, List<String>> groupedByLength = words.stream()
    .collect(Collectors.groupingBy(String::length));

// {3=[date], 5=[apple], 6=[banana, cherry], 10=[elderberry]}

// 计数
Map<Integer, Long> countByLength = words.stream()
    .collect(Collectors.groupingBy(String::length, Collectors.counting()));

// 分区(按条件分为两组)
Map<Boolean, List<String>> partitioned = words.stream()
    .collect(Collectors.partitioningBy(s -> s.length() > 5));

// {false=[apple, banana, cherry, date], true=[elderberry]}

// 多种统计
IntSummaryStatistics stats = numbers.stream()
    .collect(Collectors.summarizingInt(Integer::intValue));

System.out.println(stats.getSum());    // 总和
System.out.println(stats.getAverage()); // 平均值
System.out.println(stats.getMax());    // 最大值
System.out.println(stats.getMin());    // 最小值
System.out.println(stats.getCount());  // 数量

并行 Stream

java
// 并行处理,提高性能
long start = System.currentTimeMillis();
long count = IntStream.range(1, 1000000)
    .parallel()
    .filter(n -> n % 2 == 0)
    .count();
long end = System.currentTimeMillis();

System.out.println("并行执行耗时: " + (end - start) + "ms");

实践示例

学生成绩管理

java
import java.util.*;
import java.util.stream.Collectors;

class Student {
    private String name;
    private int score;
    private String grade;
    
    public Student(String name, int score, String grade) {
        this.name = name;
        this.score = score;
        this.grade = grade;
    }
    
    public String getName() { return name; }
    public int getScore() { return score; }
    public String getGrade() { return grade; }
}

public class StudentManager {
    public static void main(String[] args) {
        List<Student> students = Arrays.asList(
            new Student("张三", 85, "A"),
            new Student("李四", 92, "A"),
            new Student("王五", 78, "B"),
            new Student("赵六", 65, "C"),
            new Student("钱七", 55, "D"),
            new Student("孙八", 88, "A"),
            new Student("周九", 72, "B")
        );
        
        // 按成绩排序
        List<Student> sorted = students.stream()
            .sorted(Comparator.comparingInt(Student::getScore).reversed())
            .collect(Collectors.toList());
        
        System.out.println("按成绩排序:");
        sorted.forEach(s -> System.out.println(s.getName() + ": " + s.getScore()));
        
        // 按班级分组
        Map<String, List<Student>> byGrade = students.stream()
            .collect(Collectors.groupingBy(Student::getGrade));
        
        System.out.println("\n按班级分组:");
        byGrade.forEach((grade, list) -> {
            System.out.println(grade + "班: " + list.stream()
                .map(Student::getName)
                .collect(Collectors.joining(", ")));
        });
        
        // 统计各班平均分
        Map<String, Double> avgByGrade = students.stream()
            .collect(Collectors.groupingBy(
                Student::getGrade,
                Collectors.averagingInt(Student::getScore)
            ));
        
        System.out.println("\n各班平均分:");
        avgByGrade.forEach((k, v) -> System.out.println(k + "班: " + v));
        
        // 找出最高分学生
        Student topStudent = students.stream()
            .max(Comparator.comparingInt(Student::getScore))
            .orElse(null);
        
        System.out.println("\n最高分学生: " + topStudent.getName() + 
            " (" + topStudent.getScore() + "分)");
        
        // 成绩统计
        IntSummaryStatistics stats = students.stream()
            .collect(Collectors.summarizingInt(Student::getScore));
        
        System.out.println("\n成绩统计:");
        System.out.println("最高分: " + stats.getMax());
        System.out.println("最低分: " + stats.getMin());
        System.out.println("平均分: " + String.format("%.2f", stats.getAverage()));
        System.out.println("总人数: " + stats.getCount());
    }
}