/* ================================================================
 * ================================================================
 *
 * Implementation of the Grow-When-Required Neural Network.
 *
 * References:
 * [1] Marsland, S., Shapiro, J., & Nehmzow, U. (2002). A self-organising network that grows when required. Neural Networks, 15(8–9), 1041–1058.
 * [2] Marsland, S., Nehmzow, U., & Shapiro, J. (2005). On-line novelty detection for autonomous mobile robots. Robotics and Autonomous Systems, 51(2–3), 191–206.
 * [3] Neto, H. V., & Nehmzow, U. (2007). Real-time automated visual inspection using mobile robots. Journal of Intelligent and Robotic Systems, 49(3), 293–307.
 *
 * Author: L. Pitonakova (http://lenkaspace.net)
 * License: GNU General Public License. Please credit me when using my work.
 *
 * ================================================================
 * ================================================================
 */

#include "gwrNN.h"
#include <algorithm>
#include "../helpers/logger.h"



GWRNN::GWRNN() {
   robotId = -1;

   neuronCounter = 0;
   winningNeuron = NULL;

   currentNoveltyValue = -1;
   currentWinningNeuron = -1;
   currentWinningNeuronError = -1;
   currentWinningNeuronNeighboursError = -1;
}


void GWRNN::Init(argos::TConfigurationNode& t_node) {
   params.Init(t_node);
   AddNeuron();
   AddNeuron();
}

GWRNN::~GWRNN() {
   for(size_t i = 0; i < neurons.size(); i++) {
      delete neurons[i];
   }
   for(size_t i = 0; i < connections.size(); i++) {
      delete connections[i];
   }
}

/**
 * One "update step" of the network
 */
