package linMap.geneNetwork;

import java.util.*;
import java.io.*;

/**
 * A <code>Layer</code> object represents a layer of nodes in a gene network.
 * @author nic
 */
abstract class Layer implements Cloneable, Serializable {

	private int id;
	private int size;
	private boolean biased;

	// map from source layer ID to weight matrix
	protected Map<Integer, WeightMatrix> inputWeights;
	
	protected double[] biasValues;
	protected double[] currentActivation;
	protected double[] newActivation;

	/**
	 * Default layer constructor.  Creates a layer of size 0.
	 */
	public Layer() {}
	
	/**
	 * Basic layer constructor.
	 * @param id the ID of the new layer
	 * @param size the number of nodes in the new layer
	 * @param biased true if the new layer has bias terms
	 */
	public Layer(int id, int size, boolean biased) {
		this.id = id;
		this.size = size;
		this.currentActivation = new double[this.size];
		this.newActivation = new double[this.size];
		this.biased = biased;
		this.biasValues = new double[this.size];
		this.inputWeights = new HashMap<Integer, WeightMatrix>();
	}
	
	@Override
	public boolean equals(Object otherObject) {
		if (this == otherObject) return true;
		if (otherObject == null) return false;
		if (this.getClass() != otherObject.getClass()) return false;
		Layer other = (Layer)otherObject;
		return this.id == other.id 
			&& this.size == this.size
			&& this.biased == other.biased
			&& Arrays.equals(this.biasValues, other.biasValues)
			&& Arrays.equals(this.currentActivation, other.currentActivation)
			&& Arrays.equals(this.newActivation, other.newActivation);
	}
	
	@Override
	public String toString() {
		return getClass().getName() 
			+ "[id=" + id 
			+ ",size=" + size
			+ ",biased=" + biased
			+ ",biasValues=" + Arrays.toString(biasValues)
			+ ",currentActivation=" + Arrays.toString(currentActivation)
			+ ",newActivation=" + Arrays.toString(newActivation)
			+ ",inputWeights=" + inputWeights
			+ "]";
	}

	@Override
	public Object clone() {
		Layer o = null;
		try {
			o = (Layer)super.clone();
		} catch (CloneNotSupportedException e) {
			e.printStackTrace();
		}
		o.biasValues = biasValues.clone();
		o.currentActivation = currentActivation.clone();
		o.newActivation = newActivation.clone();
		o.inputWeights = new HashMap<Integer, WeightMatrix>();
		for (Map.Entry<Integer, WeightMatrix> entry : this.inputWeights.entrySet()) {
			o.inputWeights.put(entry.getKey(), (WeightMatrix)entry.getValue().clone());
		}
		return o;
	}
	
	/**
	 * Layer ID accessor.
	 * @return the layer ID
	 */
	public int getId() {
		return this.id;
	}
	
	/**
	 * Layer size accessor.
	 * @return the number of nodes in the layer
	 */
	public int getSize() {
		return this.size;
	}
	
	/**
	 * Layer bias accessor
	 * @return <code>true</code> if the layer is biased
	 */
	public boolean getBiased() {
		return this.biased;
	}
	
	/**
	 * Layer activation accessor
	 * @return a copy of the current activation state of the layer
	 */
	public double[] getCurrentActivation() {
		return currentActivation;
	}

	/**
	 * Layer activation mutator
	 * @param activation the new activation vector
	 */
	protected void setCurrentActivation(double[] activation) {
		assert activation.length == this.currentActivation.length;
		System.arraycopy(activation, 0, this.currentActivation, 0, activation.length);
	}
	
	/**
	 * Randomise all biases to normally distributed values with 
	 * mean 0.0 and standard deviation <code>weightRange</code>.
	 * @param generator 
	 * 				the <code>Random</code> used to generate
	 * @param weightRange 
	 * 				the standard deviation of the random values
	 */
	protected void setBiasesRandom(Random generator, double weightRange) {
		for (int i=0; i<this.size; i++) {
			this.biasValues[i] = generator.nextGaussian() * weightRange;
		}
	}
	
