package main;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Vector;

public class MethodDependencyChecker {
	
	private HashMap<FullMethodID,BytecodeModel[]> methods = new HashMap<FullMethodID,BytecodeModel[]>();
	private final HashMap<ClassID,ClassModel> classes;
	
	
	public MethodDependencyChecker(HashMap<ClassID,ClassModel> classes) {
		this.classes = classes;
	}
	
	public void addMethod(ClassID classID, MethodID methodID, BytecodeModel[] bytecodes) {
		methods.put(new FullMethodID(classID,methodID), bytecodes);
	}
	
	public boolean hasDependencyLoop() {
		HashSet<FullMethodID> methodsChecked;
		for (Map.Entry<FullMethodID, BytecodeModel[]> entry : methods.entrySet()) {
			methodsChecked = new HashSet<FullMethodID>();
			methodsChecked.add(entry.getKey());
			if (dependencyLoopSearch(methodsChecked, entry.getKey())) {
				return true;
			}
		}
		return false;
		
	}
	
	private boolean dependencyLoopSearch(HashSet<FullMethodID> alreadyChecked, FullMethodID currentMethod) {
		if (!methods.containsKey(currentMethod)) {
			return false;
		}
		
		HashSet<Integer> methodCPIndices = new HashSet<Integer>();
		for (BytecodeModel bytecode : methods.get(currentMethod)) {
			if (bytecode instanceof BytecodeModel.INVOKESPECIAL) {
				methodCPIndices.add(((BytecodeModel.INVOKESPECIAL) bytecode).getMethodIndex());
			} else if (bytecode instanceof BytecodeModel.INVOKESTATIC) {
				methodCPIndices.add(((BytecodeModel.INVOKESTATIC) bytecode).getMethodIndex());
			} else if (bytecode instanceof BytecodeModel.INVOKEVIRTUAL) {
				methodCPIndices.add(((BytecodeModel.INVOKEVIRTUAL) bytecode).getMethodIndex());
			}
		}
		
		for (int cpIndex : methodCPIndices) {
			CPEntry cpEntry = classes.get(currentMethod.classID).getCPEntry(cpIndex);
			ClassID classID = ((CPEntry.MethodRef) cpEntry).getClassID();
			MethodID methodID = ((CPEntry.MethodRef) cpEntry).getMethodID();
			FullMethodID fullMethodID = new FullMethodID(classID, methodID);
			if (alreadyChecked.contains(fullMethodID)) {
				return true;
			}
			@SuppressWarnings("unchecked")
			HashSet<FullMethodID> newAlreadyChecked = (HashSet<FullMethodID>) alreadyChecked.clone();
			newAlreadyChecked.add(fullMethodID);
			if (dependencyLoopSearch(newAlreadyChecked, fullMethodID)) {
				return true;
			}
		}
		
		return false;
		
	}
	
	public FullMethodID[] dependencyOrder() {
		Vector<FullMethodID> order = new Vector<FullMethodID>();
		
		while (!order.containsAll(methods.keySet())) {
			boolean methodAdded = false;
			for (FullMethodID currentMethod : methods.keySet()) {
				if (!order.contains(currentMethod)) {
					System.out.println("Checking method calls in " + currentMethod + "...");
					//HashSet<Integer> methodCPIndices = new HashSet<Integer>();
					HashSet<FullMethodID> methodsCalled = new HashSet<FullMethodID>();
					for (BytecodeModel bytecode : methods.get(currentMethod)) {
						if (bytecode instanceof BytecodeModel.INVOKESPECIAL) {
							// TODO: need to handle special case of invokespecial
							int cpIndex = ((BytecodeModel.INVOKESPECIAL) bytecode).getMethodIndex();
							CPEntry cpEntry = classes.get(currentMethod.classID).getCPEntry(cpIndex);
							ClassID classID = ((CPEntry.MethodRef) cpEntry).getClassID();
							MethodID methodID = ((CPEntry.MethodRef) cpEntry).getMethodID();
							
							if (classes.get(classID) == null) {
								System.out.println("classID = " + classID);
								System.out.println("classes.get(classID) is null");
							}
							
							if (classes.get(classID).getMethods().contains(methodID)) {
								// only add a real (non-abstract, non-native) method
								methodsCalled.add(new FullMethodID(classID, methodID));
							}
						} else if (bytecode instanceof BytecodeModel.INVOKESTATIC) {
							int cpIndex = ((BytecodeModel.INVOKESTATIC) bytecode).getMethodIndex();
							CPEntry cpEntry = classes.get(currentMethod.classID).getCPEntry(cpIndex);
							ClassID classID = ((CPEntry.MethodRef) cpEntry).getClassID();
							MethodID methodID = ((CPEntry.MethodRef) cpEntry).getMethodID();
							if (classes.get(classID).getMethods().contains(methodID)) {
								// only add a real (non-abstract, non-native) method
								methodsCalled.add(new FullMethodID(classID, methodID));
							}
						} else if (bytecode instanceof BytecodeModel.INVOKEVIRTUAL) {
							int cpIndex = ((BytecodeModel.INVOKEVIRTUAL) bytecode).getMethodIndex();
							CPEntry cpEntry = classes.get(currentMethod.classID).getCPEntry(cpIndex);
							ClassID classID = ((CPEntry.MethodRef) cpEntry).getClassID();
							MethodID methodID = ((CPEntry.MethodRef) cpEntry).getMethodID();

							System.out.println("List of invokevirtual targets for" + new FullMethodID(classID, methodID) + ":");
							// enumerate targets for invokevirtual instructions
							for (ClassID target : classes.keySet()) {
								if (isSubclass(target, classID) && classes.get(target).getMethods().contains(methodID)) {
									methodsCalled.add(new FullMethodID(target, methodID));
									System.out.println("\t" + target);
								}
							}
						}
					}
					
					System.out.println("List of methods called:");
					for (FullMethodID method : methodsCalled) {
						System.out.println("\t" + method);
					}
					if (order.containsAll(methodsCalled)) {
						order.add(currentMethod);
						System.out.println(currentMethod + " added to dependency order");
						methodAdded = true;
					}
				}
			}
			if (!methodAdded) {
				System.out.println("No more methods added in this iteration");
			}
		}
		
		return order.toArray(new FullMethodID[] {});
	}
	
	private boolean isSubclass(ClassID childID, ClassID parentID) {
		if (childID.equals(parentID)) {
			return true;
		}
		if (classes.containsKey(childID) && classes.get(childID).hasSuper()) {
			return isSubclass(classes.get(childID).getSuperName(), parentID);
		}
		return false;
	}
}
