//-------------------------------------------------------------------------------
//** Part         : Convolutional Neural Network Accelerator
//** File name    : conv1_buf.v
//** Description  : 1st Convolution Layer for CNN MNIST dataset Input Buffer
//** Author       : Haosen Yu
//** Email        : haosenyu@hotmail.com
//** Revision     : Version 1.0
//** Date         : 2023-04-12 15:23:49
//** LastEditTime : 2024-05-08 17:08:24
//** 
//** Copyright (c) 2023 by Haosen Yu, All Rights Reserved. 
//-------------------------------------------------------------------------------

 //WIDTH: width of the image, HEIGHT: height of the image
 //DATA_BITS: data bit width (0-65536) of a single pixel of the input datas
 module conv1_buf #(parameter WIDTH = 16, HEIGHT = 15, DATA_BITS = 24)(
   input clk,
   input rst_n,
   input [DATA_BITS - 1:0] data_in,
   output reg [DATA_BITS - 1:0] data_out_0, data_out_1, data_out_2, 
   data_out_3, data_out_4, data_out_5, 
   data_out_6, data_out_7, data_out_8, 
   output reg valid_out_buf   //active signal
 );

 localparam FILTER_SIZE = 3;

 reg [DATA_BITS - 1:0] buffer [0:WIDTH * FILTER_SIZE - 1];  //48 24bit_buffers
 reg [DATA_BITS - 1:0] buf_idx;
 reg [4:0] w_idx, h_idx; 
 //buf_flag (Number of corresponding lines)   
 reg [1:0] buf_flag;  // 0 ~ 2
 reg state;

 always @(posedge clk) begin
   if(~rst_n) begin
     buf_idx <= -1;
     w_idx <= 0;
     h_idx <= 0;
     buf_flag <= 0;
     state <= 0;
     valid_out_buf <= 0;      //initial reset
     data_out_0 <= 24'bx;
     data_out_1 <= 24'bx;
     data_out_2 <= 24'bx;
     data_out_3 <= 24'bx;
     data_out_4 <= 24'bx;
     data_out_5 <= 24'bx;
     data_out_6 <= 24'bx;
     data_out_7 <= 24'bx;
     data_out_8 <= 24'bx;
   end 
   else begin
    buf_idx <= buf_idx + 1;
    if(buf_idx == WIDTH * FILTER_SIZE - 1) begin // buffer size = 48 = 16(w) * 3(h)
      buf_idx <= 0;
    end
    
    buffer[buf_idx] <= data_in;  // data input
    
    // Wait until first 48 input data filled in buffer 
    if(!state) begin
      if(buf_idx == WIDTH * FILTER_SIZE - 1) begin
        state <= 1'b1;
      end
    end 
    else begin // valid state
      w_idx <= w_idx + 1'b1; // move right  state is 1 (count+1, shift right once)

      if(w_idx == WIDTH - FILTER_SIZE + 1) begin
        valid_out_buf <= 1'b0; // unvalid area
      end 
      else if(w_idx == WIDTH - 1) begin
        buf_flag <= buf_flag + 1'b1;
        if(buf_flag == FILTER_SIZE - 1) begin
          buf_flag <= 0;
        end
        w_idx <= 0;

        if(h_idx == HEIGHT - FILTER_SIZE) begin  // done 1 input read -> 16 * 15
          h_idx <= 0;
          state <= 1'b0;
        end 
        
        h_idx <= h_idx + 1'b1;

      end 
      else if(w_idx == 0) begin
        valid_out_buf <= 1'b1; // start valid area
      end

      // Buffer Selection -> 3 * 3
      if(buf_flag == 3'd0) begin
        data_out_0 <= buffer[w_idx];
        data_out_1 <= buffer[w_idx + 1];
        data_out_2 <= buffer[w_idx + 2];

        data_out_3 <= buffer[w_idx + WIDTH];
        data_out_4 <= buffer[w_idx + 1 + WIDTH];
        data_out_5 <= buffer[w_idx + 2 + WIDTH];

        data_out_6 <= buffer[w_idx + WIDTH * 2];
        data_out_7 <= buffer[w_idx + 1 + WIDTH * 2];
        data_out_8 <= buffer[w_idx + 2 + WIDTH * 2];
      end 
      else if(buf_flag == 3'd1) begin
        data_out_0 <= buffer[w_idx + WIDTH];
        data_out_1 <= buffer[w_idx + 1 + WIDTH];
        data_out_2 <= buffer[w_idx + 2 + WIDTH];

        data_out_3 <= buffer[w_idx + WIDTH * 2];
        data_out_4 <= buffer[w_idx + 1 + WIDTH * 2];
        data_out_5 <= buffer[w_idx + 2 + WIDTH * 2];

        data_out_6 <= buffer[w_idx];
        data_out_7 <= buffer[w_idx + 1];
        data_out_8 <= buffer[w_idx + 2];
      end 
      else if(buf_flag == 3'd2) begin
        data_out_0 <= buffer[w_idx + WIDTH * 2];
        data_out_1 <= buffer[w_idx + 1 + WIDTH * 2];
        data_out_2 <= buffer[w_idx + 2 + WIDTH * 2];

        data_out_3 <= buffer[w_idx];
        data_out_4 <= buffer[w_idx + 1];
        data_out_5 <= buffer[w_idx + 2];

        data_out_6 <= buffer[w_idx + WIDTH];
        data_out_7 <= buffer[w_idx + 1 + WIDTH];
        data_out_8 <= buffer[w_idx + 2 + WIDTH];
      end 
    end
   end
 end
endmodule