void GWRNN::ProcessInput(std::vector<float> input_) {
   if (params.produceLogOutput) LOGERR << std::endl;



   if (params.produceLogOutput) { Logger::Out("------ PROCESSING:", input_); }

   //----
   //-- this algorithm follows the steps of [2], with the proportionality parameter for neighbour neuron adjustment from [3]
   //----

   bool networkDidChange = false;
   int i,e;

   //Logger::Out("Input vec", input_);
   //-- find 2 neurons with weights that are the closest match to the input (step 3)
   float distance;
   float minDistance = 999999;
   float secondMinDistance = 999999;
   int minDistanceNodeIndex = -1;
   int secondMinDistanceNodeIndex = -1;

   //-- compute the distance of each neuron's input weights to the current input
   for (i=0; i<neurons.size(); i++) {
      distance = Helpers::GetVectorDistance(input_, neurons[i]->inWeights);
      //if (params.produceLogOutput) { LOGERR << i << ":" << " min d=" << minDistance << "(N" << minDistanceNodeIndex <<")" << " 2nd min d=" << secondMinDistance << "(N" << secondMinDistanceNodeIndex <<")"; }
      if (distance < minDistance) {
         secondMinDistance = minDistance;
         secondMinDistanceNodeIndex = minDistanceNodeIndex;
         minDistance = distance;
         minDistanceNodeIndex = i;
      } else if (distance < secondMinDistance) {
         secondMinDistance = distance;
         secondMinDistanceNodeIndex = i;
      }
      //if (params.produceLogOutput) LOGERR << " D=" << distance << " >>> NEW  min d=" << minDistance << "(N" << minDistanceNodeIndex <<")" << " 2nd min d=" << secondMinDistance << "(N" << secondMinDistanceNodeIndex <<")";
      //if (params.produceLogOutput) LOGERR << std::endl;
   }

   winningNeuron = neurons[minDistanceNodeIndex];
   Neuron* secondWinnningNode = neurons[secondMinDistanceNodeIndex];

   //------ create connection between the winning neurons if doesn't exist otherwise set it to 0 (step 4)
   Connection* connectionBetweenWinningNodes = GetConnectionBetweenNeurons(winningNeuron, secondWinnningNode);
   if (connectionBetweenWinningNodes != NULL) {
      //-- set the age of the connection to 0
      connectionBetweenWinningNodes->age = 0;
   } else {
      //-- add new connection between the winning neurons
      connectionBetweenWinningNodes = AddConnection(winningNeuron, secondWinnningNode);
      networkDidChange = true;
   }

   //------ calculate activation of the winning neuron (step 5)
   float winningNodeActivation = GetNeuronActivation(winningNeuron, input_);
   if (params.produceLogOutput) LOGERR << "Winning neuron activation " << winningNodeActivation << " habituation " << winningNeuron->habituation << std::endl;
   if (params.produceLogOutput) LOG << "Winner is " << winningNeuron->Print() << " (h=" << Helpers::FloatToString(winningNeuron->habituation) << ")" << std::endl;

   //------- calculate some reporting
   currentWinningNeuron = winningNeuron->id;
   currentNoveltyValue = winningNeuron->habituation;
   currentWinningNeuronError = Helpers::GetVectorDistance(input_, winningNeuron->inWeights);
   //-- average error (i.e. distance from input) of winning neurons's neighbours
   int numNeighbourNeurons = 0;
   currentWinningNeuronNeighboursError = 0;
   for (e=0; e<winningNeuron->connections.size(); e++) {
      Neuron* neighbourNeuron = winningNeuron->connections[e]->GetNeuronConnectedTo(winningNeuron);
      if (neighbourNeuron != NULL) {
         numNeighbourNeurons++;
         currentWinningNeuronNeighboursError += Helpers::GetVectorDistance(input_, neighbourNeuron->inWeights);
      }
   }
   //LOGERR << " winning " << winningNeuron->Print() << "  num neighbours " << numNeighbourNeurons << " total err" << currentWinningNeuronNeighboursError;
   if (numNeighbourNeurons > 0) {
      currentWinningNeuronNeighboursError /= numNeighbourNeurons;
   }
   //LOGERR << " avg err " << currentWinningNeuronNeighboursError << std::endl;

   if (params.produceLogOutput) Logger::Out("Winning neuron weights:",winningNeuron->inWeights);

   //------ grow network when required (step 6)
   float factor, proportionality;
   bool shouldExpandInputConnections = false;
   //-- [EXTRA:] when network is adaptive, new neuron will be added not only when winning neuron activation and habituation are low,
   //   but also when the input vector does not fit into the input weights vector of the neuron.
   if (params.useAdaptiveInWeightNumber && winningNeuron->inWeights.size() < input_.size()) {
      shouldExpandInputConnections = true;
   }
   if ((winningNodeActivation < params.activationThreshold && winningNeuron->habituation < params.habituationThreshold) || shouldExpandInputConnections) {

      //-- add a new neuron
      Neuron*newNode = AddNeuron();

      //-- create new weight vector as average of weights between the winning neuron weights and input
      for (i=0; i<input_.size(); i++) {
         if (i >= newNode->inWeights.size() && params.useAdaptiveInWeightNumber) {
            //-- add a new weight that is initially random, then adapted to move closer to input
            newNode->AddWeight();
         }
         if (i < newNode->inWeights.size()) {
            if (i < winningNeuron->inWeights.size()) {
               //-- average the weight between winning neuron and input
               newNode->inWeights[i] = (winningNeuron->inWeights[i] + input_[i])/2.0;
            } else {
               //-- winning neuron weight not available, add weight as average between random and input
               newNode->inWeights[i] = (Helpers::GetRandomFloat(0,1) + input_[i])/2.0;
            }
         }
      }
      if (params.produceLogOutput) Logger::Out("New neuron weights",newNode->inWeights);

      //-- delete connection between the two winning neurons
      DeleteConnection(connectionBetweenWinningNodes);

      //-- add connections between the new neuron and the two winning neurons
      AddConnection(newNode, winningNeuron);
      AddConnection(newNode, secondWinnningNode);

      networkDidChange = true;
   }


   //------ adapt the weights of the winning neuron (step 7)
   //       [EXTRA: add new weights when input array larger than the weight array
   factor = params.learningRate_winningNode * winningNeuron->habituation;
   if (params.produceLogOutput) LOGERR << "Adapting winning neuron weights (" << winningNeuron->Print() << ") by factor " << factor << std::endl;
   for (i=0; i<input_.size(); i++) {
      if (i >= winningNeuron->inWeights.size() && params.useAdaptiveInWeightNumber) {
         //-- add a new weight that is initially random, then adapted to move closer to input
         winningNeuron->AddWeight();
         networkDidChange = true;
      }
      //-- adapt the weight
      if (i < winningNeuron->inWeights.size()) {
         winningNeuron->inWeights[i] += factor * (input_[i] - winningNeuron->inWeights[i]);
      }
   }



   //------ adapt the neighbouring neurons (steps 7-9)
   std::vector<Neuron*> updatedNeurons;
   updatedNeurons.push_back(winningNeuron);
   for (e=0; e<winningNeuron->connections.size(); e++) {
      Neuron* neighbourNeuron = winningNeuron->connections[e]->GetNeuronConnectedTo(winningNeuron);
      if (neighbourNeuron != NULL) {
         if (params.useDishabituation) { updatedNeurons.push_back(neighbourNeuron); }

         //-- adapt the weights of the winning neuron neighbours (step 7)
         //   [use proportional learning rate as in Neto2007]
         proportionality = (params.updates_neighbourNeuronProportionality*GetNeuronActivation(neighbourNeuron, input_)) / winningNodeActivation;
         factor =  proportionality*params.learningRate_winningNode * neighbourNeuron->habituation;
         if (params.produceLogOutput) LOGERR << "Adapting neighb neuron weights (" << neighbourNeuron->Print() << ") by factor " << factor << std::endl;
         for (i=0; i<input_.size(); i++) {
            if (i >= neighbourNeuron->inWeights.size() && params.useAdaptiveInWeightNumber) {
               //-- add a new weight that is initially random, then adapted to move closer to input
               neighbourNeuron->AddWeight();
               networkDidChange = true;
            }
            //-- adapt the weight
            if (i < neighbourNeuron->inWeights.size()) {
               neighbourNeuron->inWeights[i] += factor * (input_[i] - neighbourNeuron->inWeights[i]);
            }
         }

         //-- age the connection (step 8)
         winningNeuron->connections[e]->age += 1;

         //-- reduce the strength of habituation synapses of the neighbouring neurons (step 9)
         factor = (params.habituationUpdate_alpha_neighbourNeurons*(params.initialNodeHabituation - neighbourNeuron->habituation) - params.habituationUpdate_S) / ( (1/proportionality)*params.habituationUpdate_tau_winningNeuron);
         if (params.produceLogOutput) LOGERR << "Adapting neighb neuron h (" << neighbourNeuron->Print() << ") by factor " << factor << " h=" << neighbourNeuron->habituation << std::endl;
         neighbourNeuron->habituation += factor;
      }
   }

   //------ adapt non-neighbouring neurons (dishabituation step)
   if (params.useDishabituation) {
      for (Neuron* neuron : neurons) {
         //if (!std::find_if(updatedNeurons.begin(), updatedNeurons.end(), [](Neuron* n) { return n->id == 0; })) {
         if(std::find(updatedNeurons.begin(), updatedNeurons.end(), neuron) == updatedNeurons.end()){
            factor = (params.habituationUpdate_alpha_neighbourNeurons*(params.initialNodeHabituation - neuron->habituation) - params.habituationUpdate_S_nonStimulus) / ( (1/proportionality)*params.habituationUpdate_tau_winningNeuron);
            neuron->habituation += factor;
            //if (params.produceLogOutput) LOGERR << " adapting non-neighb neuron h (" << neuron->Print() << ") by factor " << factor << " h=" << neuron->habituation << std::endl;
         }
      }
   }


   //------ reduce the strength of habituation synapse of the winning neuron (step 9)
   factor = (params.habituationUpdate_alpha_winningNeuron*(params.initialNodeHabituation - winningNeuron->habituation) - params.habituationUpdate_S) / params.habituationUpdate_tau_winningNeuron;
   if (params.produceLogOutput) LOGERR << "Adapting winning neuron h (" << winningNeuron->Print() << ") by factor " << factor << " h=" << winningNeuron->habituation << std::endl;
   winningNeuron->habituation += factor;

   //------ check if there are old connections and non-connecting neurons to delete (step 10)
   for (i=0; i<connections.size(); i++) {
      if (params.maxConnectionAge > 0 && connections[i]->age > params.maxConnectionAge) {
         if (params.produceLogOutput) LOGERR << "[!] Deleting connection with age " << connections[i]->age << "between";
         Neuron* neuron1 = connections[i]->neuron1;
         Neuron* neuron2 = connections[i]->neuron2;
         DeleteConnection(connections[i]);
         if (neuron1 != NULL) {
            if (params.produceLogOutput) LOGERR << neuron1->Print();
            if (neuron1->connections.size() == 0) {
               if (params.produceLogOutput) LOGERR << "(TO DEL)";
               DeleteNeuron(neuron1);
            }
         }
         if (neuron2 != NULL) {
            if (params.produceLogOutput) LOGERR << " and " << neuron2->Print();
            if (neuron2->connections.size() == 0) {
               if (params.produceLogOutput) LOGERR << "(TO DEL)";
               DeleteNeuron(neuron2);
            }
         }
         networkDidChange = true;
         if (params.produceLogOutput) LOGERR << std::endl;
      }
   }

   //------ output the state of the network
   int weightsTotal = 0;
   for (int i=0; i<neurons.size(); i++) {
      weightsTotal += neurons[i]->inWeights.size();
   }
   if (networkDidChange) {
      if (params.produceLogOutput) LOG << "CHANGE num of neurons: " << neurons.size() << std::endl;
      //-- log the change
      Logger::Instance()->LogEvent(Event::TYPE::NN_CHANGED,robotId,neurons.size(),weightsTotal,connections.size());
   }

   if (params.produceLogOutput) LOG << " Connections: " << connections.size() << " Nodes: " << neurons.size() << "  Weights: " << weightsTotal << "  Total connections: " << (connections.size()+weightsTotal) ;
   for (int i=0; i<connections.size(); i++) {
      if (params.produceLogOutput) LOG << " " << connections[i]->neuron1->Print(true) << "--" << connections[i]->neuron2->Print(true) << "(age " << connections[i]->age << ")";
   }

   if (params.produceLogOutput) LOG << std::endl;


}

