How to set the initial values ​​of the weight and bias parameters of each layer

How to set the initial values ​​of the weight and bias parameters of each layer.

Calculations in the middle layer-Converting m inputs to n outputs used appropriate values ​​for the parameters of each layer.

In fact, we have discovered how to set the initial values ​​of parameters that are easy to learn.

How to set initial values ​​of parameters that are easy to learn

Initial value of bias, which is a constant term

The initial value of the bias "$b", which is a constant term, is set to 0.

my $b = [
  0,
  0,
  0,
];;

Initial value of weight --Initial value of He

I will explain the initial value of the weight.

First, imagine that the initial value of the weight is randomly determined. What remains is what random value to use.

First, assume that the number of inputs is m. In this case, use a random value according to a normal distribution with a mean of 0 and a standard deviation of "√ (2 / m)". This is called the initial value of He.

Don't be surprised when the words of math come up. For software engineers, all they need to know is a function that returns a random value according to a normal distribution.

It's up to the mathematician or statistician to decide what the normal distribution is.

# Function to find random numbers that follow a normal distribution
# $sigma is standard deviation, $m is mean
sub randn {
  my ($m, $sigma) = @_;
  my ($r1, $r2) = (rand(), rand());
  while ($r1 == 0) {$r1 = rand();}
  return($sigma * sqrt(-2 * log($r1)) * sin(2 * 3.14159265359 * $r2)) + $m;
}

#Number of inputs
my $m = 3;

#Initial value of weight
my $w_init_value = randn (0, sqrt(2 / $m));

Let's repeat the above function and output 100 times. What value will come out?

use strict;
use warnings;

# Function to find random numbers that follow a normal distribution
# $sigma is standard deviation, $m is mean
sub randn {
  my ($m, $sigma) = @_;
  my ($r1, $r2) = (rand(), rand());
  while ($r1 == 0) {$r1 = rand();}
  return($sigma * sqrt(-2 * log($r1)) * sin(2 * 3.14159265359 * $r2)) + $m;
}

#Number of inputs
my $m = 3;

for (my $i = 0; $i <100; $i ++) {
  #Initial value of weight
  my $w_init_value = randn (0, sqrt(2 / $m));
  print "$w_init_value\n";
}

Sample output result.

0.374333645619917
-0.375115681029835
-1.17350573904798
0.895734411096277
-2.45164088502897
0.21248128329348
-0.0612350673715995
-1.58015005611323
0.505326889380243
-0.599354648099927
0.0587914154389874
1.04458353028099
0.479197565615491
0.612832356529786
0.576332381567135
0.264180405541669
0.481818440413769
-0.0832424817326749
0.428427943276402
0.160565847218012
-0.737467690504583
-0.239595285587741
0.710885587813834
-0.726817014379143
-0.55686611444891
1.05190934231497
-0.2010231522538
0.827332432871366
0.295783340949512
0.0930341726064953
-0.133498161233327
-0.115434247541999
-0.0747196705878676
0.126032374885604
2.25242259680805
-0.402090471436117
-0.29409432069849
0.0108104319532479
1.94316372084671
0.353921285897365
-0.135311049425417
1.28256925366589
0.934265904171614
-0.353899801206181
-0.0177130218202863
0.876570751158322
0.93130446050329
0.565140594046581
0.0568720224046188
-1.3106498157376
0.294591391101043
0.302177509705135
-0.418728932729381
-0.890311118269986
-0.169201690565162
-0.726949054345823
-2.15992967386058
-0.0121016476938005
1.13903710052864
-0.334434247842771
-0.293389954684819
0.251851723770236
-0.65162175207737
-0.0542221278450117
-1.11968362212673
-0.756353974237116
-0.92565540502834
0.145613214400026
0.29465043488826
-0.451796034009891
0.815640908027956
0.148295326632814
-0.0099329408191555
0.78619620747882
0.599917749077822
0.62396415920206
-0.784434692735128
0.737015631093392
-1.27206848071826
1.024780494607
0.600957083530113
0.426820457119949
-0.475642547197279
-0.342178324138676
1.35634621998836
-0.168375479156423
-0.895146617883268
-0.275289833219868
-0.00169255909802365
-0.0556849290838152
0.670938086243496
0.231285743569269
0.560481150229232
0.122633939226874
0.190196194044245
-1.18023990429147
1.49262860806553
-1.13516745715629
0.612636896807322
0.99369257431466

