package algs52; // section 5.2
import stdlib.*;
import algs13.Queue;
/* ***********************************************************************
 *  Compilation:  javac TrieST.java
 *  Execution:    java TrieST < words.txt
 *  Dependencies: StdIn.java
 *
 *  A string symbol table for ASCII strings, implemented using a 256-way trie.
 *
 *  % java TrieST < shellsST.txt
 *  by 4
 *  sea 6
 *  sells 1
 *  she 0
 *  shells 3
 *  shore 7
 *  the 5
 *
 *************************************************************************/

public class TrieST<V> {
	private static final int R = 256;        // extended ASCII

	private Node<V> root = new Node<>();

	private static class Node<V> {
		public Node() { }
		public V val;
		@SuppressWarnings("unchecked")
		public final Node<V>[] next = new Node[R];
	}

	/* **************************************************
	 * Is the key in the symbol table?
	 ****************************************************/
	public boolean contains(String key) {
		return get(key) != null;
	}

	public V get(String key) {
		Node<V> x = get(root, key, 0);
		if (x == null) return null;
		return x.val;
	}

	private Node<V> get(Node<V> x, String key, int d) {
		if (x == null) return null;
		if (d == key.length()) return x;
		char c = key.charAt(d);
		return get(x.next[c], key, d+1);
	}

	/* **************************************************
	 * Insert key-value pair into the symbol table.
	 ****************************************************/
	public void put(String key, V val) {
		root = put(root, key, val, 0);
	}

	private Node<V> put(Node<V> x, String key, V val, int d) {
		if (x == null) x = new Node<>();
		if (d == key.length()) {
			x.val = val;
			return x;
		}
		char c = key.charAt(d);
		x.next[c] = put(x.next[c], key, val, d+1);
		return x;
	}

	// find the key that is the longest prefix of s
	public String longestPrefixOf(String query) {
		int length = longestPrefixOf(root, query, 0, 0);
		return query.substring(0, length);
	}

	// find the key in the subtrie rooted at x that is the longest
	// prefix of the query string, starting at the dth character
	private int longestPrefixOf(Node<V> x, String query, int d, int length) {
		if (x == null) return length;
		if (x.val != null) length = d;
		if (d == query.length()) return length;
		char c = query.charAt(d);
		return longestPrefixOf(x.next[c], query, d+1, length);
	}


	public Iterable<String> keys() {
		return keysWithPrefix("");
	}

	public Iterable<String> keysWithPrefix(String prefix) {
		Queue<String> queue = new Queue<>();
		Node<V> x = get(root, prefix, 0);
		collect(x, prefix, queue);
		return queue;
	}

	private void collect(Node<V> x, String key, Queue<String> queue) {
		if (x == null) return;
		if (x.val != null) queue.enqueue(key);
		for (int c = 0; c < R; c++)
			collect(x.next[c], key + (char) c, queue);
	}


	public Iterable<String> keysThatMatch(String pat) {
		Queue<String> q = new Queue<>();
		collect(root, "", pat, q);
		return q;
	}

	private void collect(Node<V> x, String prefix, String pat, Queue<String> q) {
		if (x == null) return;
		if (prefix.length() == pat.length() && x.val != null) q.enqueue(prefix);
		if (prefix.length() == pat.length()) return;
		char next = pat.charAt(prefix.length());
		for (int c = 0; c < R; c++)
			if (next == '.' || next == c)
				collect(x.next[c], prefix + (char) c, pat, q);
	}

	public void delete(String key) {
		root = delete(root, key, 0);
	}

	private Node<V> delete(Node<V> x, String key, int d) {
		if (x == null) return null;
		if (d == key.length()) x.val = null;
		else {
			char c = key.charAt(d);
			x.next[c] = delete(x.next[c], key, d+1);
		}
		if (x.val != null) return x;
		for (int c = 0; c < R; c++)
			if (x.next[c] != null)
				return x;
		return null;
	}


	// test client
	public static void main(String[] args) {
		StdIn.fromFile("data/shellsST.txt");

		// build symbol table from standard input
		TrieST<Integer> st = new TrieST<>();
		for (int i = 0; !StdIn.isEmpty(); i++) {
			String key = StdIn.readString();
			st.put(key, i);
		}

		// print results
		for (String key : st.keys()) {
			StdOut.println(key + " " + st.get(key));
		}
	}
}