	/**
	 * Set the bias for each node equal to a proportion the sum of its input weights.
	 * @param bias
	 * 			the proportion by which the sum of input weights is to be scaled
	 */
	protected void setBiasesTotal(double bias) {
		for (int tIndex=0; tIndex<this.size; tIndex++) {
			double total = 0.0;

			for (WeightMatrix wm : this.inputWeights.values()) {

				for (int sIndex=0, n=wm.getSourceSize(); sIndex<n; sIndex++) {
					total += wm.weights[tIndex][sIndex];
				}
			}
			this.biasValues[tIndex] = -(bias * total);
		}
	}
	
	/**
	 * Calculate number of input weights into this layer.
	 * @return number of input weights 
	 */
	protected int getWeightCount() {
		int count = 0;
		for (WeightMatrix wm : this.inputWeights.values()) {
			count += wm.getWeightCount();
		}
		return count;
	}
	
	/**
	 * Add a new input weight matrix from <code>sourceLayer</code>.
	 * @param sLayer the (existing) source layer which is being linked from
	 */
	protected void addInput(Layer sLayer) {
		WeightMatrix wm = new WeightMatrix(this.size, sLayer.getSize());
		this.inputWeights.put(sLayer.getId(), wm);
	}
		
	/**
	 * Add a new input weight matrix from <code>sourceLayer</code> with 
	 * specified values.
	 * @param sLayer the (existing) source layer which is being linked from
	 * @values a two-dimensional array of weight values 
	 */
	protected void addInput(Layer sLayer, double[][] values) {
		WeightMatrix wm = new WeightMatrix(this.size, sLayer.getSize());
		if (values != null) {
			for (int tIndex=0, n=wm.getTargetSize(); tIndex<n; tIndex++) {
				System.arraycopy(values[tIndex], 0, wm.weights[tIndex], 0, 
						values[tIndex].length);
				for (int sIndex=0, n2=wm.getSourceSize(); sIndex<n2; sIndex++) {
					wm.weights[tIndex][sIndex] = values[tIndex][sIndex];
				}
			}
		}
		this.inputWeights.put(sLayer.getId(), wm);
	}
	
	/**
	 * Randomly set all weights to values drawn from a normal distribution 
	 * with mean 0.0 and standard deviation <code>weightRange</code>.
	 * @param generator 
	 * 				the <code>Random</code> used to generate
	 * @param weightRange 
	 * 				the standard deviation of the random values
	 */
	protected void setWeightsRandom(Random generator, double weightRange) {
		for (WeightMatrix wm : this.inputWeights.values()) {
			wm.randomise(generator, weightRange);
		}
	}
	
	/**
	 * Accessor for source layer IDs.
	 * @returns a set of source layer IDs
	 */
	protected Set<Integer> getInputIds() {
		return this.inputWeights.keySet();
	}
	
	/**
	 * Activate the layer.  Calculates the new activation state of a layer 
	 * given the current input.
	 * @param input the current input
	 */
	protected abstract void activate(double[] input);
	protected abstract void activate(Map<Integer, double[]> acts);
	
	/**
	 * Update the activation state of the layer.  Sets the current activation of 
	 * the layer to the new activation.
	 */
	protected void update() {
		System.arraycopy(newActivation, 0, currentActivation, 0, this.size);
	}
	
	protected double[] calculateInput(Map<Integer, double[]> acts) {
		double[] input = new double[this.size];
		for (Map.Entry<Integer, WeightMatrix> entry : this.inputWeights.entrySet()) {
			int sourceId = entry.getKey();
			WeightMatrix wm = entry.getValue();
			for (int tIndex=0; tIndex<input.length; tIndex++) {
				for (int sIndex=0, n=wm.getSourceSize(); sIndex<n; sIndex++) {
					input[tIndex] += wm.weights[tIndex][sIndex] * acts.get(sourceId)[sIndex];
				}
			}
		}
		return input;
	}
}
