Skip to content
Snippets Groups Projects
Commit 10b5fe02 authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

Training now stops when stuck in a minimum

parent 2b68c4ef
No related branches found
No related tags found
No related merge requests found
......@@ -80,7 +80,7 @@ MLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int
std = .001;
std = 1/sqrt(inDim*std);
for (k=0;k<topo[1];k++)
net->weights[0][k*(topo[0]+1)+j+1] = randn(4*std);
net->weights[0][k*(topo[0]+1)+j+1] = randn(std);
}
net->in_rate[0] = 1;
for (j=0;j<topo[1];j++)
......@@ -223,7 +223,7 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
{
int i, j;
int e;
float last_rms = 1e10;
float best_rms = 1e10;
int inDim, outDim, hiddenDim;
int *topo;
double *W0, *W1, *best_W0, *best_W1;
......@@ -241,6 +241,8 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
pthread_t thread[NB_THREADS];
int samplePerPart = nbSamples/NB_THREADS;
int count_worse=0;
int count_retries=0;
topo = net->topo;
inDim = net->topo[0];
hiddenDim = net->topo[1];
......@@ -313,10 +315,10 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
float mean_rate = 0, min_rate = 1e10;
rms = (rms/(outDim*nbSamples));
error_rate = (error_rate/(outDim*nbSamples));
fprintf (stderr, "%f (%f %f) ", error_rate, rms, last_rms);
if (rms < last_rms)
fprintf (stderr, "%f (%f %f) ", error_rate, rms, best_rms);
if (rms < best_rms)
{
last_rms = rms;
best_rms = rms;
for (i=0;i<W0_size;i++)
{
best_W0[i] = W0[i];
......@@ -328,10 +330,12 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
best_W1_rate[i] = W1_rate[i];
}
count_worse=0;
} else if (rms > last_rms) {
count_retries=0;
} else {
count_worse++;
if (count_worse>20)
if (count_worse>30)
{
count_retries++;
count_worse=0;
for (i=0;i<W0_size;i++)
{
......@@ -344,13 +348,15 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
for (i=0;i<W1_size;i++)
{
W1[i] = best_W1[i];
best_W1_rate[i] *= .7;
best_W1_rate[i] *= .8;
if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15;
W1_rate[i] = best_W1_rate[i];
W1_grad[i] = 0;
}
}
}
if (count_retries>10)
break;
for (i=0;i<W0_size;i++)
{
if (W0_oldgrad[i]*W0_grad[i] > 0)
......@@ -386,7 +392,10 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
W1[i] += W1_grad[i]*W1_rate[i];
}
mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2];
fprintf (stderr, "%g (min %g) %d\n", mean_rate, min_rate, e);
fprintf (stderr, "%g %d", mean_rate, e);
if (count_retries)
fprintf(stderr, " %d", count_retries);
fprintf(stderr, "\n");
if (stopped)
break;
}
......@@ -403,7 +412,7 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
free(W1_grad);
free(W0_rate);
free(W1_rate);
return last_rms;
return best_rms;
}
int main(int argc, char **argv)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment