package algs44;
import stdlib.*;
import algs13.Stack;
/* ***********************************************************************
 *  Compilation:  javac FloydWarshall.java
 *  Execution:    java FloydWarshall V E
 *  Dependencies: AdjMatrixEdgeWeightedDigraph.java
 *
 *  Floyd-Warshall all-pairs shortest path algorithm.
 *
 *  % java FloydWarshall 100 500
 *
 *  Should check for negative cycles during triple loop; otherwise
 *  intermediate numbers can get exponentially large.
 *  Reference: "The Floyd-Warshall algorithm on graphs with negative cycles"
 *  by Stefan Hougardy
 *
 *************************************************************************/

public class XFloydWarshall {
	private double[][] distTo;        // distTo[v][w] = length of    shortest v->w path
	private DirectedEdge[][] edgeTo;  // edgeTo[v][w] = last edge on shortest v->w path

	public XFloydWarshall(EdgeWeightedDigraph G) {
		int V = G.V();
		distTo = new double[V][V];
		edgeTo = new DirectedEdge[V][V];

		// initialize distances to infinity
		for (int v = 0; v < V; v++) {
			for (int w = 0; w < V; w++) {
				distTo[v][w] = Double.POSITIVE_INFINITY;
			}
		}

		// initialize distances using edge-weighted digraph's
		for (int v = 0; v < G.V(); v++) {
			for (DirectedEdge e : G.adj(v)) {
				distTo[e.from()][e.to()] = e.weight();
				edgeTo[e.from()][e.to()] = e;
			}
			// in case of self-loops
			if (distTo[v][v] >= 0.0) {
				distTo[v][v] = 0.0;
				edgeTo[v][v] = null;
			}
		}

		// Floyd-Warshall updates
		for (int i = 0; i < V; i++) {
			// compute shortest paths using only 0, 1, ..., i as intermediate vertices
			for (int v = 0; v < V; v++) {
				if (edgeTo[v][i] == null) continue;    // optimization
				for (int w = 0; w < V; w++) {
					if (distTo[v][w] > distTo[v][i] + distTo[i][w]) {
						distTo[v][w] = distTo[v][i] + distTo[i][w];
						edgeTo[v][w] = edgeTo[i][w];
					}
				}
				if (distTo[v][v] < 0.0) return;  // negative cycle
			}
		}
	}

	// is there a negative cycle?
	public boolean hasNegativeCycle() {
		for (int v = 0; v < distTo.length; v++)
			if (distTo[v][v] < 0.0) return true;
		return false;
	}

	// negative cycle
	public Iterable<DirectedEdge> negativeCycle() {
		for (int v = 0; v < distTo.length; v++) {
			// negative cycle in v's predecessor graph
			if (distTo[v][v] < 0.0) {
				int V = edgeTo.length;
				EdgeWeightedDigraph spt = new EdgeWeightedDigraph(V);
				for (int w = 0; w < V; w++)
					if (edgeTo[v][w] != null)
						spt.addEdge(edgeTo[v][w]);
				EdgeWeightedDirectedCycle finder = new EdgeWeightedDirectedCycle(spt);
				assert finder.hasCycle();
				return finder.cycle();
			}
		}
		return null;
	}

	// is there a path from v to w?
	public boolean hasPath(int v, int w) {
		return distTo[v][w] < Double.POSITIVE_INFINITY;
	}


	// return length of shortest path from v to w
	public double dist(int v, int w) {
		return distTo[v][w];
	}

	// return view of shortest path from v to w, null if no such path
	public Iterable<DirectedEdge> path(int v, int w) {
		if (!hasPath(v, w) || hasNegativeCycle()) return null;
		Stack<DirectedEdge> path = new Stack<>();
		for (DirectedEdge e = edgeTo[v][w]; e != null; e = edgeTo[v][e.from()]) {
			path.push(e);
		}
		return path;
	}

	// check optimality conditions
	private boolean check(EdgeWeightedDigraph G, int s) {

		// no negative cycle
		if (!hasNegativeCycle()) {
			for (int v = 0; v < G.V(); v++) {
				for (DirectedEdge e : G.adj(v)) {
					int w = e.to();
					for (int i = 0; i < G.V(); i++) {
						if (distTo[i][w] > distTo[i][v] + e.weight()) {
							System.err.println("edge " + e + " is eligible");
							return false;
						}
					}
				}
			}
		}
		return true;
	}



	public static void main(String[] args) {

		// random graph with V vertices and E edges, parallel edges allowed
		int V = Integer.parseInt(args[0]);
		int E = Integer.parseInt(args[1]);
		EdgeWeightedDigraph G = new EdgeWeightedDigraph(V);
		for (int i = 0; i < E; i++) {
			int v = (int) (V * Math.random());
			int w = (int) (V * Math.random());
			double weight = Math.round(100 * (Math.random() - 0.15)) / 100.0;
			if (v == w) G.addEdge(new DirectedEdge(v, w, Math.abs(weight)));
			else        G.addEdge(new DirectedEdge(v, w, weight));
		}

		StdOut.println(G);

		// run Floyd-Warshall algorithm
		XFloydWarshall spt = new XFloydWarshall(G);

		// print all-pairs shortest path distances
		StdOut.format("     ");
		for (int v = 0; v < G.V(); v++) {
			StdOut.format("%6d ", v);
		}
		StdOut.println();
		for (int v = 0; v < G.V(); v++) {
			StdOut.format("%3d: ", v);
			for (int w = 0; w < G.V(); w++) {
				if (spt.hasPath(v, w)) StdOut.format("%6.2f ", spt.dist(v, w));
				else                   StdOut.format("   Inf ");
			}
			StdOut.println();
		}

		// print negative cycle
		if (spt.hasNegativeCycle()) {
			StdOut.println("Negative cost cycle:");
			for (DirectedEdge e : spt.negativeCycle())
				StdOut.println(e);
			StdOut.println();
		}

		// print all-pairs shortest paths
		else {
			for (int v = 0; v < G.V(); v++) {
				for (int w = 0; w < G.V(); w++) {
					if (spt.hasPath(v, w)) {
						StdOut.format("%d to %d (%5.2f)  ", v, w, spt.dist(v, w));
						for (DirectedEdge e : spt.path(v, w))
							StdOut.print(e + "  ");
						StdOut.println();
					}
					else {
						StdOut.format("%d to %d          no path\n", v, w);
					}
				}
			}
		}

	}

}
