做《算法进阶》时,我遇到了我从未涉及到的概率问题 研究了很久,终于学会概率dp和期望值 看一下这道题 【题意】 给定一个无向连通图,其节点编号为1到N,其边的权值为非负整数。 试求出一条从1号节点都N号节点的路径,使得该路径上经过的边的权值的XOR和最大。 该路径可以重复经过某些节点或边,当一条边在路径中出现多次时,其权值在计算XOR和时也应被重复计算相应多的次数。 直接求解上述问题比较困难,于是你决定使用非完美算法。 具体来说,从1号节点开始,以相等的概率,随机选择与当前节点相关联的某条边,并沿着这条边走到下一个节点,重复这个过程直到走到N号节点为止,便得到一条从1号节点到N号节点的路径。 显然得到每条这样的路径的概率是不同的,并且每条这样的路径的XOR和也不一样。 现在请你求出该算法得到的路径的XOR和的期望值。 【输入格式】 第一行包含两个整数N和M,表示节点数和边数。 接下来M行,每行包含三个整数u,v,w,表示存在一条边(u,v),权值为w。 图中可能存在重边或自环。 【输出格式】 输出包含一个实数,表示XOR和的期望值,结果保留三位小数。 【数据范围】 2≤N≤100, M≤10000, 1≤u,v≤N, 0≤w≤10^9 【输入样例】 2 2 1 1 2 1 2 3 【输出样例】 2.333
因为不能直接将这个期望值当成一个整体来算,所以我们将它拆分。 期望值其实就是平均数,它=所有的(可能值X概率) 又因为是异或运算,所以我们不妨将它拆分为二进制的每一个位 先设f[i]为i位是1的概率。 所以 解释一下第一条公式: 先是枚举每一条与u相连的边(连到v) 如果这条边二进制拆分以后所求的一位是1 就要加上v点0的概率,也就是1-f[v] 否则就是要加上v点1的概率,也就是f[v] 只有这样,我们才能使u是1 参考代码(有修改):
//Author:XuHt #include <cmath> #include <cstdio> #include <vector> #include <cstring> #include <iostream> using namespace std; const int N = 106; int n, m; double a[N][N], b[N], ans; vector<pair<int, int> > e[N]; void work() { for (int i = 1; i < n; i++) { /* int now = i;算法进阶里的标程有这一句,实际上可以不用 for (int j = i + 1; j < n; j++) if (fabs(a[j][i]) > fabs(a[now][i])) now = j; for (int j = 0; j <= n; j++) swap(a[i][j], a[now][j]);*/ for (int j = i + 1; j <= n; j++) { double rate = a[j][i] / a[i][i]; for (int k = 0; k <= n; k++) a[j][k] = a[i][k] * rate - a[j][k]; } } for (int i = n; i; i--) { for (int j = i + 1; j <= n; j++) a[i][0] -= a[i][j] * b[j]; b[i] = a[i][0] / a[i][i]; } } int main() { cin >> n >> m; for (int i = 1; i <= m; i++) { int x, y, z; scanf("%d %d %d", &x, &y, &z); e[x].push_back(make_pair(y, z)); if (x != y) e[y].push_back(make_pair(x, z)); } for (int i = 0; i < 31; i++) { memset(a, 0, sizeof(a)); memset(b, 0, sizeof(b)); //高斯消元数组的构造(第三条公式除以dg[u]) for (int x = 1; x <= n; x++) a[x][x] = 1; for (int x = 1; x < n; x++) { int s = e[x].size(); for (int j = 0; j < s; j++) { int y = e[x][j].first, z = e[x][j].second; double w = 1.0 / s; if ((z >> i) & 1) { a[x][y] += w; a[x][0] += w; } else a[x][y] -= w; } } work(); ans += b[1] * (1 << i); } printf("%.3f\n", ans); return 0; }