package linMap.geneNetwork;

import java.util.ArrayList;
import java.util.Random;

import junit.framework.TestCase;
import linMap.geneNetwork.NetTools;
import linMap.geneNetwork.Network;

public class NetworkTest extends TestCase {

	private Network network;
	private int inputId;
	private int outputId;
	
	public static final int INPUT_SIZE=3;
	public static final int OUTPUT_SIZE=2;
	
	double[][] weightValues;
	
	protected void setUp() throws Exception {
		super.setUp();
		network = new Network();
		inputId = network.addLayer(INPUT_SIZE, Network.LayerType.INPUT, false);
		outputId = network.addLayer(OUTPUT_SIZE, Network.LayerType.SIGMOID, true);
		weightValues = new double[OUTPUT_SIZE][INPUT_SIZE];
		for (int i=0; i<OUTPUT_SIZE; i++) {
			for (int j=0; j<INPUT_SIZE; j++) {
				weightValues[i][j] = i*3+j;
			}
		}
	}

	public void testClone() {
		network.linkLayers(inputId, outputId, weightValues);
		Network clonedNetwork = (Network)this.network.clone();
		assertEquals(network.toString(), clonedNetwork.toString());
		String originalString = network.toString();
		int newLayerId = clonedNetwork.addLayer(4, Network.LayerType.SIGMOID, true);
		clonedNetwork.linkLayers(outputId, newLayerId);
		assertEquals(originalString, network.toString());
//		assertEquals(originalString, clonedNetwork.toString());	// SHOULD FAIL!
		double[] input = { 0.1, 0.2, 0.3 };
		clonedNetwork.activate(input);
		assertEquals(originalString, network.toString());
		clonedNetwork = null;
		assertEquals(originalString, network.toString());
		
		Network network2 = (Network)this.network.clone();
		network.activate(input);
		network2.activate(input);		
		network.activate(input);
		network2.activate(input);
		System.out.println(network);
		System.out.println(network2);
		assertEquals(network.toString(), network2.toString());
		
		Random generator = new Random(123);
		
		Network netA = NetTools.createSimpleRecurrentNet(2,2,2,generator,2.0);
		Network netB = (Network)netA.clone();
		System.out.println(netA);
		System.out.println(netB);	
		netA.activate();
		netB.activate();
		System.out.println(netA);
		System.out.println(netB);
		assertEquals(netA.toString(), netB.toString());		
		
	}
	
	/*
	 * Test method for 'linMap.geneNetwork.Network.addLayer(int, LayerType, boolean)'
	 */
	public void testAddLayer() {
		assertEquals(2, network.getLayerCount());
		assertEquals(INPUT_SIZE+OUTPUT_SIZE, network.getNodeCount());
		assertEquals(1, outputId);
	}

	/*
	 * Test method for 'linMap.geneNetwork.Network.linkLayers(int, int)'
	 */
	public void testLinkLayersIntInt() {
		network.linkLayers(inputId, outputId);
		assertEquals(INPUT_SIZE*OUTPUT_SIZE, network.getWeightCount());
	}

	/*
	 * Test method for 'linMap.geneNetwork.Network.linkLayers(int, int, double[][])'
	 */
	public void testLinkLayersIntIntDoubleArray() {
		network.linkLayers(inputId, outputId, weightValues);
		assertEquals(INPUT_SIZE*OUTPUT_SIZE, network.getWeightCount());
	}
	
	/*
	 * Test method for 'linMap.geneNetwork.Network.setAllActivations(double)'
	 */
	public void testSetAllActivations() {
		double value = 0.89;
		network.setAllActivations(value);
		for (int i=0; i<OUTPUT_SIZE; i++) {
			assertEquals(value, network.getOutput()[i]);
		}
	}

	/*
	 * Test method for 'linMap.geneNetwork.Network.activate()'
	 */
	public void testActivate() {
		network.linkLayers(inputId, outputId, weightValues);
		double[] input = { 0.1, 0.2, 0.3 };
		network.setInput(input);
		network.activate();
		double[] targetOutput = { 0.689974481, 0.93086158 };
		for (int i=0; i<OUTPUT_SIZE; i++) {
			assertEquals(targetOutput[i], this.network.getOutput()[i], 1e-8);
		}
		
		double[][] newWeightValues = { { 1.0, 1.0 }, { 1.0, 1.0 } };
		network.linkLayers(outputId, outputId, newWeightValues);
		double[] newInput = { 0.0, 0.0, 0.0 };
		network.setInput(newInput);
		network.activate();
		double[] newTargetOutput = { 0.834910401, 0.834910401 };
		for (int i=0; i<OUTPUT_SIZE; i++) {
			//System.out.println(network.getOutput()[i]);
			assertEquals(newTargetOutput[i], this.network.getOutput()[i], 1e-8);
		}			
		
	}

	/*
	 * Test method for 'linMap.geneNetwork.Network.activate(double[])'
	 */
	public void testActivateDoubleArray() {
		network.linkLayers(inputId, outputId, weightValues);	
		double[] input = { 0.1, 0.2, 0.3 };
		double[] output = network.activate(input);
		double[] targetOutput = { 0.689974481, 0.93086158 };
		for (int i=0; i<OUTPUT_SIZE; i++) {
			assertEquals(targetOutput[i], output[i], 1e-8);
		}	
	}

	public void testActivateDoubleArrayArray() {
		Network netA = NetTools.createSimpleRecurrentNet(2,2,2,1234,2.0);
		ArrayList<double[]> state = new ArrayList<double[]>(2);
		double[] input = { 0.0, 1.0 };
		double[] hidden = { 1.5, 2.0 };
		state.add(input);
		state.add(hidden);
		for (int i=0; i<state.size(); i++) {
			for (int j=0; j<state.get(0).length; j++) {
				System.out.print(state.get(i)[j] + " ");
			}
			System.out.println();
		}
		netA.activate(state);
		for (int i=0; i<state.size(); i++) {
			for (int j=0; j<state.get(0).length; j++) {
				System.out.print(state.get(i)[j] + " ");
			}
			System.out.println();
		}
		System.out.println(netA);	
		double[] tgtHidden = { 0.988625568255113, 0.9361075589156772 };
		for (int i=0; i<tgtHidden.length; i++) {
			assertEquals(tgtHidden[i], state.get(1)[i], 1e-7);
		}

	}
	
}
