package astar;

import java.util.ListIterator;
import java.util.Iterator;
import java.awt.geom.Point2D;
import java.util.LinkedList;
import java.awt.Point;

public class AStarSearch
{
    private static int base;
    
    static {
        AStarSearch.base = 0;
    }
    
    public static LinkedList<Point> findPath(final Point source, final Point dest, final AStarMap map) {
        final BinaryHeap openList = new BinaryHeap();
        LinkedList<Point> newPath = null;
        final AStarNode sourceNode = map.getNode(source.x, source.y);
        final AStarNode destNode = map.getNode(dest.x, dest.y);
        if (!destNode.passable) {
            return newPath;
        }
        sourceNode.prev = null;
        sourceNode.G = 0;
        sourceNode.H = (int)sourceNode.getMidpoint().distance(destNode.getMidpoint());
        openList.add(sourceNode);
        sourceNode.curList = AStarSearch.base + 1;
        while (destNode.curList != AStarSearch.base + 2 && !openList.isEmpty()) {
            final AStarNode smallest = openList.removeSmallest();
            smallest.curList = AStarSearch.base + 2;
            for (final AStarNode curChild : smallest.neighbors) {
                if (curChild.curList != AStarSearch.base + 2) {
                    if (!curChild.passable) {
                        continue;
                    }
                    if (curChild.curList == AStarSearch.base + 1) {
                        final int newG = smallest.G + (int)smallest.getMidpoint().distance(curChild.getMidpoint());
                        if (newG >= curChild.G) {
                            continue;
                        }
                        curChild.G = newG;
                        curChild.prev = smallest;
                        openList.update(curChild);
                    }
                    else {
                        curChild.prev = smallest;
                        curChild.G = curChild.prev.G + (int)curChild.prev.getMidpoint().distance(curChild.getMidpoint());
                        curChild.H = (int)curChild.getMidpoint().distance(destNode.getMidpoint());
                        curChild.curList = AStarSearch.base + 1;
                        openList.add(curChild);
                    }
                }
            }
        }
        if (destNode.curList == AStarSearch.base + 2) {
            AStarNode cur = destNode;
            newPath = new LinkedList<Point>();
            newPath.addFirst(dest);
            while (cur != null) {
                newPath.addFirst(cur.getMidpoint());
                cur = cur.prev;
            }
            newPath.addFirst(source);
            flattenPath(newPath, map);
        }
        AStarSearch.base += 2;
        return newPath;
    }
    
    private static void flattenPath(final LinkedList<Point> path, final AStarMap map) {
        final int size = path.size();
        if (size < 3) {
            return;
        }
        final ListIterator<Point> iter = path.listIterator();
        Point prev = iter.next();
        iter.next();
        while (iter.hasNext()) {
            final Point next = iter.next();
            if (map.validPath(prev, next)) {
                iter.previous();
                iter.previous();
                iter.remove();
                iter.next();
            }
            if (iter.hasNext()) {
                prev = next;
                iter.next();
            }
        }
        if (path.size() == size) {
            return;
        }
        flattenPath(path, map);
    }
}
