//-------------------------------------------------------------------------------
//** Part         : Convolutional Neural Network Accelerator
//** File name    : maxpool_relu.v
//** Description  : MaxPooling for CNN
//**                Activation Function for CNN - ReLU Function
//** Author       : Haosen Yu
//** Email        : haosenyu@hotmail.com
//** Revision     : Version 1.0
//** Date         : 2023-04-12 14:32:04
//** LastEditTime : 2024-05-15 11:46:03
//** 
//** Copyright (c) 2023 by Haosen Yu, All Rights Reserved. 
//-------------------------------------------------------------------------------

module maxpool_relu #(parameter CONV_BIT = 25, OUT_WIDTH = 7, OUT_HEIGHT = 4, OUT_WIDTH_BIT = 3) (
	input clk,
	input rst_n,	// asynchronous reset, active low
	input valid_in,
	input signed [CONV_BIT - 1 : 0] conv_out_1, conv_out_2, conv_out_3, conv_out_4, conv_out_5, conv_out_6, conv_out_7, conv_out_8,
	output reg [CONV_BIT - 1 : 0] max_value_1, max_value_2, max_value_3, max_value_4, max_value_5, max_value_6, max_value_7, max_value_8,
	output reg valid_out_relu
);

reg signed [CONV_BIT - 1:0] buffer1 [0:OUT_WIDTH - 1];
reg signed [CONV_BIT - 1:0] buffer2 [0:OUT_WIDTH - 1];
reg signed [CONV_BIT - 1:0] buffer3 [0:OUT_WIDTH - 1];
reg signed [CONV_BIT - 1:0] buffer4 [0:OUT_WIDTH - 1];
reg signed [CONV_BIT - 1:0] buffer5 [0:OUT_WIDTH - 1];
reg signed [CONV_BIT - 1:0] buffer6 [0:OUT_WIDTH - 1];
reg signed [CONV_BIT - 1:0] buffer7 [0:OUT_WIDTH - 1];
reg signed [CONV_BIT - 1:0] buffer8 [0:OUT_WIDTH - 1];

reg [OUT_WIDTH_BIT - 1:0] pcount;
reg [1:0] state;
reg flag;

