package writerIdentification;

import java.util.ArrayList;
import java.util.List;

import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.TermCriteria;
import org.opencv.ml.KNearest;
import org.opencv.ml.Ml;
import org.opencv.ml.SVM;
import org.opencv.utils.Converters;

public class Training {
	
	private double correctAnswerNum;
	
	private SVM svm;
	private KNearest knn;
	private int numberOfNeighbors = 1;
	
	private int trainingImgNum = 2;
	
	private Mat imageResults = new Mat(0, 0, CvType.CV_32F);
	private List<Integer> resultLabels = new ArrayList<Integer>();
	
	private Mat trainingList = new Mat(0, 0, CvType.CV_32F);
	private Mat testingList = new Mat(0, 0, CvType.CV_32F);
	
	
	
	//Creates the labels for SVM and KNN training
	public void createLabels() {
		for(int i = 0; i < trainingList.rows(); i++) {
			resultLabels.add(i/trainingImgNum);
		}
	}
	
	//Creates an KNN and trains it
	public void kNearestNeighbors() {
		
		try {

			knn = KNearest.create();
			knn.train(trainingList, Ml.ROW_SAMPLE, Converters.vector_int_to_Mat(resultLabels));
			
			correctAnswerNum = 0;
			for (int i=0; i<testingList.rows(); i++)
	        {
	            Mat one_feature = testingList.row(i);
	            int testLabel = resultLabels.get(i);
	
	            Mat res= new Mat();
	            Mat neighbours= new Mat();
	            Mat dist= new Mat();
	            
	            float p = knn.findNearest(one_feature, numberOfNeighbors, res,neighbours,dist);
	            System.out.println(testLabel + " " + p);

	            if(testLabel == res.get(0, 0)[0]) {
	            	correctAnswerNum++;
	            }

	        }

			System.out.println(correctAnswerNum);
			correctAnswerNum = (correctAnswerNum*100) / testingList.rows();
		}catch (Exception e) {
			System.out.println("Error at KNN training");
		}
		
	}
	
	//Creates an SVM and trains it
	public void supportVectorMachines() {
		
		try {
			svm = SVM.create();
			svm.setKernel(SVM.INTER);
			svm.setType(SVM.C_SVC);
			svm.setTermCriteria(new TermCriteria(TermCriteria.MAX_ITER, 100, 1e-6));
			
			svm.train(trainingList, Ml.ROW_SAMPLE, Converters.vector_int_to_Mat(resultLabels));
			
			Mat results = new Mat();
				
			svm.predict(testingList, results, 0);
			
			correctAnswerNum = 0;
			for (int i=0; i<results.rows(); i++){
				if(results.get(i, 0)[0] == resultLabels.get(i)) {
					correctAnswerNum++;
	            }
				System.out.println(i/2+". - " + results.get(i, 0)[0]);

	        }
			System.out.println(correctAnswerNum);
			correctAnswerNum = (correctAnswerNum*100) / results.rows();
		}catch (Exception e) {
			System.out.println("Error at SVM training");
		}
		
	}
	
	public double getCorrectPercent() {
		return correctAnswerNum;
	}

	// Resets the values
	public void setToDefault() {
		this.imageResults = new Mat(0, 0, CvType.CV_32F);
		this.resultLabels = new ArrayList<Integer>();
		this.trainingList = new Mat(0, 0, CvType.CV_32F);
		this.testingList = new Mat(0, 0, CvType.CV_32F);
	}

	public Mat getImageResults() {
		return imageResults;
	}
	
	//Sorts the pictures in to the Training and Testing lists
	public void setImageResults(List<List<String>> records) {
		
		Mat row;
		
		//String[] trainingPictureLabels = {"_1.","_2."};
		//String[] trainingPictureLabels = {"_1.","_3."};
		String[] trainingPictureLabels = {"_1.","_4."};
		//String[] trainingPictureLabels = {"_2.","_3."};
		//String[] trainingPictureLabels = {"_2.","_4."};
		//String[] trainingPictureLabels = {"_3.","_4."};
		
		for(int i=0 ; i<records.size(); i++) {
			
			row = new Mat(1, records.get(0).size()-1, CvType.CV_32F);
			
			for(int j=1;j<records.get(0).size();j++) {
				row.put(0, j-1, Double.parseDouble(records.get(i).get(j)));
			}
			
			if(records.get(i).get(0).contains(trainingPictureLabels[0]) || records.get(i).get(0).contains(trainingPictureLabels[1])) {
				this.trainingList.push_back(row);
			}else {
				this.testingList.push_back(row);
			}
			
		}

	}
	
}
