组合问题(非常详细,附带源码)
组合问题指的是从给定的 N 个数中找出任意 K 个数的所有组合。例如,从 [1, 2, 3, 4] 这 4 个数中找到任意 2 个数的所有组合,它们是:
再比如,从 [1,2,3,4] 这 4 个数中找到任意 3 个数的所有组合,它们是:
编程解决组合问题,多数读者最先想到的思路是用嵌套的循环结构。比如从 [1, 2, 3, 4] 中找到任意 2 个数的组合,对应的 C 语言实现代码为:
嵌套的循环结构解决组合问题,只适合 K 值较小的情况,比如 K 是 2 或者 3。当 K 值较大时,比如 10、20 甚至更大,循环嵌套的层数非常多,代码变得复杂、难以维护,算法的时间复杂度也会呈指数级增长,导致效率极低,这种情况下更推荐用回溯算法解决组合问题。
实现回溯算法,通常会采用递归的方式,下面是用回溯算法解决组合问题的 C 语言代码:
虽然 combinationUtil() 函数的参数数量比较多,但非常容易理解:
程序中,需要重点解释的是第 49 行的循环条件
如果可选元素的数量能满足需求,即
通常情况下,多数读者只会想到
下面是回溯算法解决组合问题的 Python 程序:
下面是回溯算法解决组合问题的 Java 程序:
1 2
1 3
1 4
2 3
2 4
3 4
再比如,从 [1,2,3,4] 这 4 个数中找到任意 3 个数的所有组合,它们是:
1 2 3
1 2 4
1 3 4
2 3 4
编程解决组合问题,多数读者最先想到的思路是用嵌套的循环结构。比如从 [1, 2, 3, 4] 中找到任意 2 个数的组合,对应的 C 语言实现代码为:
#include <stdio.h> int main() { int nums[] = { 1,2,3,4 }; // i 的值作为选择的第一个数 for (int i = 0; i < 4; i++) { // j 的值作为选择的第二个数 for (int j = i + 1; j < 4; j++) { printf("%d %d\n", nums[i], nums[j]); } } return 0; }类似的,查找任意 3 个数的组合,需要用 3 个嵌套的循环结构实现。也就是说,从 N 个数中找到任意 K 个数的所有组合,用循环结构实现的话,需要用 K 个嵌套的循环结构。
嵌套的循环结构解决组合问题,只适合 K 值较小的情况,比如 K 是 2 或者 3。当 K 值较大时,比如 10、20 甚至更大,循环嵌套的层数非常多,代码变得复杂、难以维护,算法的时间复杂度也会呈指数级增长,导致效率极低,这种情况下更推荐用回溯算法解决组合问题。
回溯算法解决组合问题
假设在 {0, 1, 2, 3} 这 4 个数中找到任意 3 个数的所有组合,下面是回溯算法的整个实现思路:从 {0,1,2,3} 里,选择元素 0; 从 {1,2,3} 里,选择元素 1; 从 {2,3} 里,选择元素 2,找到了 {0,1,2} 组合; 回溯到 {2,3},选择元素 3,找到了 {0,1,3} 组合; 回溯到 {2,3},没有元素可以选择,再次回溯; 回溯到 {1,2,3},选择元素 2; 从 {3} 里选择元素 3,找到了 {0,2,3} 组合; 回溯到 {3},没有元素可以选择,再次回溯; 回溯到 {1,2,3},只剩 3 可以选择,无法构成一个新的组合,直接回溯; 回溯到 {0,1,2,3},选择元素 1; 从 {2,3} 里,选择元素 2; 从 {3} 里选择元素 3,找到了 {1,2,3} 组合; 回溯到 {2,3},只剩 3 可以选择,无法构成一个新的组合,直接回溯; 回溯到 {0,1,2,3},只剩 2 和 3 可以选择,无法构成一个新的组合,算法执行结束。仔细观察整个过程不难发现,每选择一个元素 e,下次选择是在 e 之后的元素里进行。比如选择了元素 1 和 2,下次选择就是在 {3} 中进行,而不是 {0, 3}。这样做的好处是,可以确保每种组合只找到一次,不会重复。
实现回溯算法,通常会采用递归的方式,下面是用回溯算法解决组合问题的 C 语言代码:
#include <stdio.h> #include <stdlib.h> #define N 4 // 定义数组的大小 #define K 2 // 定义组合的大小 // 为实现回溯算法做前期准备的函数 void printCombination(int arr[], int n, int k); // 实现用回溯算法解决组合问题的递归函数 void combinationUtil(int arr[], int data[], int start, int end, int index, int k); int main() { int arr[N]; // 声明一个大小为 N 的数组 // 初始化数组为 {1, 2, 3, 4} for (int i = 0; i < N; i++) { arr[i] = i + 1; } // 调用 printCombination 函数来打印所有可能的组合 printCombination(arr, N, K); return 0; } // 用于初始化临时数组并开始组合过程 void printCombination(int arr[], int n, int k) { // 动态创建一个数组,用于存储选择了的元素 int* data = (int*)malloc(k * sizeof(int)); // 使用回溯生成所有组合 combinationUtil(arr, data, 0, n, 0, k); free(data); // 释放之前分配的内存 } // 在 arr 数组中的[start,end)下标范围内,找到 k 个可选择的元素 // data[] 数组用于存储被选择了的元素 // index 参数用来统计被选择的元素个数 void combinationUtil(int arr[], int data[], int start, int end, int index, int k) { // 如果已找到一个完整的组合 if (index == k) { // 打印当前组合 for (int j = 0; j < k; j++) { printf("%d ", data[j]); } printf("\n"); return; } // 递归生成所有可能的组合 for (int i = start; (i < end) && (end - i + 1 >= k - index); i++) { data[index] = arr[i]; // 将当前元素加入到当前组合中 combinationUtil(arr, data, i + 1, end, index + 1, k); // 递归调用以选择下一个元素 } }程序中自定义了两个函数,真正实现“回溯算法解决组合问题”的是 combinationUtil() 函数,而 printCombination() 函数存在的意义是创建一个数组,用来存储回溯过程被选择的元素,以便后续输出符合要求的组合。
虽然 combinationUtil() 函数的参数数量比较多,但非常容易理解:
- [start, end):start 和 end 用来表示 arr 数组中的下标范围,也是可选择的元素范围。递归过程中,end 的值不变,但 start 的值是变化的;
- arr 和 k:arr 数组用于存储所有的元素;k 用于指定要选择的元素数量。整个递归过程中,arr 和 k 都是不变的。
- data 和 index:data 数组用来存储被选择了的元素;index 用来记录已经选择了的元素数量。
程序中,需要重点解释的是第 49 行的循环条件
(i < end) && (end - i + 1 >= k - index)
,其中 (end - i + 1 >= k - index) 的含义是:
-
end - i + 1
:计算的是从当前元素(包括当前元素)到数组末尾的元素数量,也就是从当前元素开始,数组中还剩下多少个元素可以被选择。 -
k - index
:表示为了达到指定的组合大小,还需要选择多少个元素。
如果可选元素的数量能满足需求,即
end - i + 1 >= k - index
成立,可以继续筛选;反之,如果可选元素的数量无法满足需求,说明当前情况下已经无法找到符合条件的组合,直接回溯即可。通常情况下,多数读者只会想到
i < end
作为循环条件,程序也是能正常运行的。添加end - i + 1 >= k - index
作为循环条件,可以过滤掉一些无效的、没必要存在的递归过程,进一步优化程序的运行效率。下面是回溯算法解决组合问题的 Python 程序:
def print_combination(arr, n, k): """ 用于初始化临时列表并开始组合过程 """ # 创建一个列表,用于存储选择的元素 data = [0] * k # 使用回溯生成所有组合 combination_util(arr, data, 0, n, 0, k) def combination_util(arr, data, start, end, index, k): """ 在 arr 列表中 [start, end) 下标范围内,找到 k 个可选择的元素 data 列表用于存储被选择的元素 index 参数用来统计被选择的元素个数 """ # 如果已找到一个完整的组合 if index == k: # 打印当前组合 for i in data: print(i, end=' ') print() return # 递归生成所有可能的组合 for i in range(start, end): if end - i >= k - index: data[index] = arr[i] # 将当前元素加入到当前组合中 combination_util(arr, data, i + 1, end, index + 1, k) # 递归调用以选择下一个元素 # 定义列表的大小和组合的大小 N = 4 K = 2 if __name__ == "__main__": # 声明一个大小为 N 的列表并初始化为 {1, 2, 3, 4} arr = [i + 1 for i in range(N)] # 调用 print_combination 函数来打印所有可能的组合 print_combination(arr, N, K)
下面是回溯算法解决组合问题的 Java 程序:
public class Combination { public static void main(String[] args) { final int N = 4; // 定义数组的大小 final int K = 2; // 定义组合的大小 int[] arr = new int[N]; // 初始化数组为 {1, 2, 3, 4} for (int i = 0; i < N; i++) { arr[i] = i + 1; } // 调用 printCombination 函数来打印所有可能的组合 printCombination(arr, N, K); } // 用于初始化临时数组并开始组合过程 public static void printCombination(int[] arr, int n, int k) { int[] data = new int[k]; // 使用回溯生成所有组合 combinationUtil(arr, data, 0, n, 0, k); } // 在 arr 数组中的[start, end)下标范围内,找到 k 个可选择的元素 // data 数组用于存储被选择了的元素 // index 参数用来统计被选择的元素个数 public static void combinationUtil(int[] arr, int[] data, int start, int end, int index, int k) { // 如果已找到一个完整的组合 if (index == k) { // 打印当前组合 for (int i = 0; i < k; i++) { System.out.print(data[i] + " "); } System.out.println(); return; } // 递归生成所有可能的组合 for (int i = start; (i < end) && (end - i + 1 >= k - index); i++) { data[index] = arr[i]; // 将当前元素加入到当前组合中 combinationUtil(arr, data, i + 1, end, index + 1, k); // 递归调用以选择下一个元素 } } }运行结果,结果为:
1 2
1 3
1 4
2 3
2 4
3 4