The invention relates to artificial neural networks, and in particular, to a method of training an artificial neural network using sparse connectivity learning.
An artificial neural network is a network including multiple processing units arranged in layers and operating in parallel. Typically, a conventional artificial neural network is fully connected, that is, all processing units in one layer are connected to all processing units in the preceding layer. However, such network arrangements are often complex in structure, require excessive memory resources and power consumption, and suffer from overfitting.
According to one embodiment of the invention, a computing network includes a plurality of processing nodes. A method of training the computing network includes: a processing node in the plurality of processing nodes computing an output estimate according to a weight defined by a weight variable and a connectivity mask, the connectivity mask representing a connection between the processing node and a preceding processing node in the plurality of processing nodes and being derived from a connectivity variable; and adjusting connectivity variables according to an objective function to reduce a total number of connections between the plurality of processing nodes and reduce a performance loss indicative of how different the output estimate is from a target value.
These and other objectives of the present invention will no doubt become obvious to those of ordinary skill in the art after reading the following detailed description of the preferred embodiment that is illustrated in the various figures and drawings.
The artificial neural network 1 may include layers Lyr(1) to Lyr(J), J being a positive integer exceeding 1. The layer Lyr(1) may be referred to as an input layer, the layer Lyr(J) may be referred to as an output layer, and layers Lyr(2) to Lyr(J−1) may be referred to as hidden layers. Each layer Lyr(j) may include a plurality of processing nodes coupled to a plurality of processing nodes in a preceding layer Lyr(j−1) via connections C1J to C|Cj|j, j being a layer index varying between 2 and J, and |Cj| is the total number of connections between the layer Lyr(j) and the preceding layer Lyr(j−1). The input layer Lyr(1) may contain processing nodes N11 to N|N1|1, where the superscript represents a layer index, the subscript represents a node index, and |N1| is the total number of the processing nodes in the input layer Lyr(1). The processing nodes N11 to N|N1|1 may receive input data x11 to x|N1|1, respectively. Each hidden layer Lyr(j) in the hidden layers Lyr(2) to Lyr(J−1) may contain processing nodes N1j to N|Nj|j, where |Nj| is the total number of processing nodes in the hidden layer Lyr(j). The output layer Lyr(J) may contain processing nodes N1J to N|NJ|J, where |NJ| is the total number of the processing nodes in the output layer Lyr(J). The processing nodes N1J to N|NJ|J may generate output estimates y1J to y|NJ|J, respectively.
Each processing node in the layer Lyr(j) may be coupled to one or more processing nodes in the preceding layer Lyr(j−1) via connections therebetween. Each connection may be associated with a weight, the processing node may compute a weighted sum of one or more pieces of input data from the processing nodes in the preceding layer Lyr(j−1). A connection associated with a weight larger in magnitude is more influential in generating the weighted sum than a connection associated with a weight smaller in magnitude. When the value of a weight is 0, the connection associated with the weight may be regarded as being eliminated from the artificial neural network 1, achieving network connectivity sparsity, and reducing computational complexity, power consumption and operational costs. The artificial neural network 1 may be trained to include an optimized sparse network structure to deliver the output estimates y1J to y|NJ|J closely matching respective target values Y(1) to Y(|NJ|) using a reduced or minimal number of the connections C12 to C|CJ|J.
y=w*x Equation (1)
The input data x may be (1×1) in size. The weight w may be referred to as a kernel, and may be (1×1) in size. “*” may represent a convolution operation. The output estimate y may be passed to a subsequent processing node as input data thereof to compute a subsequent output estimate. The weight w may be re-parameterized into a weight variable {tilde over (w)} and a connectivity mask m, as expressed by Equation (2):
w={tilde over (w)}⊙m Equation (2)
The connectivity mask m may be a binary number representing connectivity of the connection, with 1 representing a connection and 0 representing no connection. The weight variable {tilde over (w)} may represent a strength of the connection. “⊙” may represent an element-wise multiplication. The connectivity mask m may be derived by performing a unit step operation H(•) on a connectivity variable {tilde over (m)}, as expressed by Equation (3).
The processing node Nkj may binarize the connectivity variable according to the unit step operation H(•) to generate the connectivity mask m. By re-parameterizing the weight w, the connectivity and the strength of the connection may be respectively trained by adjusting the connectivity variable {tilde over (m)} and the weight variable {tilde over (w)}. If the connectivity variable {tilde over (m)} is less than or equal to 0, the weight variable {tilde over (w)} may be zero-masked to generate a zero-weight w, and if the connectivity variable {tilde over (m)} exceeds 0, the weight variable {tilde over (w)} may be assigned as the weight w.
In the artificial neural network 1, the connections C12 to C|CJ|J may be associated with connectivity variables {tilde over (m)}12 to {tilde over (m)}|CJ|J and weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J, respectively. The connectivity variables {tilde over (m)}12 to {tilde over (m)}|CJ|J and weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J may be trained according to an objective function to reduce a total number of the connections C12 to C|CJ|J while reducing a performance loss of the artificial neural network 1. The total number of connections C12 to C|CJ|J may be computed by summing all the connectivity masks m12 to m|CJ|J. The performance loss may represent how different the output estimates y1J to y|NJ|J are from the respective target values Y(1) to Y(|NJ|), and may be computed in form of a cross entropy or a squared error. The objective function L may be expressed as Equation (4):
L=CE+λ1Σj=2JΣi=1|Cj|mij+λ2Σj=2JΣi=1|Cj|({tilde over (w)}ij)2 Equation (4)
where CE is a cross entropy;
λ1 is a connectivity decay coefficient;
λ2 is a weight decay coefficient;
j is a layer index;
i is a mask index or a weight index;
mij is the ith connectivity mask of a jth layer;
|Cj| is the total number of the connections of the jth layer; and
{tilde over (w)}ij is the ith weight variable of the jth layer.
The objective function L may include the cross entropy CE between the output estimates y1J to y|NJ|J and the respective target values Y(1) to Y(|NJ|), an L0 regularization term of the total number of connections C12 to C|CJ|J, and an L2 regularization term of the weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J associated with the connections C12 to C|CJ|J. In some embodiments, a sum of squared errors between the output estimates y1J to y|NJ|J and the respective target values Y(1) to Y(|NJ|) may replace the cross entropy CE in the loss function L. The L0 regularization term may be a product of the connectivity decay coefficient λ1 and the sum of the connectivity masks m12 to m|CJ|J. The L2 regularization term may be a product of the weight decay coefficient λ2 and the sum of weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J. In some embodiments, the L2 regularization term may be omitted from the loss function L. The artificial neural network 1 is trained to minimize the output result of the objective function L. Therefore, the L0 regularization term may penalize a large number of connections, and the L2 regularization term may penalize large weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J. The larger the connectivity decay coefficient λ1 is, the sparser the neural network 1 will be. The connectivity decay coefficient λ1 may be set to be a large constant to drive the connectivity masks m12 to m|NJ|J towards 0, pushing the connectivity variables {tilde over (m)}12 to {tilde over (m)}|CJ|J towards the negative direction, and leading to a sparse connectivity structure of the artificial neural network 1. Only when a connection Cij is important in reducing the cross entropy CE, the connectivity mask mij associated with the connection Cij may be left being 1. In this manner, a balanced point between reducing the cross entropy CE and reducing the total number of connections may be achieved to result in the sparse connectivity structure while producing the output estimates yiJ to y|NJ|J substantially matching the target values Y(1) to Y(|NJ|). Similarly, the weight decay coefficient λ2 may be set to be a large constant to shrink the values of the weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J, while the cross entropy CE ensures important weight variables remain in the artificial neural network 1, leading to a simple and accurate model of the artificial neural network 1.
During training of the connectivity variables {tilde over (m)}12 to {tilde over (m)}|CJ|J, the input data x11 to x|N1|1 may be fed into the input layer Lyr(1) and forward-propagated from the layers Lyr(1) to Lyr(J) to generate the output estimates y1J to y|NJ|1, errors between the output estimates y1J to y|NJ|J and the respective target values Y(1) to Y(|NJ|) may be computed and back-propagated from the layers Lyr(J) to Lyr(2) to compute connectivity variable gradients of an objective function L with respect to the connectivity variables {tilde over (m)}12 to {tilde over (m)}|CJ|J, and then the connectivity variables {tilde over (m)}12 to {tilde over (m)}|CJ|J may be adjusted according to the connectivity variable gradients with respect to the connectivity variables {tilde over (m)}12 to {tilde over (m)}|CJ|J, so as to reduce a total number of the connections C12 to C|CJ|J while reducing a performance loss of the artificial neural network 1. Specifically, the connectivity variable {tilde over (m)} may be adjusted until the corresponding connectivity variable gradient
reaches 0 in order to find a local minimum of the cross entropy CE. Nevertheless, according to the derivative chain rule, the computation of the connectivity variable gradient
involves differentiation of the unit step function in Equation (3), and the differentiation of the unit step function would result in 0 almost for all values of connectivity variable {tilde over (m)}, setting the connectivity variable gradient
to be 0, terminating the training process, and leading to no update of the connectivity variable {tilde over (m)}. In order to keep the connectivity variable {tilde over (m)} trainable during the training process, during backpropagation, the unit step function is skipped and the connectivity variable gradient
may be redefined as a connectivity mask gradient
of the objective function L with respect to the connectivity mask m, and may be expressed by Equation (5):
Referring to
In some embodiments, the connectivity mask gradient
may be computed as an element-wise multiplication of a corresponding weight gradient
and the corresponding weight variable {tilde over (w)}, as indicated in Equation 5. In this fashion, when it is determined that the connection is negligible in reducing the cross entropy CE, the connectivity variable {tilde over (m)} may be updated from a positive number to a negative number, and the connectivity mask m may be updated from 1 to 0. When it is determined that the connection is essential in reducing the cross entropy CE, connectivity variable {tilde over (m)} may be updated from a negative number to a positive number, and the connectivity mask m may be updated from 0 to 1. In some embodiments, each mini-batch of input datasets may be input into the artificial neural network to generate plural sets of output estimates y1J to y|NJ|J, a mean error of the plural sets of output estimates y1J to y|NJ|J may be computed, and the connectivity variables {tilde over (m)}12 to {tilde over (m)}|CJ|J may be trained according to backpropagation of the mean error. In some embodiments, the connectivity variable gradient
or the connectivity mask gradient
may be normalized to a standard deviation of 1 for each mini-batch of input datasets, in order to avoid different scales of the gradient
and the weight variable {tilde over (w)}.
Similarly, during training of the weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J, weight variable gradients of the objective function L with respect to the weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J may be computed by backpropagation of the errors, and then the weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J may be adjusted according to the weight variables gradients, so as to reduce the weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J while reducing the performance loss of the artificial neural network 1. the weight variable {tilde over (w)} may be adjusted until the corresponding weight variable gradient
reaches 0 in order to find a local minimum of the cross entropy CE. According to Equation (2) and the derivative chain rule, the weight variable gradient
may be expressed by Equation (6):
According to Equation (5), the weight variable gradient
is 0 when the connectivity mask m is 0, leading to no update of the weight variable {tilde over (w)} and termination of the training process. In order to keep the weight variable {tilde over (w)} trainable, during backpropagation, the weight variable gradient
may be redefined as a weight gradient
of the objective function L with respect to the weight w, and may be expressed by Equation (7):
By redefining the weight variable gradient
to be the weight gradient
the weight variable {tilde over (w)} may remain trainable even when the connectivity mask m is 0. Referring to
may be obtained by backpropagation. The weight variable {tilde over (w)} may be updated according to the weight gradient
regardless of the connectivity mask m being 1 or 0. In this fashion, the weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J may be trained even if some of the weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J are zero-masked temporarily.
The artificial neural network 1 separates the weights w into the connectivity variables {tilde over (m)} and the weight variables {tilde over (w)}, trains the connectivity variables {tilde over (m)} to form sparse connectivity structure, and trains the weight variables {tilde over (w)} to form a simple model for the artificial neural network 1. Further, in order to train the connectivity variables {tilde over (m)} and the weight variables {tilde over (w)}, the connectivity variable gradient
is redefined as the connectivity mask gradient
and the weight variable gradient
is redefined as the weight gradient
The resultant sparse connectivity structure of the artificial neural network 1 can significantly reduce computational complexity, memory requirements and power consumption.
Step S302: The processing node Nkj computes an output estimate according to a weight w defined by a weight variable {tilde over (w)} and a connectivity mask m, the connectivity mask m being derived from a connectivity variable {tilde over (m)};
Step S304: Adjust the connectivity variables {tilde over (m)}12 to {tilde over (m)}|CJ|J according to an objective function L to reduce a total number of connections and reduce a performance loss;
Step S306: Adjust the weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J according to the objective function L to reduce a sum of the weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J.
Explanations for Steps S302 to S306 are provided in the preceding paragraph and will not be repeated here. The training method 300 trains the connectivity variables {tilde over (m)}12 to {tilde over (m)}|CJ|J and the weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J separately to generate an artificial neural network 1 that is sparse in connection, simple in structure, and accurate in output prediction.
The artificial neural network 1 and the training method 300 are utilized to train the connectivity variables {tilde over (m)}12 to {tilde over (m)}|CJ|J and the weight variables {tilde over (w)}12 to {tilde over (w)}|CJ|J, producing sparse network connectivity while delivering accurate outputs.
Those skilled in the art will readily observe that numerous modifications and alterations of the device and method may be made while retaining the teachings of the invention. Accordingly, the above disclosure should be construed as limited only by the metes and bounds of the appended claims.
This application claims the benefit of U.S. Provisional Application No. 62/851,652, filed on May 23, 2019, and included herein by reference in its entirety.
Number | Date | Country | |
---|---|---|---|
62851652 | May 2019 | US |