组合问题(非常详细,附带源码)

 
组合问题指的是从给定的 N 个数中找出任意 K 个数的所有组合。例如,从 [1, 2, 3, 4] 这 4 个数中找到任意 2 个数的所有组合,它们是:

1 2
1 3
1 4
2 3
2 4
3 4

注意,组合内的元素是没有顺序的,这就意味着 [1, 2] 和 [2, 1] 是同一个组合。因此从 [1, 2, 3, 4] 中找到任意 2 个数的组合,一共有 6 组。

再比如,从 [1,2,3,4] 这 4 个数中找到任意 3 个数的所有组合,它们是:

1 2 3
1 2 4
1 3 4
2 3 4


编程解决组合问题,多数读者最先想到的思路是用嵌套的循环结构。比如从 [1, 2, 3, 4] 中找到任意 2 个数的组合,对应的 C 语言实现代码为:
#include <stdio.h>
int main()
{
    int nums[] = { 1,2,3,4 };
    // i 的值作为选择的第一个数
    for (int i = 0; i < 4; i++) {
        // j 的值作为选择的第二个数
        for (int j = i + 1; j < 4; j++) {
            printf("%d %d\n", nums[i], nums[j]);
        }
    }
    return 0;
}
类似的,查找任意 3 个数的组合,需要用 3 个嵌套的循环结构实现。也就是说,从 N 个数中找到任意 K 个数的所有组合,用循环结构实现的话,需要用 K 个嵌套的循环结构。

嵌套的循环结构解决组合问题,只适合 K 值较小的情况,比如 K 是 2 或者 3。当 K 值较大时,比如 10、20 甚至更大,循环嵌套的层数非常多,代码变得复杂、难以维护,算法的时间复杂度也会呈指数级增长,导致效率极低,这种情况下更推荐用回溯算法解决组合问题。

回溯算法解决组合问题

假设在 {0, 1, 2, 3} 这 4 个数中找到任意 3 个数的所有组合,下面是回溯算法的整个实现思路:
从 {0,1,2,3} 里,选择元素 0;
    从 {1,2,3} 里,选择元素 1;
        从 {2,3} 里,选择元素 2,找到了 {0,1,2} 组合;
        回溯到 {2,3},选择元素 3,找到了 {0,1,3} 组合;
        回溯到 {2,3},没有元素可以选择,再次回溯;
    回溯到 {1,2,3},选择元素 2;
        从 {3} 里选择元素 3,找到了 {0,2,3} 组合;
        回溯到 {3},没有元素可以选择,再次回溯;
    回溯到 {1,2,3},只剩 3 可以选择,无法构成一个新的组合,直接回溯;
回溯到 {0,1,2,3},选择元素 1;
    从 {2,3} 里,选择元素 2;
        从 {3} 里选择元素 3,找到了 {1,2,3} 组合;
    回溯到 {2,3},只剩 3 可以选择,无法构成一个新的组合,直接回溯;
回溯到 {0,1,2,3},只剩 2 和 3 可以选择,无法构成一个新的组合,算法执行结束。
仔细观察整个过程不难发现,每选择一个元素 e,下次选择是在 e 之后的元素里进行。比如选择了元素 1 和 2,下次选择就是在 {3} 中进行,而不是 {0, 3}。这样做的好处是,可以确保每种组合只找到一次,不会重复。

实现回溯算法,通常会采用递归的方式,下面是用回溯算法解决组合问题的 C 语言代码:
#include <stdio.h>
#include <stdlib.h>
#define N 4  // 定义数组的大小
#define K 2  // 定义组合的大小

// 为实现回溯算法做前期准备的函数
void printCombination(int arr[], int n, int k);
// 实现用回溯算法解决组合问题的递归函数
void combinationUtil(int arr[], int data[], int start, int end, int index, int k);

int main() {
    int arr[N];  // 声明一个大小为 N 的数组

    // 初始化数组为 {1, 2, 3, 4}
    for (int i = 0; i < N; i++) {
        arr[i] = i + 1;
    }
    // 调用 printCombination 函数来打印所有可能的组合
    printCombination(arr, N, K);
    return 0;
}

// 用于初始化临时数组并开始组合过程
void printCombination(int arr[], int n, int k) {
    // 动态创建一个数组,用于存储选择了的元素
    int* data = (int*)malloc(k * sizeof(int));

    // 使用回溯生成所有组合
    combinationUtil(arr, data, 0, n, 0, k);

    free(data);  // 释放之前分配的内存
}