/**
 * Calculate a neuron activation [1]
 */
float GWRNN::GetNeuronActivation(Neuron* neuron_, std::vector<float> input_) {
   return exp(-Helpers::GetVectorDistance(input_, neuron_->inWeights));
}

/* ==================================================================================== */
/* ========================= Network manipulation and search */
/* ==================================================================================== */

/**
 * Add a new neuron with default inWeights
 */
GWRNN::Neuron* GWRNN::AddNeuron() {
   Neuron* newNode = new Neuron(neuronCounter++, params.initialNodeHabituation, params.initialNumOfInWeights);
   neurons.push_back(newNode);
   if (params.produceLogOutput) LOGERR << "[] Added neuron " << newNode->Print() << std::endl;
   return newNode;
}

/**
 * Deletes a neuron
 */
void GWRNN::DeleteNeuron(Neuron* neuron_) {
   if (params.produceLogOutput) LOGERR << "Deleting neuron " << neuron_->Print();
   //-- delete record of connections
   for (int i=0; i<neuron_->connections.size(); i++) {
      DeleteConnection(neuron_->connections[i]);
   }
   if (params.produceLogOutput) LOGERR << " performing " << std::endl;
   neurons.erase(std::remove(neurons.begin(), neurons.end(), neuron_), neurons.end());
}

