发布于  更新于 

CF1209E Rotate Columns

DP 题解 OI

这是一道最近 Codeforces 比赛的题目, 当时在场上昏昏欲睡, 连小的点都没有想出来, 现在再看一下.

题意

给你一个的矩阵, 可以对每一列的元素循环移位, 求每一行的最大值之和的最大值. .

分析

明示状压. 考虑定义状态表示考虑前列, 其中包括的行已经取了最大值的答案, 那么每次就可以枚举的子集来转移, 复杂度, 可以水过小数据.

观察分析发现, 答案至少取到每列按照最大值排序的前个值, 因此枚举状压的部分可以降到, 加上按最大值取前列复杂度是, 大数据可过.

代码

可以用以下方法取的子集(每次去掉最后一个二进制1):

1
2
3
4
for (int k = s; ; k = (k - 1) & s) {
...
if (k == 0) break;
}

用以下方法取的超集:

1
2
3
for (int k = s; k < (1 << n); k = (k + 1) | s) {
...
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <cassert>
#include <cmath>
using namespace std;

typedef long long int64;

const int INF = 0x3f3f3f3f;
const int MAXN = 14;
const int MAXS = (1 << 12) + 10;
const int MAXM = 2010;

int a[MAXN][MAXM];
int mxval[MAXM];
int dp[MAXN][MAXS];
int col[MAXM];

#include <cctype>
#include <cstdio>

template <typename T = int>
inline T read() {
T X = 0, w = 0;
char ch = 0;
while (!isdigit(ch))
{
w |= ch == '-';
ch = getchar();
}
while (isdigit(ch)) {
X = (X << 3) + (X << 1) + (ch ^ 48);
ch = getchar();
}
return w ? -X : X;
}

int val[MAXS]; // 预处理每次循环移位的结果

int main() {
int tcnt = read();
for (int T = 1; T <= tcnt; T++) {
int n = read();
int m = read();
memset(a, 0, sizeof a);
memset(col, 0, sizeof col);
memset(mxval, 0, sizeof mxval);
memset(dp, 0, sizeof dp);
for (int i = 0; i < n; i++) {
for (int j = 1; j <= m; j++) {
a[i][j] = read();
mxval[j] = max(mxval[j], a[i][j]);
}
}
for (int i = 1; i <= m; i++) {
col[i] = i;
}
sort(col + 1, col + m + 1, [](const int& a, const int& b) -> bool {
return mxval[a] > mxval[b];
});
for (int i = 1; i <= min(n, m); i++) {
memset(val, 0, sizeof val);
for (int s = 0; s < (1 << n); s++) {
for (int j = 0; j < n; j++) { // 暴力循环移位
int now = 0;
for (int k = 0; k < n; k++) {
if (s & (1 << k)) {
now += a[(j + k) % n][col[i]];
}
}
val[s] = max(val[s], now);
}
}
memset(dp[i], 0, sizeof dp[i]);
for (int s = 0; s < (1 << n); s++) {
for (int k = s; ; k = (k - 1) & s) {
dp[i][s] = max(dp[i][s], dp[i - 1][k] + val[s ^ k]);
// clog << dp[i][s] << ' ';
if (k == 0) break; // 至少执行一次
}
}
// clog << endl;
}
cout << dp[min(n, m)][(1 << n) - 1] << endl;
}
}