// 在 arr 数组中的[start,end)下标范围内,找到 k 个可选择的元素
// data[] 数组用于存储被选择了的元素
// index 参数用来统计被选择的元素个数
void combinationUtil(int arr[], int data[], int start, int end, int index, int k) {
    // 如果已找到一个完整的组合
    if (index == k) {
        // 打印当前组合
        for (int j = 0; j < k; j++) {
            printf("%d ", data[j]);
        }
        printf("\n");
        return;
    }

    // 递归生成所有可能的组合
    for (int i = start; (i < end) && (end - i + 1 >= k - index); i++) {
        data[index] = arr[i];  // 将当前元素加入到当前组合中
        combinationUtil(arr, data, i + 1, end, index + 1, k);  // 递归调用以选择下一个元素
    }
}
程序中自定义了两个函数,真正实现“回溯算法解决组合问题”的是 combinationUtil() 函数,而 printCombination() 函数存在的意义是创建一个数组,用来存储回溯过程被选择的元素,以便后续输出符合要求的组合。

虽然 combinationUtil() 函数的参数数量比较多,但非常容易理解:
  • [start, end):start 和 end 用来表示 arr 数组中的下标范围,也是可选择的元素范围。递归过程中,end 的值不变,但 start 的值是变化的;
  • arr 和 k:arr 数组用于存储所有的元素;k 用于指定要选择的元素数量。整个递归过程中,arr 和 k 都是不变的。
  • data 和 index:data 数组用来存储被选择了的元素;index 用来记录已经选择了的元素数量。

程序中,需要重点解释的是第 49 行的循环条件(i < end) && (end - i + 1 >= k - index),其中 (end - i + 1 >= k - index) 的含义是:
  • end - i + 1:计算的是从当前元素(包括当前元素)到数组末尾的元素数量,也就是从当前元素开始,数组中还剩下多少个元素可以被选择。
  • k - index:表示为了达到指定的组合大小,还需要选择多少个元素。

如果可选元素的数量能满足需求,即end - i + 1 >= k - index成立,可以继续筛选;反之,如果可选元素的数量无法满足需求,说明当前情况下已经无法找到符合条件的组合,直接回溯即可。

通常情况下,多数读者只会想到i < end作为循环条件,程序也是能正常运行的。添加end - i + 1 >= k - index作为循环条件,可以过滤掉一些无效的、没必要存在的递归过程,进一步优化程序的运行效率。

下面是回溯算法解决组合问题的 Python 程序:
def print_combination(arr, n, k):
    """
    用于初始化临时列表并开始组合过程
    """
    # 创建一个列表,用于存储选择的元素
    data = [0] * k
    # 使用回溯生成所有组合
    combination_util(arr, data, 0, n, 0, k)

def combination_util(arr, data, start, end, index, k):
    """
    在 arr 列表中 [start, end) 下标范围内,找到 k 个可选择的元素
    data 列表用于存储被选择的元素
    index 参数用来统计被选择的元素个数
    """
    # 如果已找到一个完整的组合
    if index == k:
        # 打印当前组合
        for i in data:
            print(i, end=' ')
        print()
        return

    # 递归生成所有可能的组合
    for i in range(start, end):
        if end - i >= k - index:
            data[index] = arr[i]  # 将当前元素加入到当前组合中
            combination_util(arr, data, i + 1, end, index + 1, k)  # 递归调用以选择下一个元素

# 定义列表的大小和组合的大小
N = 4
K = 2

if __name__ == "__main__":
    # 声明一个大小为 N 的列表并初始化为 {1, 2, 3, 4}
    arr = [i + 1 for i in range(N)]
    # 调用 print_combination 函数来打印所有可能的组合
    print_combination(arr, N, K)

下面是回溯算法解决组合问题的 Java 程序:
public class Combination {

    public static void main(String[] args) {
        final int N = 4; // 定义数组的大小
        final int K = 2; // 定义组合的大小
        int[] arr = new int[N];

        // 初始化数组为 {1, 2, 3, 4}
        for (int i = 0; i < N; i++) {
            arr[i] = i + 1;
        }

        // 调用 printCombination 函数来打印所有可能的组合
        printCombination(arr, N, K);
    }

    // 用于初始化临时数组并开始组合过程
    public static void printCombination(int[] arr, int n, int k) {
        int[] data = new int[k];

        // 使用回溯生成所有组合
        combinationUtil(arr, data, 0, n, 0, k);
    }

    // 在 arr 数组中的[start, end)下标范围内,找到 k 个可选择的元素
    // data 数组用于存储被选择了的元素
    // index 参数用来统计被选择的元素个数
    public static void combinationUtil(int[] arr, int[] data, int start, int end, int index, int k) {
        // 如果已找到一个完整的组合
        if (index == k) {
            // 打印当前组合
            for (int i = 0; i < k; i++) {
                System.out.print(data[i] + " ");
            }
            System.out.println();
            return;
        }

        // 递归生成所有可能的组合
        for (int i = start; (i < end) && (end - i + 1 >= k - index); i++) {
            data[index] = arr[i]; // 将当前元素加入到当前组合中
            combinationUtil(arr, data, i + 1, end, index + 1, k); // 递归调用以选择下一个元素
        }
    }
}
运行结果,结果为:

1 2
1 3
1 4
2 3
2 4
3 4