Viewing contents of file '../idllib/astron/contrib/beck/train_nnet.pro'
;+
;*NAME:
; TRAIN_NNET.PRO
;
;*PURPOSE:
; Trains (computes) the values of the weights used by the neural-network
; classifier.
;
;*CALLING SEQUENCE:
; TRAIN_NNET, n_pat, n_in, n_hid, n_out, train_set, classes, bias_hid, $
; w_hid, bias_out, w_out
;
;*INPUTS:
; n_pat - number of training patterns (INT scalar).
; n_in - number of input neurons (INT scalar).
; n_hid - number of hidden neurons (INT scalar).
; n_out - number of output neurons (INT scalar).
; train_set - training data ( DBLARR[n_in,n_pat] ).
; classes - classification of each training pattern ( INTARR[n_pat] ).
;
;*OUTPUTS:
; bias_hid - bias weights on the hidden neurons ( DBLARR[n_hid] ).
; w_hid - weights between input & hidden layers ( DBLARR[n_in,n_hid] ).
; bias_out - bias weights on the output neurons ( DBLARR[n_out] ).
; w_out - weights between hidden & output layers ( DBLARR[n_hid,n_out] ).
;
;*KEYWORD PARAMETERS:
; outfile - set this keyword to write the computed weights to a FITS
; file.
; alpha - learning rate, default=0.15.
; mu - momemtum term, default=0.10.
;
;*EXAMPLE:
; This example uses the neural network as a stellar spectral classifier.
; It could be used to classify any type of data, if the data could
; be input as a normalized vector.
; --------------------------------------------------------------------
; You have a set of 10 flux & wavelength calibrated spectra. If
; necessary, resample the spectra to the same dispersion (eg. nm/pixel).
; Extract the same wavelength region from all spectra. Normalize. Make
; sure all pixel values are between 0 and 1.0. Stack all spectra into
; a single 2-D array. This is the training set (see input variable
; "train_set" above). If each spectrum has 200 pixels, then the size of
; train_set will be (200,10). n_pat = 10 and n_in = 200 also.
;
; Create a integer vector ("classes", above) of 10 elements, each
; element is a number that designates the spectral type of the
; corresponding spectra in the training set, by subscript:
;
; classes(0) <====> train_set(*,0)
;
; It is help to generate a lookup table:
;
; class SP type
; ----- -------
; 0 M0V
; 1 M1V
; 2 M1.5V
; 3 M2V
; 4 M3V
; 5 M4V
; 6 M5V
;
; Example of classes vector:
;
; IDL> classes = [0,1,2,2,3,4,4,5,6,6]
;
; Note that in this case some spectral types have more than one example.
; It is a good idea to have a many examples of each spectral type as
; possible, this will allow the neural net to generalize better and be
; able to ignore noise.
;
; CAUTION: Two examples of the same spectal type that very different
; in appearance due to noise, poor calibraion, etc. may cause the
; network not to converge to a solution.
;
; In this example the number of output neurons (n_out) is equal to 7.
; Set n_hid to some number between n_in and n_out, in this example,
; 100 would be a good choice.
;
; Ready to run:
; IDL> train_nnet, 10, 200, 100, 7, train_set, classes, $
; bias_hid, w_hid, bias_out, w_out
;
;*OPERATIONAL NOTES:
;
; While program is running, it prints the training epoch (iteration)
; and the total error of all training patterns across all output units
; to standard output.
;
;*HISTORY:
; Version 1.0 Terry Beck
; Advanced Computer Concepts, Inc. 21 Apr 1999
;-
;___________________________________________________________________________
pro train_nnet, n_pat, n_in, n_hid, n_out, train_set, classes, $
bias_hid, w_hid, bias_out, w_out, outfile=outfile, $
alpha=alpha, mu=mu
if not(keyword_set(alpha)) then alpha=0.15
if not(keyword_set(mu)) then mu=0.10
;
; initialize weights
;
bias_hid = double(randomu(seed,n_hid)) - 0.5
w_hid = double(randomu(seed,n_in,n_hid)) - 0.5
bias_out = double(randomu(seed,n_out)) - 0.5
w_out = double(randomu(seed,n_hid,n_out)) - 0.5
;
; generate target array
;
targ = fltarr(n_pat,n_out)
n = indgen(n_pat)
targ(n,classes(n)) = 1.0
;
; define other needed arrays
;
delw = dblarr(n_hid,n_out)
delw0 = dblarr(n_out)
del_in = dblarr(n_hid)
v0_old = dblarr(n_hid)
v_old = dblarr(n_in,n_hid)
w0_old = dblarr(n_out)
w_old = dblarr(n_hid,n_out)
;.......n_pat > number of training patterns.
;.......n_out > number of output units.
;.......n_hid > number of hidden units.
;.......n_in > number of input units.
;.......x > input units.
;.......z > hidden units.
;.......y > output units.
;.......v > weights from input to hidden units. (w_hid)
;.......v0 > bias on hidden units. (bias_hid)
;.......w > weights from hidden to output units. (w_out)
;.......w0 > bias on output units. (bias_out)
;.......t > target vector.
epoch = 0
ct = 0
last_sum = 10000
flag = 0
sum_error = 0
;
; feed foward phase
;
print, "Training begins..."
print
train:
for pat = 0, n_pat-1 do begin
input = train_set(*,pat)
nnet, bias_hid, w_hid, bias_out, w_out, $
input, z, y
;
; back propagation of error phase
;
k = indgen(n_out)
error = total(abs(targ(pat,k) - y(k)))
del = reform(targ(pat,k) - y(k))*y(k)*(1 - y(k))
if (error gt 0.4) then flag = 1
sum_error = sum_error + error
;
delw0 = alpha*del
for j = 0, n_hid-1 do begin
delw(j,k) = alpha*del(k)*z(j)
endfor
;
del_in = dblarr(n_hid)
for j = 0, n_hid-1 do begin
del_in(j) = total(del*w_out(j,k))
endfor
;
delz = dblarr(n_hid)
for j = 0,n_hid-1 do begin
delz(j) = del_in(j)*z(j)*(1 - z(j))
endfor
;
delv0 = alpha*delz
delv = dblarr(n_in,n_hid)
i = indgen(n_in)
for j = 0,n_hid-1 do begin
delv(i,j) = alpha*delz(j)*input(i)
endfor
;
;.......weight adjustment phase............................
;
j = indgen(n_hid)
v0_new = bias_hid + delv0 + mu*(bias_hid - v0_old)
v0_old = bias_hid
bias_hid = v0_new
for i = 0,n_in-1 do begin
v_new = w_hid(i,j) + delv(i,j) + mu*(w_hid(i,j) - v_old(i,j))
v_old(i,j) = w_hid(i,j)
w_hid(i,j) = v_new
endfor
;
k = indgen(n_out)
w0_new = bias_out + delw0 + mu*(bias_out - w0_old)
w0_old = bias_out
bias_out = w0_new
for j = 0,n_hid-1 do begin
w_new = w_out(j,k) + delw(j,k) + mu*(w_out(j,k) - w_old(j,k))
w_old(j,k) = w_out(j,k)
w_out(j,k) = w_new
endfor
endfor
;
;.......output phase.......................................
;
epoch = epoch + 1
ct = ct + 1
if (ct eq 2) then begin
print, epoch, sum_error
ct = 0
endif
;
if (flag eq 0) then begin
print, epoch, sum_error
print,'learning rate = ', alpha
endif else begin
last_sum = sum_error
sum_error = 0
flag = 0
goto, train
endelse
if keyword_set(outfile)) then begin
s = size(outfile)
if s(0) ne 0 then outfile='weights.fits'
if (s(s(0)+1) ne 7) then outfile='weights.fits'
nnet_write_weights, bias_hid, w_hid, bias_out, w_out, $
outfile=outfile
endif
return
end