计蒜客-圣诞树

圣诞节快到了,蒜头君准备做一棵大圣诞树。

这棵树被表示成一组被编号的结点和一些边的集合,树的结点从 1 到 n 编号,树的根永远是 1。每个结点都有一个自身特有的数值,称为它的权重,各个结点的权重可能不同。对于一棵做完的树来说,每条边都有一个价值 v_e,若设这条边 e 连接结点 i 和结点 j,且 i 为 j 的父结点(根是最老的祖先),则该边的价值 v_e = s_j × w_e,s_j表示结点 j 的所有子孙及它自己的权重之和,w_e 表示边 e 的权值。

现在蒜头君想造一棵树,他有 m 条边可以选择,使得树上所有边的总价值最小,并且所有的点都在树上,因为蒜头君喜欢大树。

输入格式

第一行输入两个整数 n和 m(0 ≤ n,m ≤50,000),表示结点总数和可供选择的边数。

接下来输入一行,输入 n 个整数,依次表示每个结点的权重。

接下来输入 m 行,每行输入 3 个正整数 a, b, c(1≤a,b,≤n,1≤c≤10,000),表示结点 a 和结点 b 之间有一条权值为 c 的边可供造树选择。

输出格式

输出一行,如果构造不出这样的树,请输出No Answer,否则输出一个整数,表示造树的最小价值。

样例输入

4 4
10 20 30 40
1 2 3
2 3 2
1 3 5
2 4 1

样例输出

370

#include <iostream>
#include <cstring>
#include <algorithm>
#include <set>

using namespace std;

const int N = 50000;
const int M = 50000;
const int INF = 0x3f3f3f3f;

struct edge {
    int v, w, next;
    edge() {}
    edge(int _v, int _w, int _next): v(_v), w(_w), next(_next) {}
} E[M << 1];

int p[N];
int cnt;

void init() {
    memset(p, -1, sizeof p);
    cnt = 0;
}

void insert(int u, int v, int w) {
    E[cnt] = edge(v, w, p[u]);
    p[u] = cnt++;
}

void insert2(int u, int v, int w) {
    insert(u, v, w);
    insert(v, u, w);
}

int dis[N];
bool vis[N];
int weight[N];
typedef pair<int, int> PII;
set<PII, less<PII> > min_heap;
int n, m;

bool dijstra(int u) {
    memset(dis, 0x3f, sizeof dis);
    memset(vis, false, sizeof vis);
    dis[u] = 0;
    min_heap.insert(make_pair(0, u));

    for (int i = 0; i < n; i++) {
        if (min_heap.size() == 0) {
            return false;
        }

        set<PII, less<PII> >::iterator it = min_heap.begin();
        int min_v = it->second;
        vis[min_v] = true;
        min_heap.erase(*it);

        for (int e = p[min_v]; e != -1; e = E[e].next) {
            int v = E[e].v;
            int w = E[e].w;
            if (!vis[v] && dis[min_v] + w < dis[v]) {
                min_heap.erase(make_pair(dis[v], v));
                dis[v] = dis[min_v] + w;
                min_heap.insert(make_pair(dis[v], v));
            }
        }
    }
    return true;
}

int main() {
    cin >> n >> m;

    for (int i = 1; i <= n; i++) {
        cin >> weight[i];
    }

    init();
    int a, b, c;
    for (int i = 1; i <= m; i++) {
        cin >> a >> b >> c;
        insert2(a, b, c);
    }
    if (dijstra(1) == false) {
        cout << "No Answer" << endl;
    } else {
        long long ans = 0;
        for (int i = 1; i <= n; i++) {
            ans += dis[i] * weight[i];
        }
        cout << ans << endl;
    }

    return 0;
}