算法 - 石子合并


来源: https://www.acwing.com/problem/content/284/

描述:

设有N堆石子排成一排,其编号为1,2,3,…,N。

每堆石子有一定的质量,可以用一个整数来描述,现在要将这N堆石子合并成为一堆。

每次只能合并相邻的两堆,合并的代价为这两堆石子的质量之和,合并后与这两堆石子相邻的石子将和新堆相邻,合并时由于选择的顺序不同,合并的总代价也不相同。

例如有4堆石子分别为 1 3 5 2, 我们可以先合并1、2堆,代价为4,得到4 5 2, 又合并 1,2堆,代价为9,得到9 2 ,再合并得到11,总代价为4+9+11=24;

如果第二步是先合并2,3堆,则代价为7,得到4 7,最后一次合并代价为11,总代价为4+7+11=22。

问题是:找出一种合理的方法,使总的代价最小,输出最小代价。


首先是因为是相邻的区间进行合并,所以就是一个分界点,分成两个部分,一个部分左边,一部分右边。这样就符合题意的相邻两个部分进行合并,这样只用以2堆,3堆,4堆,一直到长度为n堆进行合并就行了。因为是区间DP问题,所以状态表示为$f[i][j]$。

区间dp分析:

状态表示$f[i][j]$

  • 集合: 区间i ~ j 合并所需要的代价
  • 属性: Min

状态计算:

  • 以k为分界点,把区间分成两部分,k可以取值1, 2, 3, 4…, k - 1
  • 合并的代价就是区间的长度的价值
  • 状态转移方程就是: $f[i][j] = min(f[i][k] + f[k + 1][j] + s[j] - s[i - 1]), k = 1, 2, 3, 4…, k - 1$

写区间dp问题,注意就是区间的写法。如何写区间,调整区间。

伪代码:

1
2
3
4
5
6
7
8
9
10
for (len = 2; len <= n; len++)
for (i = 1; i + len - 1 <= n; i++)
// 枚举所有区间长度,从2开始
left = i, right = i + len - 1
for(k = 1; k < right; k++)
//状态转移方程
f[left][right] = min(f[left][k] + f[k + 1][right] + s[right] - s[left - 1])

//答案 区间 1 ~ n
f[1][n]

C++实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <iostream>
#include <algorithm>

using namespace std;

const int N = 310;

int n;
int s[N];
int f[N][N];

int main() {
cin >> n;
for(int i = 1; i <= n; i++) cin >> s[i];
for(int i = 1; i <= n; i++) s[i] += s[i - 1];

for(int len = 2; len <= n; len++)
for(int i = 1; i + len - 1 <= n; i++) {
int l = i, r = i + len - 1;
f[l][r] = 1e8;
for(int k = l; k < r; k++)
f[l][r] = min(f[l][r], f[l][k] + f[k + 1][r] + s[r] - s[l - 1]);
}

cout << f[1][n] << endl;
return 0;
}

Python 实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

N = 310
def main():
# f 表示状态, s表示数据和
f = [[0]*N for _ in range(N)]
s = [0]*N
n = int(input())
# 写入数据从角标1开始
s[1:] = map(int, input().split())

for i in range(1, n + 1):
s[i] += s[i - 1]


for len in range(2, n + 1):
for i in range(1, n):
if i + len - 1 <= n:
l, r = i, i + len - 1
f[l][r] = 1e8
for k in range(l, r):
# 状态转移方程
f[l][r] = min(f[l][r], f[l][k] + f[k + 1][r] + s[r] - s[l - 1])

# 表示从1 ~ n区间
print(f[1][n])

main()