import java.util.Random;
import phw.util.Ask;
import jcsp.lang.*;

public class TestMatrix {

  public static void main (String[] args) {

    long seed = Ask.Int ("seed = ", 1, 100);
    Random random = new Random (seed);

    double limit = Ask.Int ("limit = ", 1, 2000);

    int A = Ask.Int ("A = ", 3, 4096);
    int B = Ask.Int ("B = ", 3, 4096);
    int C = Ask.Int ("C = ", 3, 4096);

    double[][] X = new double[A][B];
    double[][] Y = new double[B][C];
    double[][] Z = new double[A][C];
    double[][] ZZ = new double[A][C];

    System.out.println ("Initialising arrays X[A][B], Y[B][C] ...");
    Matrix.randomise (X, limit, random);
    Matrix.randomise (Y, limit, random);

    System.out.println ("Printing array X[A][B], ...");
    Matrix.print (X, 3, 3);
    System.out.println ("\nPrinting array Y[B][C] ...");
    Matrix.print (Y, 3, 3);

    CSTimer tim = new CSTimer ();
    long t0, t1, time;

    System.out.println ("\nMultiplying X*Y --> ZZ ...");
    t0 = tim.read ();
    Matrix.multiply (X, Y, ZZ);
    t1 = tim.read ();
    time = t1 - t0;
    System.out.println ("Completed in " + time + " milliseconds ...");
    System.out.println ("\nPrinting array ZZ[A][C] ...");
    Matrix.print (ZZ, 3, 3);

    int nBenchmarks = Ask.Int ("\nnumber of benchmarks = ", 1, 2000);

    ////////////////////////////////////////////////////////////////////////////////////

    double sum = 0.0d, sumsq = 0.0d;

    for (int i = 0; i < nBenchmarks; i++) {
      System.out.println ("\n(SEQ) Multiplying X*Y --> Z ...");
      t0 = tim.read ();
      Matrix.seqMultiply (X, Y, Z);
      t1 = tim.read ();
      time = t1 - t0;
      System.out.println ("[" + (i + 1) + "/" + nBenchmarks +
                          "] Completed in " + time + " milliseconds ...");
      sum += time;
      sumsq += time*time;
    }

    System.out.println ("\nPrinting array Z[A][C] ...");
    Matrix.print (Z, 3, 3);

    System.out.println ("\nChecking array Z[A][C] against ZZ[A][C] ...");
    if (Matrix.same (Z, ZZ)) {
      System.out.println ("... checked OK");
    } else {
      System.out.println ("... check FAILED");
    }

    double mean = sum/nBenchmarks;
    double top = sumsq - ((sum*sum)/nBenchmarks);
    double stdev = Math.sqrt(top/(nBenchmarks - 1));

    System.out.println ("\n[" + A + "][" + B + "] * [" + B + "][" + C + "] ==> [" + A + "][" + C + "]");
    System.out.println ("number of benchmarks = " + nBenchmarks);
    System.out.println ("mean = " + mean);
    System.out.println ("standard deviation = " + stdev);

    Matrix.randomise (Z, limit, random);

    ////////////////////////////////////////////////////////////////////////////////////

    sum = 0.0d; sumsq = 0.0d;

    for (int i = 0; i < nBenchmarks; i++) {
      System.out.println ("\n(PAR) Multiplying X*Y --> Z ...");
      t0 = tim.read ();
      Matrix.parMultiply (X, Y, Z);
      t1 = tim.read ();
      time = t1 - t0;
      System.out.println ("[" + (i + 1) + "/" + nBenchmarks +
                          "] Completed in " + time + " milliseconds ...");
      sum += time;
      sumsq += time*time;
    }

    System.out.println ("\nPrinting array Z[A][C] ...");
    Matrix.print (Z, 3, 3);

    System.out.println ("\nChecking array Z[A][C] against ZZ[A][C] ...");
    if (Matrix.same (Z, ZZ)) {
      System.out.println ("... checked OK");
    } else {
      System.out.println ("... check FAILED");
    }

    mean = sum/nBenchmarks;
    top = sumsq - ((sum*sum)/nBenchmarks);
    stdev = Math.sqrt(top/(nBenchmarks - 1));

    System.out.println ("\n[" + A + "][" + B + "] * [" + B + "][" + C + "] ==> [" + A + "][" + C + "]");
    System.out.println ("number of benchmarks = " + nBenchmarks);
    System.out.println ("mean = " + mean);
    System.out.println ("standard deviation = " + stdev);

    Matrix.randomise (Z, limit, random);

    ////////////////////////////////////////////////////////////////////////////////////

    Parallel par = Matrix.makeParMultiply (X, Y, Z);

    sum = 0.0d; sumsq = 0.0d;

    for (int i = 0; i < nBenchmarks; i++) {
      System.out.println ("\n(reuse-PAR) Multiplying X*Y --> Z ...");
      t0 = tim.read ();
      par.run ();
      t1 = tim.read ();
      time = t1 - t0;
      System.out.println ("[" + (i + 1) + "/" + nBenchmarks +
                          "] Completed in " + time + " milliseconds ...");
      sum += time;
      sumsq += time*time;
    }

    System.out.println ("\nPrinting array Z[A][C] ...");
    Matrix.print (Z, 3, 3);

    System.out.println ("\nChecking array Z[A][C] against ZZ[A][C] ...");
    if (Matrix.same (Z, ZZ)) {
      System.out.println ("... checked OK");
    } else {
      System.out.println ("... check FAILED");
    }

    mean = sum/nBenchmarks;
    top = sumsq - ((sum*sum)/nBenchmarks);
    stdev = Math.sqrt(top/(nBenchmarks - 1));

    System.out.println ("\n[" + A + "][" + B + "] * [" + B + "][" + C + "] ==> [" + A + "][" + C + "]");
    System.out.println ("number of benchmarks = " + nBenchmarks);
    System.out.println ("mean = " + mean);
    System.out.println ("standard deviation = " + stdev);

    Matrix.randomise (Z, limit, random);

  }

}

