-- Author      : Manuel De Girardi
-- Date        : 2010/12/9
-- Version     : 0.0.0pre-alpha_010
-- Description : artificial neural network generator for MidiSurf
-------------------------------------------------------------------------------
with Ada.Text_Io;                       use Ada.Text_Io;
with Ada.Strings, Ada.Strings.Fixed;
use Ada.Strings, Ada.Strings.Fixed;
with Calendar;                          use Calendar;
with Calendar.Formatting;
with Pragmarc.Ansi_Tty_Control;         use Pragmarc.Ansi_Tty_Control;
with Libsens.Neural_Chord.IO;
package body Libsens.Neural_Chord.Trainner is
    
   
   procedure Train_From_File(Filename : in String; Reuse : in Boolean; Converged : in Real; Max_Epoch : in Positive) is
   begin
      
      
      
      Register_Io.Open(Reg_File, Register_Io.In_File, filename & ".bin");
      ----------------------------------------------------------------
      Data_Length := natural(Register_Io.Size(Reg_File))/2;
      declare
	 Date : Time := Clock;
	 Heure, Minute, Seconde : Natural := 0;
	 Reste : Duration := 0.0;
	 procedure Get_Input (Pattern : in Positive;
			      Input : out Node_Set;
			      Desired : out Node_Set) is

	 begin

	    Register_Io.Read(Reg_File, T_Register(Input), Register_Io.Count(Pattern*2-1));
	    Register_Io.Read(Reg_File, T_Register(desired), Register_Io.Count(Pattern*2));
	 end Get_Input;


	 package Mutan_REM_NN_Trai is new REM_NN(Num_Input_Nodes => T_Register'Length,
                                                 Num_Hidden_Nodes => T_Register'Length/7,
                                                 Num_Output_Nodes => T_Register'Length,
                                                 New_Random_Weights => not reuse,
                                                 Weight_File_Name => Filename & ".wgt",
                                                 Input_To_Output_Connections => True,
                                                 Num_Patterns => Data_Length,
                                                 Get_Input => Get_Input);
	 Response : Mutan_REM_NN_Trai.Output_Set := (others => 0.0);
	 Desired_Output : array (1..Data_Length) of Mutan_REM_NN_trai.Output_Set;
	 Date_string   : String(1..80) := (others => Character'Val(32));
	 Banner : String(1..80) := (others => Character'Val(32));
	 RMS_Error : Real := 10.0;

	 Error     : Real := 0.0;

	 
	 Epoch : Natural := 0;
	 Index : Register_io.Count := 1;
      begin
         for I in 1..Desired_Output'Length loop
            Register_io.read(Reg_file, T_Register(Desired_Output(I)), Register_Io.Count(Register_Io."+"(Index, 1)));
            Index := Register_Io."+"(Index, 2);
         end loop;
         loop
            Move((80 * Character'Val(32)), Date_string, Ada.Strings.Error, Center);
            Move(Formatting.Image(Clock), Date_string, Ada.Strings.Error, Center);
            Put (Clear_Screen);
            Move("Welcome to Ultrason arpeggiator Network Generator." , Banner, Ada.Strings.Error, Center);
            Put_line(Bold_mode & Banner & Normal_mode);
            Put_Line(Date_string);
            New_Line;

            Put_Line("Training artificial neural network " & " length=" & Integer'Image(Data_Length));
            Put ("Epoch");
            put (Integer'Image (Epoch) );
            Put(" => RMS_Error: ");

            Real_Io.Put(RMS_Error);
            Put_line(Integer'Image(Integer(((converged)/RMS_Error)*100.0)) &
		       '%' &
		       Integer'Image(Heure) &
		       ':' &
		       Integer'Image(Minute) &
		       ':' &
		       Integer'Image(Seconde) &
		       ':' &
		       duration'Image(Reste) );
            Split(Clock, Heure, Minute, Seconde, reste);
            if integer((((80.0/RMS_Error)/80.0)/(0.1/converged))*8.0) > 0 then
               Put(Reverse_Video &
		     Integer((((80.0/RMS_Error)/80.0)/(0.1/converged))*8.0) * ' ' &
		     Normal_Mode);
            end if;
	    
            RMS_Error := 0.0;
        All_Patterns :
            for Pattern in 1..Data_Length Loop
               Mutan_REM_NN_Trai.Train;
               Mutan_REM_NN_Trai.Respond (Pattern, Response);
               for I in Response'Range loop
                  Error := Error + (Desired_Output(Pattern)(i) -   Response(i) );
               end loop;
               RMS_Error := RMS_Error + ((Error/Real(Response'Length)) ** 2);
               Error := 0.0;
            end loop All_Patterns;
            RMS_Error := Real_Math.Sqrt(RMS_Error / Real (Data_length)) ;
            if (RMS_Error <= Converged) or
              (Epoch >= Max_Epoch) then
               exit;
            end if;

            Epoch := Epoch + 1;
         end loop;
         Mutan_REM_NN_trai.Save_Weights;
      end;
      Register_Io.Close(Reg_File);
      ----------------------------------------------------------------
   end Train_From_File;


end Libsens.Neural_Chord.Trainner;