/**
 * Add an connection between two neurons
 */
GWRNN::Connection* GWRNN::AddConnection(Neuron*neuron1_, Neuron* neuron2_) {
   //-- add new connection
   Connection* newConnection = new Connection(neuron1_, neuron2_);
   connections.push_back(newConnection);

   //-- remember in neurons
   neuron1_->connections.push_back(newConnection);
   neuron2_->connections.push_back(newConnection);

   //-- return
   if (params.produceLogOutput) LOGERR << "[] Added connection between neurons " << neuron1_->Print() << " and " << neuron2_->Print() << std::endl;
   return newConnection;
}

/**
 * Delete an connection
 */
void GWRNN::DeleteConnection(Connection* connection_) {
   //if (params.produceLogOutput) LOGERR << "Deleting connection ";
   //-- delete the record of the connection in neurons
   if (connection_->neuron1 != NULL) {
      //if (params.produceLogOutput) LOGERR << " of " << connection_->neuron1->Print();
      connection_->neuron1->connections.erase(std::remove(connection_->neuron1->connections.begin(), connection_->neuron1->connections.end(), connection_), connection_->neuron1->connections.end());
   }
   if (connection_->neuron2 != NULL) {
      //if (params.produceLogOutput) LOGERR << " of " << connection_->neuron2->Print();
      connection_->neuron2->connections.erase(std::remove(connection_->neuron2->connections.begin(), connection_->neuron2->connections.end(), connection_), connection_->neuron2->connections.end());
   }
   //-- delete the connection itself
   //if (params.produceLogOutput) LOGERR << " performing " << std::endl;
   connections.erase(std::remove(connections.begin(), connections.end(), connection_), connections.end());
}


