Source for de.webdings.jannis.neuralnet.Teacher

   1: /* Teacher.java - Copyright (c) 2005 by Stefan Thesing
   2:  <p>This file is part of Jannis.</p>
   3:  <p>Jannis is free software; you can redistribute it and/or modify
   4:  it under the terms of the GNU General Public License as published by
   5:  the Free Software Foundation; either version 2 of the License, or
   6:  (at your option) any later version.</p>
   7: <p>Jannis is distributed in the hope that it will be useful,
   8: but WITHOUT ANY WARRANTY; without even the implied warranty of
   9: MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10: GNU General Public License for more details.</p>
  11: <p>You should have received a copy of the GNU General Public License
  12: along with Jannis; if not, write to the<br>
  13: Free Software Foundation, Inc.,<br>
  14: 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA<br>
  15: */
  16: package de.webdings.jannis.neuralnet;
  17: 
  18: import java.io.IOException;
  19: 
  20: import de.webdings.jannis.exceptions.PatternCreateException;
  21: import de.webdings.jannis.exceptions.PatternGiverReaderCommunicationException;
  22: import de.webdings.tools.files.TextFiles;
  23: 
  24: /**
  25:  * Teacher is used to train a neural net. Currently only
  26:  * training by backpropagation of errors for 2-layer-nets
  27:  * is supported.
  28:  * 
  29:  * @author Stefan Thesing<br>
  30:  * Website: <a href="http://www.webdings.de">http://www.webdings.de</a>
  31:  * @version 0.1 11.08.2005
  32:  */
  33: public class Teacher {
  34:     //attributes
  35:     /**
  36:      * Teacher compares the output the net is supposed to produce to the output actually 
  37:      * produced by the net and changes synapse weights slowly into a direction
  38:      * that brings the net closer to producing the desired output.
  39:      */
  40:     protected Pattern[] desiredOutput;
  41:     private Neuron[][] layers;
  42:     private int counter;
  43:     
  44:     //constructors
  45:     /**
  46:      * @param fileNameDesiredOutput
  47:      * @param net
  48:      * @throws PatternCreateException
  49:      * @throws IOException
  50:      */
  51:     public Teacher(String fileNameDesiredOutput, NeuralNet net) throws PatternCreateException, IOException {
  52:         this(fileNameDesiredOutput, net.getLayers());
  53:     }
  54:     
  55:     /**
  56:      * @param fileNameDesiredOutput
  57:      * @param layers
  58:      * @throws PatternCreateException
  59:      * @throws IOException
  60:      */
  61:     public Teacher(String fileNameDesiredOutput, Neuron[][] layers) throws PatternCreateException, IOException {
  62:       this.desiredOutput = PatternConverter.strToPattern(TextFiles.readFromFile(fileNameDesiredOutput),layers[layers.length-1].length);
  63:       this.layers = layers;
  64:     }
  65: 
  66:     /**
  67:      * @param desiredOutput
  68:      * @param net
  69:      */
  70:     public Teacher(Pattern[] desiredOutput, NeuralNet net) {
  71:         this(desiredOutput, net.getLayers());
  72:     }
  73:     
  74:     /**
  75:      * @param desiredOutput
  76:      * @param layers
  77:      */
  78:     public Teacher(Pattern[] desiredOutput, Neuron[][] layers) {
  79:       this.desiredOutput = desiredOutput;
  80:       this.layers = layers;
  81:       this.counter = 0;
  82:     }
  83: 
  84:     //methods
  85:     /**
  86:      * @return the number of actual and desired output 
  87:      * patterns that have already been compared
  88:      */
  89:     int amountCompared() {
  90:       return counter;
  91:     }
  92: 
  93:     /**
  94:      * compares the actual output produced by the net to
  95:      * the desired output
  96:      */
  97:     void compareOutputToDesiredOutput() {
  98:       for(int i=0;i<desiredOutput[0].entries.length;++i) {
  99:         layers[layers.length-1][i].setShouldHaveFired(desiredOutput[counter].entries[i]);
 100:       }
 101:     }
 102: 
 103:     /**
 104:      * There are many possible combination of states 
 105:      * the parameters can have. Teacher only modifies
 106:      * synapse weights in two cases:
 107:      * <table border=1>
 108:      *  <tr>
 109:      *   <th>#</th>
 110:      *   <th>Description</th>
 111:      *   <th>Modification</th>
 112:      *  </tr>
 113:      *  <tr>
 114:      *   <td>Case 1</td>
 115:      *   <td>the target has fired, but it wasn't supposed 
 116:      *       to fire, the source has fired</td>
 117:      *   <td>decrease the synapse weight by 0.1</td>
 118:      *  </tr>
 119:      *  <tr>
 120:      *   <td>Case 2</td>
 121:      *   <td>the target didn't fire, but it was supposed to
 122:      *       fire, the source has fired</td>
 123:      *   <td>increase the synapse weight by 0.1</td>
 124:      *  </tr>
 125:      * </table>  
 126:      * @param targetFired
 127:      * @param targetShouldHaveFired
 128:      * @param sourceFired
 129:      * @param synapse
 130:      */
 131:     void adjustWeights(boolean targetFired, boolean targetShouldHaveFired, boolean sourceFired, Synapse synapse) {
 132:             if(targetFired && targetShouldHaveFired && sourceFired) {
 133:               //This was used in an attempt to implement
 134:               //a backpropagation training method for 
 135:               //nets with more than two layers. It will
 136:               //stay here for the time being, although it
 137:               //doesn't do anything functional for now.
 138:               synapse.getSource().setShouldHaveFired(true);
 139:             }
 140: 
 141:             //Case 1:
 142:             if(targetFired && !targetShouldHaveFired && sourceFired) {
 143:               synapse.setWeight(synapse.getWeight()-0.1f);
 144:             }
 145: 
 146:             //Case 2:
 147:             if(!targetFired && targetShouldHaveFired && sourceFired) {
 148:               synapse.setWeight(synapse.getWeight()+0.1f);
 149:               synapse.getSource().setShouldHaveFired(true);
 150:             }
 151: 
 152:     }
 153: 
 154:     private void checkNetBackwards() {
 155:       Neuron currentNeuron;
 156:       Neuron potentialSource;
 157:       //Start with the output layer and count down to the 
 158:       //first hidden layer
 159:       for(int i=layers.length-1;i>0;--i) {
 160:         //process every neuron of the current layer
 161:         for(int j=0;j<layers[i].length;++j) {
 162:           currentNeuron = layers[i][j];
 163:           //process the layer before the current one
 164:           for(int k=0;k<layers[i-1].length;++k) {
 165:             potentialSource = layers[i-1][k];
 166:             //check every neuron of that layer for a connection to the 
 167:             //curront neuron
 168:             for(int l=0;l<potentialSource.getConnections().length;++l) {
 169:               //if the potential source targets (among others) 
 170:               //the current neuron, the synapse weight of 
 171:               //the connection is modified (this happens in
 172:               //the method 'adjustWeights'
 173:               if(potentialSource.getConnections()[l].getTarget() == currentNeuron) {
 174:                 adjustWeights(currentNeuron.hasFired(), currentNeuron.getShouldHaveFired(), potentialSource.hasFired(), potentialSource.getConnections()[l]);
 175:               }
 176:             }
 177:           }
 178: 
 179:         }
 180:       }
 181:     }
 182:     
 183:     /**
 184:      * starts comparing the actual output produced by the
 185:      * net with desired ouput and then backpropagates the
 186:      * error.
 187:      * @throws PatternGiverReaderCommunicationException
 188:      */
 189:     public void teach() throws PatternGiverReaderCommunicationException {
 190:         if(counter >= desiredOutput.length) {
 191:           throw new PatternGiverReaderCommunicationException("An error occured while teaching!");
 192:         } else {
 193:           this.compareOutputToDesiredOutput();
 194:           this.checkNetBackwards();
 195:           ++counter;
 196:         }
 197:     }
 198: 
 199:     /**
 200:      * @return the desired output the net is supposed to
 201:      * produced
 202:      */
 203:     public Pattern[] getDesiredOutput() {
 204:         return desiredOutput;
 205:     }
 206: }

© 2005 by Stefan Thesing;
Verbatim copying and redistribution of this entire page are permitted provided this notice is preserved.