Please feel it. Numbers from a little over -1 to about 1 are output.

Next, let's count the number in increments of 0.1.

use strict;
use warnings;

# Function to find random numbers that follow a normal distribution
# $sigma is standard deviation, $m is mean
sub randn {
  my ($m, $sigma) = @_;
  my ($r1, $r2) = (rand(), rand());
  while ($r1 == 0) {$r1 = rand();}
  return($sigma * sqrt(-2 * log($r1)) * sin(2 * 3.14159265359 * $r2)) + $m;
}

#Number of inputs
my $m = 3;

my $count = {};
for (my $i = 0; $i <100; $i ++) {
  #Initial value of weight
  my $w_init_value = randn (0, sqrt(2 / $m));
  
  my $range = int($w_init_value * 10) / 10;
  
  $count->{$range} ++;
}

for my $range (sort {$a <=> $b} keys%$count) {
  print "Range: $range, Count: $count->{$range}\n";
}

Output result.

Range: -2, Count: 3
Range: -1.6, Count: 2
Range: -1.4, Count: 1
Range: -1.3, Count: 1
Range: -1.2, Count: 3
Range: -1.1, Count: 2
Range: -1, Count: 2
Range: -0.9, Count: 3
Range: -0.8, Count: 1
Range: -0.7, Count: 2
Range: -0.6, Count: 6
Range: -0.5, Count: 3
Range: -0.4, Count: 3
Range: -0.3, Count: 1
Range: -0.2, Count: 7
Range: -0.1, Count: 3
Range: 0, Count: 8
Range: 0.1, Count: 4
Range: 0.2, Count: 3
Range: 0.3, Count: 8
Range: 0.4, Count: 6
Range: 0.5, Count: 5
Range: 0.6, Count: 2
Range: 0.7, Count: 3
Range: 0.8, Count: 2
Range: 0.9, Count: 4
Range: 1, Count: 1
Range: 1.1, Count: 3
Range: 1.3, Count: 3
Range: 1.4, Count: 3
Range: 1.5, Count: 1
Range: 1.8, Count: 1

It seems that the number around 0 is somewhat large. It feels like the number of cases approaches 0 as the absolute value of the value increases. This is an image of the normal distribution.

Even if you don't understand the theory of the normal distribution, you can imagine how the distribution varies.

Sample to set weights

Let's write a sample that sets weights and gets multiple outputs from multiple inputs.

Computations in the middle layer-Convert m inputs to n outputs Write a looping sample with appropriate initial values to watch. Data::Dumper also outputs the initial value of the weight to the standard error output.

use strict;
use warnings;
use Data::Dumper;

my $x = [0.5, 0.8];
my $y = [0, 0, 0];
my $x_len = @$x;
my $y_len = @$y;

my $w = [];
for (my $i = 0; $i <$x_len * $y_len; $i ++) {
  my $w_init_value = randn (0, sqrt(2 / $x_len));
  push @$w, $w_init_value;
}

print STDERR Dumper ($w);

my $b = [
  0,
  0,
  0
];;

for (my $y_index = 0; $y_index <$y_len; $y_index ++) {
  my $total = 0;
  for (my $x_index = 0; $x_index <$x_len; $x_index ++) {
    $total + = ($w->[$x_len * $y_index + $x_index] * $x->[$x_index]);
  }
  $total + = $b->[$y_index];
  $y->[$y_index] = $total> 0? $Total: 0;
}

print "($y->[0], $y->[1], $y->[2])\n";

sub randn {
  my ($m, $sigma) = @_;
  my ($r1, $r2) = (rand(), rand());
  while ($r1 == 0) {$r1 = rand();}
  return($sigma * sqrt(-2 * log($r1)) * sin(2 * 3.14159265359 * $r2)) + $m;
}

An example of the output result.

$VAR1 = [
          '0.152110884289137',
          '2.40437412125725',
          '1.47474871999698',
          '0.30283258213298',
          '-0.274498796676187',
          '-1.27516026508991'
        ];;
(1.99955473915037, 0.979640425704874, 0)

Initial value of weight --Initial value of Xivier

In addition to the initial value of He, there is a method called the initial value of Xivier that randomly generates weights from a normal distribution with a mean of 0 and a standard deviation of "1 / sqrt(n)".

It seems to be used when the activation function is other than ReLU.

Associated Information