/**
 * Return connection that connects two neurons. If not found, return NULL.
 */
GWRNN::Connection* GWRNN::GetConnectionBetweenNeurons(Neuron* neuron1_, Neuron* neuron2_) {
   for (int i=0; i<connections.size(); i++) {
      if (connections[i]->DoesConnectNeurons(neuron1_, neuron2_)) {
         return connections[i];
      }
   }
   return NULL;
}

/* ==================================================================================== */
/* ========================= NNConnection */
/* ==================================================================================== */

/**
 * Returns true if the connection, which doesn't have a direction, connects two neurons
 */
bool GWRNN::Connection::DoesConnectNeurons(Neuron*neuron1_, Neuron* neuron2_) {
   if (neuron1 == neuron1_ && neuron2 == neuron2_) return true;
   if (neuron2 == neuron1_ && neuron1 == neuron2_) return true;
   return false;
}

/**
 * Returns the second neuron of the pair, given the first neuron (neuron_).
 * Returns NULL if neuron_ is at neither end of the connection.
 */
GWRNN::Neuron* GWRNN::Connection::GetNeuronConnectedTo(Neuron* neuron_) {
   if (neuron1 == neuron_) return neuron2;
   if (neuron2 == neuron_) return neuron1;
   return NULL;
}




/* ==================================================================================== */
/* ========================= READ PARAMETERS */
/* ==================================================================================== */


void GWRNN::Params::Init(argos::TConfigurationNode& t_node) {
   try {
      GetNodeAttribute(t_node, "produceLogOutput", produceLogOutput);
      GetNodeAttribute(t_node, "initialNumOfInWeights", initialNumOfInWeights);
      GetNodeAttribute(t_node, "useAdaptiveInWeightNumber", useAdaptiveInWeightNumber);
      GetNodeAttribute(t_node, "maxConnectionAge", maxConnectionAge);
      GetNodeAttribute(t_node, "useDishabituation", useDishabituation);

      initialNodeHabituation = 1.0;
      GetNodeAttribute(t_node, "activationThreshold", activationThreshold);
      GetNodeAttribute(t_node, "habituationThreshold", habituationThreshold);
      GetNodeAttribute(t_node, "learningRate_winningNeuron", learningRate_winningNode);
      GetNodeAttribute(t_node, "habituationUpdate_tau_winningNeuron", habituationUpdate_tau_winningNeuron);
      GetNodeAttribute(t_node, "habituationUpdate_alpha", habituationUpdate_alpha_winningNeuron);
      GetNodeAttribute(t_node, "habituationUpdate_alpha", habituationUpdate_alpha_neighbourNeurons);
      GetNodeAttribute(t_node, "updates_neighbourNeuronProportionality", updates_neighbourNeuronProportionality);

      habituationUpdate_S = 1;

   }
   catch(CARGoSException& ex) {
      THROW_ARGOSEXCEPTION_NESTED("Error initialising GWRNN parameters.", ex);
   }
}