always @(posedge clk) begin
		if(~rst_n) begin
			valid_out_relu <= 0;
			pcount <= 0;
			state <= 0;
			flag <= 0;
		end 
		else begin
			if(valid_in == 1'b1) begin
				flag <= ~flag;
				if(flag == 1) begin
					pcount <= pcount + 1;
					if(pcount == OUT_WIDTH - 1) begin
						pcount <= 0;
						if(state == 2'b10) begin
							state <= 0;
						else
							state <= state + 1'b1;
						end
					end
				end

				// pooling size 3*2
				if(state == 2'b00) begin	// first line
					valid_out_relu <= 0;
					if(flag == 0) begin	// 1st input
						buffer1[pcount] <= conv_out_1;
						buffer2[pcount] <= conv_out_2;
						buffer3[pcount] <= conv_out_3;
						buffer4[pcount] <= conv_out_4;
						buffer5[pcount] <= conv_out_5;
						buffer6[pcount] <= conv_out_6;
						buffer7[pcount] <= conv_out_7;
						buffer8[pcount] <= conv_out_8;
					end 
					else begin	// 2nd input -> comparison
						if(buffer1[pcount] < conv_out_1)
							buffer1[pcount] <= conv_out_1;
						if(buffer2[pcount] < conv_out_2)
							buffer2[pcount] <= conv_out_2;
						if(buffer3[pcount] < conv_out_3)
							buffer3[pcount] <= conv_out_3;
						if(buffer4[pcount] < conv_out_4)
							buffer4[pcount] <= conv_out_4;
						if(buffer5[pcount] < conv_out_5)
							buffer5[pcount] <= conv_out_5;
						if(buffer6[pcount] < conv_out_6)
							buffer6[pcount] <= conv_out_6;
						if(buffer7[pcount] < conv_out_7)
							buffer7[pcount] <= conv_out_7;
						if(buffer8[pcount] < conv_out_8)
							buffer8[pcount] <= conv_out_8;
					end
				end 
				else if(state == 2'b01) begin	// second line
					valid_out_relu <= 0;
					if(flag == 0) begin	// 3rd input
						if(buffer1[pcount] < conv_out_1)
							buffer1[pcount] <= conv_out_1;
						if(buffer2[pcount] < conv_out_2)
							buffer2[pcount] <= conv_out_2;
						if(buffer3[pcount] < conv_out_3)
							buffer3[pcount] <= conv_out_3;
						if(buffer4[pcount] < conv_out_4)
							buffer4[pcount] <= conv_out_4;
						if(buffer5[pcount] < conv_out_5)
							buffer5[pcount] <= conv_out_5;
						if(buffer6[pcount] < conv_out_6)
							buffer6[pcount] <= conv_out_6;
						if(buffer7[pcount] < conv_out_7)
							buffer7[pcount] <= conv_out_7;
						if(buffer8[pcount] < conv_out_8)
							buffer8[pcount] <= conv_out_8;
					end 
					else begin	// 4th input -> comparison
						if(buffer1[pcount] < conv_out_1)
							buffer1[pcount] <= conv_out_1;
						if(buffer2[pcount] < conv_out_2)
							buffer2[pcount] <= conv_out_2;
						if(buffer3[pcount] < conv_out_3)
							buffer3[pcount] <= conv_out_3;
						if(buffer4[pcount] < conv_out_4)
							buffer4[pcount] <= conv_out_4;
						if(buffer5[pcount] < conv_out_5)
							buffer5[pcount] <= conv_out_5;
						if(buffer6[pcount] < conv_out_6)
							buffer6[pcount] <= conv_out_6;
						if(buffer7[pcount] < conv_out_7)
							buffer7[pcount] <= conv_out_7;
						if(buffer8[pcount] < conv_out_8)
							buffer8[pcount] <= conv_out_8;
					end
				end 
				else begin	// third line
					if(flag == 0) begin	// 5th input -> comparison
						valid_out_relu <= 0;
						if(buffer1[pcount] < conv_out_1)
							buffer1[pcount] <= conv_out_1;
						if(buffer2[pcount] < conv_out_2)
							buffer2[pcount] <= conv_out_2;
						if(buffer3[pcount] < conv_out_3)
							buffer3[pcount] <= conv_out_3;
						if(buffer4[pcount] < conv_out_4)
							buffer4[pcount] <= conv_out_4;
						if(buffer5[pcount] < conv_out_5)
							buffer5[pcount] <= conv_out_5;
						if(buffer6[pcount] < conv_out_6)
							buffer6[pcount] <= conv_out_6;
						if(buffer7[pcount] < conv_out_7)
							buffer7[pcount] <= conv_out_7;
						if(buffer8[pcount] < conv_out_8)
							buffer8[pcount] <= conv_out_8;
					end 
					else begin	// 6th input -> comparison + RELU
						valid_out_relu <= 1;
						if(buffer1[pcount] < conv_out_1) begin
							if(conv_out_1[CONV_BIT] == 1'b0) begin // RELU according to the conv_out's sign bit
								max_value_1 <= conv_out_1;
							end 
							else begin
								max_value_1 <= 0;
							end
						end 
						else begin
							if(buffer1[pcount] > 0) begin
								max_value_1 <= buffer1[pcount];
							end else begin
								max_value_1 <= 0;
							end
						end

						if(buffer2[pcount] < conv_out_2) begin
							if(conv_out_2[CONV_BIT] == 1'b0) begin
								max_value_2 <= conv_out_2;
							end else begin
								max_value_2 <= 0;
							end
						end 
						else begin
							if(buffer2[pcount] > 0) begin
								max_value_2 <= buffer2[pcount];
							end else begin
								max_value_2 <= 0;
							end
						end

						if(buffer3[pcount] < conv_out_3) begin
							if(conv_out_3[CONV_BIT] == 1'b0) begin
								max_value_3 <= conv_out_3;
							end 
							else begin
								max_value_3 <= 0;
							end
						end 
						else begin
							if(buffer3[pcount] > 0) begin
								max_value_3 <= buffer3[pcount];
							end 
							else begin
								max_value_3 <= 0;
							end
						end

						if(buffer4[pcount] < conv_out_4) begin
							if(conv_out_4[CONV_BIT] == 1'b0) begin
								max_value_4 <= conv_out_4;
							end 
							else begin
								max_value_4 <= 0;
							end
						end 
						else begin
							if(buffer4[pcount] > 0) begin
								max_value_4 <= buffer4[pcount];
							end 
							else begin
								max_value_4 <= 0;
							end
						end

						if(buffer5[pcount] < conv_out_5) begin
							if(conv_out_5[CONV_BIT] == 1'b0) begin
								max_value_5 <= conv_out_5;
							end 
							else begin
								max_value_5 <= 0;
							end
						end 
						else begin
							if(buffer5[pcount] > 0) begin
								max_value_5 <= buffer5[pcount];
							end 
							else begin
								max_value_5 <= 0;
							end
						end

						if(buffer6[pcount] < conv_out_6) begin
							if(conv_out_6[CONV_BIT] == 1'b0) begin
								max_value_6 <= conv_out_6;
							end 
							else begin
								max_value_6 <= 0;
							end
						end 
						else begin
							if(buffer6[pcount] > 0) begin
								max_value_6 <= buffer6[pcount];
							end 
							else begin
								max_value_6 <= 0;
							end
						end

						if(buffer7[pcount] < conv_out_7) begin
							if(conv_out_7[CONV_BIT] == 1'b0) begin
								max_value_7 <= conv_out_7;
							end 
							else begin
								max_value_7 <= 0;
							end
						end 
						else begin
							if(buffer7[pcount] > 0) begin
								max_value_7 <= buffer7[pcount];
							end 
							else begin
								max_value_7 <= 0;
							end
						end

						if(buffer8[pcount] < conv_out_8) begin
							if(conv_out_8[CONV_BIT] == 1'b0) begin
								max_value_8 <= conv_out_8;
							end 
							else begin
								max_value_8 <= 0;
							end
						end 
						else begin
							if(buffer8[pcount] > 0) begin
								max_value_8 <= buffer8[pcount];
							end 
							else begin
								max_value_8 <= 0;
							end
						end
				end		
			end
		end 
		else begin
			valid_out_relu <= 0;
		end
	end
end
endmodule