package linMap.geneNetwork;

import java.util.*;
import java.io.*;

/**
 * A <code>Network</code> object represents a dynamic (gene) network.
 * 
 * @author nic
 */
public class Network implements Cloneable, Serializable {

	private static final long serialVersionUID = 2162825861785394142L;

	protected List<Layer> nodes;

	private int nextLayerIndex = 0;

	/**
	 * An enumerated type for the varieties of possible network layers.
	 * (currently only INPUT and SIGMOID).
	 * 
	 * @author nic
	 * 
	 */
	public enum LayerType {
		INPUT, SIGMOID
	}

	/**
	 * Default network constructor.
	 */
	public Network() {
		super();
		nodes = new ArrayList<Layer>();
	}

	public String toString() {
		return getClass().getName() + "[nodes=" + nodes + "]";
	}

	public Object clone() {
		Network o = null;
		try {
			o = (Network) super.clone();
		} catch (CloneNotSupportedException e) {
			e.printStackTrace();
		}
		o.nodes = new ArrayList<Layer>();
		// clone each layer
		for (Layer layer : this.nodes) {
			o.nodes.add((Layer) layer.clone());
		}
		return o;
	}

	/**
	 * Get the number of layers in the network.
	 * 
	 * @return number of layers in network
	 */
	public int getLayerCount() {
		return nodes.size();
	}

	public int getLayerSize(int index) {
		return nodes.get(index).getSize();
	}
	
	/**
	 * Get the number of nodes in the network.
	 * 
	 * @return number of nodes in network
	 */
	public int getNodeCount() {
		int count = 0;
		for (Layer layer : nodes) {
			count += layer.getSize();
		}
		return count;
	}

	/**
	 * Get the number of weights (excluding biases) in the network.
	 * 
	 * @return number of weights in network
	 */
	public int getWeightCount() {
		int count = 0;
		for (Layer layer : nodes) {
			count += layer.getWeightCount();
		}
		return count;
	}

	public int getInputSize() {
		return nodes.get(0).getSize();
	}
	
	public int getOutputSize() {
		return nodes.get(nodes.size()-1).getSize();
	}
	
	
	/**
	 * Add a new layer of nodes to a network.
	 * 
	 * @param size
	 *            the number of nodes in the new layer
	 * @param type
	 *            the type of the new layer
	 * @param biased
	 *            <code>true</code> if the new layer has bias terms
	 * @return the index of the new layer (for future identification)
	 */
	public int addLayer(int size, LayerType type, boolean biased) {
		Layer newLayer = null;
		switch (type) {
		case INPUT:
			newLayer = new InputLayer(this.nextLayerIndex, size);
			break;
		case SIGMOID:
			newLayer = new SigmoidLayer(this.nextLayerIndex, size, biased);
			break;
		default:
			System.out.println("ERROR: linMap.geneNetwork.Network.addLayer() "
					+ "-- unknown layer type");
			break;
		}

		nodes.add(newLayer);

		this.nextLayerIndex++;
		return this.nextLayerIndex - 1;
	}

	/**
	 * Link two existing layers by a weight matrix
	 * 
	 * @param sId
	 *            the ID of the source layer
	 * @param tId
	 *            the ID of the target layer
	 */
	public void linkLayers(int sId, int tId) {
		double[][] nullValues = null;
		this.linkLayers(sId, tId, nullValues);
	}

	/**
	 * Link two existing layers by a weight matrix with given values
	 * 
	 * @param sId
	 *            the ID of the source layer
	 * @param tId
	 *            the ID of the target layer
	 * @param values
	 *            a two-dimensional array of weight values
	 */
	public void linkLayers(int sId, int tId, double[][] values) {
		Layer sLayer = null;
		try {
			sLayer = this.nodes.get(sId);
		} catch (IndexOutOfBoundsException e) {
			System.out.println("ERROR: " + getClass().getName()
					+ "-- source layer " + sId + " does not exist");
		}
		try {
			this.nodes.get(tId).addInput(sLayer, values);
		} catch (IndexOutOfBoundsException e) {
			System.out.println("ERROR: " + getClass().getName()
					+ "-- target layer " + tId + " does not exist");
		}
	}

	/**
	 * Set the activation of all nodes in the network. (eg., to reset a network
	 * to 0.0)
	 * 
	 * @param value
	 *            the new activation value
	 */
	public void setAllActivations(double value) {
		for (Layer layer : nodes) {
			double[] activation = new double[layer.getSize()];
			for (int i = 0; i < activation.length; i++) {
				activation[i] = value;
			}
			layer.setCurrentActivation(activation);
		}
	}

	/**
	 * Set the activation of the input (first) layer.
	 * 
	 * @param activation
	 *            the new input activation
	 */
	public void setInput(double[] activation) {
		if (nodes.size() == 0) {
			System.out.println("ERROR: " + getClass().getName()
					+ "-- network must have at least one layer to set input");
		} else {
			nodes.get(0).setCurrentActivation(activation);
		}
	}

	/**
	 * Get the activation of the output (last) layer.
	 * 
	 * @return a copy of the output activation
	 */
	public double[] getOutput() {
		double[] output = null;
		if (nodes.size() == 0) {
			output = new double[0];
		} else {
			output = nodes.get(nodes.size() - 1).getCurrentActivation();
		}
		return output;
	}

	/**
	 * Update the network activation state on the basis of its current values.
	 */
	public void activate() {
		for (Layer layer : nodes) {
			Set<Integer> inputIds = layer.getInputIds();
			Map<Integer, double[]> inputActs = new HashMap<Integer, double[]>();
			for (int id : inputIds) {
				inputActs.put(id, nodes.get(id).getCurrentActivation());
			}
			layer.activate(inputActs);
			layer.update();
		}
	}

	/**
	 * Update the network activation state on the basis of the passed input
	 * values.
	 * 
	 * @param input
	 *            the new input values
	 * @return a copy of the new output values
	 */
	public double[] activate(double[] input) {
		setInput(input);
		activate();
		return getOutput();
	}
	
	public double[] activate(List<double[]> state) {
		for (int i=0, n=state.size(); i<n; i++) {
			nodes.get(i).currentActivation = state.get(i);
		}
		activate();
		return getOutput();
	}
}
