Neural networks are used in the field of machine learning and artificial intelligence (AI). A neural network comprises a plurality of nodes which are interconnected by links, sometimes referred to as edges. The input edges of one or more nodes form the input of the network as a whole, and the output edges of one or more other nodes form the output of the network as a whole, whilst the output edges of various nodes within the network form the input edges to other nodes. Each node represents a function of its input edge(s) weighted by a respective weight; the result being output on its output edge(s). The weights can be gradually tuned based on a set of experience data (e.g. training data) to tend towards a state where the output of the network will output a desired value for a given input.
Typically, the nodes are arranged into layers with at least an input and an output layer. A “deep” neural network comprises one or more intermediate or “hidden” layers in between the input layer and the output layer. The neural network can take input data and propagate the input data through the layers of the network to generate output data. Certain nodes within the network perform operations on the data, and the result of those operations is passed to other nodes, and so on.
At some or all the nodes of the network, the input to that node is weighted by a respective weight. A weight may define the connectivity between a node in a given layer and the nodes in the next layer of the neural network. A weight can take the form of a scalar or a probabilistic distribution. When the weights are defined by a distribution, as in a Bayesian model, the neural network can be fully probabilistic and captures the concept of uncertainty. The values of the connections 106 between nodes may also be modelled as distributions. This is illustrated schematically in
The network learns by operating on data input at the input layer, and, based on the input data, adjusting the weights applied by some or all of the nodes in the network. There are different learning approaches, but in general there is a forward propagation through the network from left to right in
The input to the network is typically a vector, each element of the vector representing a different corresponding feature. E.g. in the case of image recognition the elements of this feature vector may represent different pixel values, or in a medical application the different features may represent different symptoms. The output of the network may be a scalar or a vector. The output may represent a classification, e.g. an indication of whether a certain object such as an elephant is recognized in the image, or a diagnosis of the patient in the medical example.
Training in this manner is sometimes referred to as a supervised approach. Other approaches are also possible, such as a reinforcement approach wherein the network each data point is not initially labelled. The learning algorithm begins by guessing the corresponding output for each point, and is then told whether it was correct, gradually tuning the weights with each such piece of feedback. Another example is an unsupervised approach where input data points are not labelled at all and the learning algorithm is instead left to infer its own structure in the experience data.
The present disclosure recognizes that by understanding causal relationships between variables when performing missing value imputation, relationships between variables can be determined. The accuracy of missing value imputation can also be improved. Also, a scalable way to discover causal relationship among variables even under a situation where data values for certain variables are not observed is provided. This can be useful, for example, in a healthcare setting for determining the condition of patients based on sensor data or for diagnosing faults with equipment (e.g. electrical equipment). In an example, the causal relationship between biological measurements from medical sensor data is determined.
According to one aspect disclosed herein, there is provided a computer-implemented method of machine learning. The method comprises receiving an input vector comprising values of variables. The method then comprises using a first neural network to encode the variables of the input vector into a plurality of latent vectors. The plurality of latent vectors can then be input into a second neural network comprising a graph neural network, wherein the graph neural network is parametrized by a graph comprising edge probabilities indicating causal relationships between the variables, in order to determine a computed vector value. The method then comprises tuning the edge probabilities of the graph, one or more parameters of the first neural network and one or more parameters of the second neural network to minimise a loss function, wherein the loss function comprises a measure of difference between the input vector and the computed vector value and a function of the graph.
In some examples, the values of the variables of the input vector may not be fully observed. This is common in real world scenarios where a value may not be obtained for each variable of an input vector.
The method can optimize a graph showing causal relationships between variables of input information. The method can also be used in some examples to impute missing values.
To assist understanding of embodiments of the present disclosure and to illustrate how such embodiments may be put into effect, reference is made, my way of example only, to the accompanying drawings in which:
The following will present a method of determining causal relationships between variables in input vectors. In some examples, the method can also be used to impute missing values of input vectors.
First however there is described an example system in which the presently disclosed techniques may be implemented. There is also provided an overview of the principles behind graph neural networks and variational auto encoders, based upon which embodiments may be built or expanded.
The computing apparatus 200 comprises at least a controller 202, an interface (e.g. a user interface) 204, and an artificial intelligence (AI) algorithm 206. The controller 202 is operatively coupled to each of the interface 204 and the AI algorithm 206.
Each of the controller 202, interface 204 and AI algorithm 206 may be implemented in the form of software code embodied on computer readable storage and run on processing apparatus comprising one or more processors such as CPUs, work accelerator co-processors such as GPUs, and/or other application specific processors, implemented on one or more computer terminals or units at one or more geographic sites. The storage on which the code is stored may comprise one or more memory devices employing one or more memory media (e.g. electronic or magnetic media), again implemented on one or more computer terminals or units at one or more geographic sites. In embodiments, one, some or all the controller 202, interface 204 and AI algorithm 206 may be implemented on the server. Alternatively, a respective instance of one, some or all of these components may be implemented in part or even wholly on each of one, some or all of the one or more user terminals. In further examples, the functionality of the above-mentioned components may be split between any combination of the user terminals and the server. Again, it is noted that, where required, distributed computing techniques are in themselves known in the art. It is also not excluded that one or more of these components may be implemented in dedicated hardware.
The controller 202 comprises a control function for coordinating the functionality of the interface 204 and the AI algorithm 206. The interface 204 refers to the functionality for receiving and/or outputting data. The interface 204 may comprise a user interface (UI) for receiving and/or outputting data to and/or from one or more users, respectively; or it may comprise an interface to a UI on another, external device. Alternatively, the interface may be arranged to collect data from and/or output data to an automated function implemented on the same apparatus or an external device. In the case of an external device, the interface 204 may comprise a wired or wireless interface for communicating, via a wired or wireless connection respectively, with the external device. The interface 204 may comprise one or more constituent types of interface, such as voice interface, and/or a graphical user interface. The interface 204 may present a UI front end to the user(s) through one or more I/O modules on their respective user device(s), e.g. speaker and microphone, touch screen, etc., depending on the type of user interface. The logic of the interface may be implemented on a server and output to the user through the I/O module(s) on his/her user device(s). Alternatively, some or all of the logic of the interface 204 may also be implemented on the user device(s) 102 its/themselves.
The controller 202 is configured to control the AI algorithm 206 to perform operations in accordance with the embodiments described herein. It will be understood that any of the operations disclosed herein may be performed by the AI algorithm 206, under control of the controller 202 to collect experience data from the user and/or an automated process via the interface 204, pass it to the AI algorithm 206, receive predictions back from the AI algorithm and output the predictions to the user and/or automated process through the interface 204.
The AI algorithm 206 comprises a machine-learning model 208, comprising one or more constituent statistical models such as one or more neural networks.
Each node 104 represents a function of the input value(s) received on its input edges(s) 106i, the outputs of the function being output on the output edge(s) 106o of the respective node 104, such that the value(s) output on the output edge(s) 106o of the node 104 depend on the respective input value(s) according to the respective function. The function of each node 104 is also parametrized by one or more respective parameters w, sometimes also referred to as weights (not necessarily weights in the sense of multiplicative weights, though that is certainly one possibility). Thus, the relation between the values of the input(s) 106i and the output(s) 106o of each node 104 depends on the respective function of the node and its respective weight(s).
Each weight could simply be a scalar value. Alternatively, as shown in
As shown in
The different weights of the various nodes 104 in the neural network 100 can be gradually tuned based on a set of experience data (e.g. training data), so as to tend towards a state where the output 108o of the network will produce a desired value for a given input 108i. For instance, before being used in an actual application, the neural network 100 may first be trained for that application. Training comprises inputting experience data in the form of training data to the inputs 108i of the graph and then tuning the weights w of the nodes 104 based on feedback from the output(s) 108o of the graph. The training data comprises multiple different input data points, each comprising a value or vector of values corresponding to the input edge or edges 108i of the graph 100.
For instance, consider a simple example as in
The classification Y could be a scalar or a vector. For instance in the simple example of the elephant-recognizer, Y could be a single binary value representing either elephant or not elephant, or a soft value representing a probability or confidence that the image comprises an image of an elephant. Or similarly of the neural network 100 is being used to text for a particular condition, Y could be a single binary value representing whether the subject has the condition or not, or a soft value representing a probability or confidence that the subject has the condition in question. As another example, Y could comprise a “1-hot” vector, where each element represents a different animal or condition. E.g. Y=[1, 0, 0, . . . ] represents an elephant, Y=[0, 1, 0, . . . ] represents an hippopotamus, Y=[0, 0, 1, . . . ] represents a rhinoceros, et. Or if soft values are used, Y=[0.81, 0.12, 0.05, . . . ] represents an 81% confidence that the image comprises an image of an elephant, 12% confidence that it comprises an image of an hippopotamus, 5% confidence of a rhinoceros, etc.
In the training phase, the true value of Yi for each data point i is known. With each training data point i, the AI algorithm 206 measures the resulting output value(s) at the output edge or edges 108o of the graph, and uses this feedback to gradually tune the different weights w of the various nodes 108 so that, over many observed data points, the weights tend towards values which make the output(s) 108i (Y) of the graph 100 as close as possible to the actual observed value(s) in the experience data across the training inputs (for some measure of overall error). I.e. with each piece of input training data, the predetermined training output is compared with the actual observed output of the graph 302o. This comparison provides the feedback which, over many pieces of training data, is used to gradually tune the weights of the various nodes 104 in the graph toward a state whereby the actual output 108o of the graph will closely match the desired or expected output for a given input 108i. Examples of such feedback techniques include for instance stochastic backpropagation.
Once trained, the neural network 100 can then be used to infer a value of the output 108o (Y) for a given value of the input vector 108i (X), or vice versa.
Explicit training based on labelled training data is sometimes referred to as a supervised approach. Other approaches to machine learning are also possible. For instance another example is the reinforcement approach. In this case, the neural network 100 begins making predictions of the classification Yi for each data point i, at first with little or no accuracy. After making the prediction for each data point I (or at least some of them), the AI algorithm 206 receives feedback (e.g. from a human) as to whether the prediction was correct, and uses this to tune the weights so as to perform better next time. Another example is referred to as the unsupervised approach. In this case the AI algorithm receives no labelling or feedback and instead is left to infer its own structure in the experienced input data.
The one or more inference networks are arranged to receive the observed feature vector X as an input and encode it into a latent vector Z (a representation in a latent space). The one or more generative networks 208p are arranged to receive the latent vector Z and decode back to the original feature space X.
The latent vector Z is a compressed (i.e. encoded) representation of the information contained in the input observations X. In a VAE, no one element of the latent vector Z necessarily represents directly any real world quantity, but the vector Z as a whole represents the information in the input data in compressed form. It could be considered conceptually to represent abstract features abstracted from the input data X, such as “wrinklyness of skin” and “trunk-like-ness” in the example of elephant recognition (though no one element of the latent vector can necessarily be mapped onto any one such factor, and rather the latent vector Z as a whole encodes such abstract information). The decoder 404 is arranged to decode the latent vector Z back into values in a real-world feature space, i.e. back to an uncompressed form representing the actual observed properties (e.g. pixel values). In some examples, see e.g.
The weights w of the one or more inference networks 208q are labelled herein φ, whilst the weights w of the one or more generative networks 208p are labelled θ. Each node 104 applies its own respective weight as illustrated in
When using a VAE, with each data point in the training data (or more generally each data point in the experience data during learning), the weights φ and θ are tuned so that the VAE 208 learns to encode the feature vector X into the latent space Z and back again. For instance, this may be done by minimizing a measure of divergence between qφ (Zi|Xi) and pθ(Xi|Zi), where qσ(Xi|Zi) is a function parameterised by y representing a vector of the probabilistic distributions of the elements of Zi output by the encoder 208q given the input values of Xi, whilst pθ(Xi|Zi) is a function parameterized by θ representing a vector of the probabilistic distributions of the elements of Xi output by the encoder 208q given Zi. The symbol “|” means “given”. The model is trained to reconstruct Xi and therefore maintains a distribution over Xi. At the “input side”, the value of Xi is known, and at the “output side”, the likelihood of Xi under the output distribution of the model is evaluated. The input values of X are sampled from the input data distribution. p(X|Z) is part of the model distribution, and the goal in a VAE of the algorithm 206 is to make p(X) close to the input data distribution. p(X, Z) may be referred to as the model of the decoder, whilst p(Z|X) may be referred to as the posterior or exact posterior, and q(Z|X) as the approximate posterior. p(z) and q(z) may be referred to as priors.
For instance, this may be done by minimizing the Kullback-Leibler (KL) divergence between qφ(Zi|Xi) and pθ(Xi|Zi). The minimization may be performed using an optimization function such as an ELBO (evidence lower bound) function, which uses cost function minimization based on gradient descent. However, in general other metrics and functions are also known in the art for tuning the encoder and decoder neural networks of a VAE.
The requirement to learn to encode to Z and back again amounts to a constraint placed on the overall neural network 208 of the VAE formed from the constituent neural networks 208q, 208p. This is the general principle of an autoencoder. The purpose of forcing the autoencoder to learn to encode and then decode a compressed form of the data, is that this can achieve one or more advantages in the learning compared to a generic neural network; such as learning to ignore noise in the input data, making better generalizations, or because when far away from a solution the compressed form gives better gradient information about how to quickly converge to a solution. In a variational autoencoder, the latent vector Z is subject to an additional constraint that it follows a predetermined form (type) of probabilistic distribution such as a multidimensional Gaussian distribution or gamma distribution.
Nonetheless, an issue with existing machine learning models is that existing imputation methods are agnostic to causality. VAEs do not consider causal relationships between input variables.
To address this, the present disclosure provides a machine learning model that can discover relationships between variables given partial observation and can be used to provide missing value imputation at the same time. In examples, causal discovery is used to help the task of missing value imputation. The causal structure of data is a powerful source of information for real-world decision making, and it can improve and complement other learning tasks. However, historically causality and machine learning research have evolved separately. One of the main challenges in real-world machine learning is the presence of missing data. In some examples, causal discovery can help the task of missing value imputation since the relationships between variables are paramount for such task.
Some examples are scalable, which is useful as the number of possible causal graphs grows exponentially with the number of variables. In some examples, it is not necessary to exhaustively enumerate all possible DAGs for the causal graphs G to find the “best” one, which helps to improve scalability. Some examples perform causal discovery in the presence of missing values, which is not considered in standard approaches. Some examples seek to model complex relationships between variables, so that flexible deep learning models are required. Some examples can discover causal relationships between groups of variables. Variables may be grouped in a smaller number of semantically coherent pre-defined groups. For example, one setting in which such a need arises is in the education domain. Education data can contain student responses to thousands of individual questions, where each question belongs to a broader topic. It is insightful to find relationships between topics instead of individual questions to help teachers adjust the curriculum. For instance, if there exists a causal relationship from one topic to another, the former should be taught earlier in the curriculum. Also, educational data is inherently sparse since it is not feasible to ask every question to every student.
A further example where the method could be applied is in healthcare. For example, if a food log of a set of people (subjects) is recorded, variables could be the weight of consumed apple, consumed orange, consumed banana, consumed cucumber, consumed broccoli and consumed meat for each subject. Further variables could also be blood pressure and blood sugar of each subject. The method could be used to determine the relationship between each of the consumed types of food, blood pressure and blood sugar. Further, if it was desired to determine the relationship between fruit intake, vegetable intake, meat intake, blood pressure and blood sugar, the consumed fruit variables could be grouped together and the consumed vegetable variables could be grouped together when determining causal relationships.
Some examples provide an approach to simultaneously tackle missing data imputation and causal discovery. Some examples provide two outputs in one framework. This is accomplished by inferring a generative model that leverages a structured latent space and a decoder based on Graph Neural Networks (GNN). Namely, the structured latent space endows each variable with its own latent subspace, and the interactions between the subspaces are regulated by a GNN whose behaviour depends on a graph of causal relationships between variables. Some examples leverage continuous optimization of the causal structure to achieve scalability, can be used in the presence of missing data, and can make use of deep learning architectures for increased flexibility.
Moreover, the causal structure can be learned at different levels of granularity when the variables are organised in groups.
An example of information that may be put into a system is shown at 520 in
Each data point i in
In the specific example shown, gender has a causal effect on sugar intake. Sugar intake has a causal effect on diabetes. Age has a causal effect on sugar intake and diabetes.
From
In some examples, graph G can have a probability of a variable having a causal effect along each edge connecting the nodes (variables) of graph G. Such as an example is shown in
In the example of graph G 721 shown in
In the specific example shown, gender has a probability of 0.4 of having a causal effect on sugar intake. Sugar intake has a probability of 0.9 of having a causal effect on diabetes. Age has a probability of 0.7 of having causal effect on sugar intake. Age has a probability of 0.3 of having causal effect on sugar intake and diabetes.
Similar causal relationships are shown in
The causal relationships shown in
An example of a cyclic causal relationship is shown in
Feature vector X 930 (comprising variable A 932a, variable B 932b, variable C 932c and variable D 932d) are input into a first neural network 934. Variable A 932a, variable B 932b, variable C 932c and variable D 932 may be any suitable variable, for example gender, age, diabetes, and sugar intake. First neural network 934 may comprise one or more inference networks. First neural network 934 may comprise an encoder. First neural network 934 may be similar to the inference network 1234 of
First neural network 934 can act as an element-wise encoder for each variable of X 930. In other words, each element of X 930 is encoded to a respective latent vector. The first neural network 934 outputs: latent vector ZA 936a corresponding to variable A 932a; latent vector ZB 936b corresponding to variable B 932b; latent vector ZC 936c corresponding to variable C 932c; latent vector ZD 936d corresponding to variable D 932d.
Each latent vector 936a, 936b, 936c and 936d may be input into a second neural network 938. Second neural network 938 comprises a GNN 940 that can operate on graph inputs. These can be used where each latent vector ZA 936a, ZB 936b, ZC 936c and ZD 936d is defined by a distribution (e.g., a Bayesian model). GNN 940 may comprise weights θGNN. Second neural network 938 may also comprise other neural networks, comprising other weights θ. GNN 940 is parametrized by graph G 921. G 921 dictates that graph over which the GNN 940 operates on. G 921 comprises edge probabilities showing causal relationships between the variables A 932a, B 932b, C 932c and D 932d.
Second neural network 938 then outputs computed vector X 944 comprising computed variable  942a, computed variable {circumflex over (B)} 942b, computed variable Ĉ 942c and computed variable {circumflex over (D)} 942d. A feature vector X may be input into the system of
In some examples, variables A 932a, B 932b, C 932c and D 932d are scalar values. First neural network 924 may convert these scalar values to latent variables ZA 936a, ZB 936b, ZC 936c and ZD 936d that have a probabilistic distribution. During this conversion, first neural network 934 may introduce noise into variables A 932a, B 932b, C 932c and D 932d to convert to latent variables ZA 936a, ZB 936b, ZC 936c and ZD 936d. The distributions may be represented in the form of a set of samples or a set of parameters parameterizing the distribution (e.g. the mean μ and standard deviation σ or variance σ2). The probabilistic distributions of latent variables ZA 936a, ZB 936b, ZC 936c and ZD 936d may be input into GNN 921. Either of GNN 921 or second neural network 928 may then convert to latent variables ZA 936a, ZB 936b, ZC 936c and ZD 936d to scalar output values  942a, {circumflex over (B)} 942b, Ĉ 942c and {circumflex over (D)} 942d of output vector X 944.
The values of graph G showing causal relationships between the variables A 932a, B 932b, C 932c and D 932d can be tuned using the method described below. The tuning may comprise an attempt to optimize the values of G 921. In some examples, the values of edge probabilities showing causal relationships between the variables of graph G 921 is optimized using the below method. In some examples, the below method can also be used to tune (attempt to optimize) values of at least one of θGNN, θ and φ.
When a vector such as vector X 930 is received to be input into the system of
A loss function is then determined based on the difference between input vector X 930 and output vector {circumflex over (X)} 944, where {circumflex over (X)} 944 is the output of the system for the input {tilde over (X)} value. For example, the loss function could be determined by |X−{circumflex over (X)}|2. Any other suitable equation for determining an amount of difference between X 930 and output vector {circumflex over (X)} 944 may be used, however.
In some examples, the loss function may additionally include a function that penalises cyclic relationships between variables in X when determining G 921. This function may increase the value of the loss function when a cyclic relationship is present in G 921. This function can enforce that G 921 is a Direct Acyclic Graph (DAG). The function can considered to be DAG(G), where DAG(G) penalises any cycles between variables in graph G and in some examples removes all cyclicity in graph G 921. By removing cyclicity in graph G, the loss function can be minimized more efficiently to determine values of G, θGNN, θ and φ as the minimization of the loss function will converge to a solution in fewer computational steps. The method may comprise a constraint that G 921 does not comprise any cyclic causal relationships between variables of G 921.
In some examples the loss function may also comprise a further function for regularizing G 921. The function may comprise a measure of difference between an estimated value of a posterior function of G 921, q(G), and a prior value of G 921, p(G). In some examples, the prior value may be a human prior value input for G 921 by a user. In other examples, p(G) may be set arbitrarily. As such, the user can input previous expectations of G 921 into the loss function (e.g. if the user expects sugar causes diabetes, they could reflect this in p(G)). The measure of difference may be determined using any suitable algorithm, for example a Kullback-Liebler (KL) divergence function. The measure of difference may be expressed as KL[q(G)∥p(G)], for example.
A combination of DAG(G) and KL[q(G)∥p(G)] may be considered to be a regularization function of G 921, reg(G).
The loss function may also comprise a regularization of Z (i.e., of the values ZA 936a, ZB 936b, ZC 936c and ZD 936d). The function may comprise a measure of difference between an estimated value of Z 921, q(Z), and a prior value of G 921, p(Z). The prior value may be a human prior value input for Z by a user. In other examples, p(Z) may be set arbitrarily. As such, the user can input previous expectations on Z into the loss function. For example, a user may have an expectation of how the values of Z are distributed. For example, a user may have an expectation of the type of distribution (e.g. a normal distribution), the mean of a distribution and/or a variance of the distribution of the Z values. The measure of difference may be determined using any suitable algorithm, for example a Kullback-Liebler (KL) divergence function. The measure of difference may be expressed as KL[q(Z)∥p(Z)], for example.
KL[q(Z)|p(Z)] may be considered to be a regularization function of Z, reg(Z).
The loss function of the method, L, may in some examples be considered to be:
In some examples, reg (Z) is not included in the loss function.
The method can comprise tuning the values of one or more of G, θGNN, θ and φ in order to minimise the loss function L.
In some examples, tuning the values of one or more of G (e.g. edge probabilities of graph G 921), θGNN (parameters of decoder GNN 921), θ (parameters of decoder 938) and φ (parameters of encoder 934) in order to minimise the loss function L can comprise performing a gradient step for G, θGNN, θ and φ to minimise the loss function. The gradient step may involve the following steps:
The loss function may be minimized over a plurality of N input vectors, to provide tuned values of G, θGNN, θ and φ. This provides further tuning of edge probabilities of the graph G 921 and also provides further tuning of θGNN, θ and φ.
By minimising the loss function L, optimized values for G 921 can be determined. This allows a user to review the optimized values of G to discover causal relationships between variables A 932a, B 932b, C 932c and D 932d.
In some examples, the loss function may only operate on variables that are present in input vector X and not on features that are missing from an input vector. For example, if input vector X 930 only had values for variables A 932a, C 932c and D 932d, but not for B 932b, the loss function would not be applied to variable B. This is because the missing values do not provide information about how to tune G, θGNN, θ and φ.
Once the values of G, θGNN, θ and φ have been tuned using the above described method, the system of
In some examples, variables of input feature vectors may be grouped together to provide a latent vector for the group of variables. This can be useful as the overall amount of computation is reduced, as latent vectors only have to be determined for each group.
An example where variables A, B C, D, E, F and H of vector X 1030 are grouped is shown in
Each group provides a corresponding latent vector (latent vectors Z1 1036a, Z2 1036b and Z3 1036c). The latent vectors can then be used in a similar way to the latent vectors shown in
In some examples, variables may be grouped based on what they represent. For example, when determining a patient's health, a first measurement for heart rate and a second measurement for heart rate made using an alternative method may be grouped together, as they are closely linked. As an example, group 1046a could group together 3 different measures for cholesterol level, group 1046b could group together two different measures for blood pressure and group 1046v could group together two different measures for heart rate. The causal graph G for
In some examples, the variables of input vector X could represent one or more sensor values.
In some examples, the variables of input vector X could represent one or more sensor values for at least one health monitoring device. The above-described method could then discover causal relationships between sensor values for the at least one health monitoring device (e.g. between blood pressure and body temperature). Further, if one of the sensors is malfunctioning, the method can impute any missing values from the sensors. The discovered causal relationships could be used to diagnose a patient. The imputed missing values could also be used to diagnose a patient with a health condition.
In a further example, the variables of input vector X in
In some examples, the variables of input vector X could represent one or more sensor values representing states of a device or a system of devices. The above-described method could then discover causal relationships between sensor values for the device or the system of devices. For example, a causal relationship graph G could show that leaving the device on for a long period of time causes overheating, which causes a device to power off and cause a fault in a network. The causal relationship G could be used to determine reasons why a device or system of devices has malfunctioned.
Further, if one or more of the sensors is malfunctioning, the method can impute any missing values from the sensors. The imputed missing values could also be used to work out if a device or system of devices is malfunctioning and could be used to work out why the device or system of devices is malfunctioning.
In a further example, the above described method of finding causal relationships and missing values could be applied in a power system or other industrial system. Variables may comprise weather and/or air conditions. Further variables may comprise efficiency levels and a type of electrical generator (e.g. wind, solar, tidal, etc.). The system can be used to determine the casual relationships between the variables to determine which types of electrical generator are most efficient during certain conditions. Further, the method can be used even when values for certain variables have not been observed. Also, the missing values can be predicted by the method.
At 1150, the method comprises receiving an input vector comprising values of variables. In some examples, the input vector may have missing values.
In some examples, between 1150 and 1152, one or more values is removed from the input vector. The input vector with the values removed may then be used to determine the computed vector in 1154.
At 1152, the method comprises using a first neural network to encode the variables of the input vector into a plurality of latent vectors.
At 1154, the method comprises inputting the plurality of latent vectors into a second neural network comprising a graph neural network, wherein the graph neural network is parametrized by a graph comprising edge probabilities indicating causal relationships between the variables, in order to determine a computed vector value.
At 1156, the method comprises tuning the edge probabilities of the graph, one or more parameters of the first neural network and one or more parameters of the second neural network to minimise a loss function, wherein the loss function comprises a measure of difference between the input vector and the computed vector value and a function of the graph.
Input vector xn 1230 is input into first neural network 1234 to output latent vectors zn,1 1236a, zn,2 1236b and zn,3 1236c. The latent vectors are input into GNN 1221. GNN 1221 is dictated by a causal relationship graph G, which comprises edge probabilities between each input variable. The edge probabilities may be tuned as described above. The output from GNN 1221 is then input into neural network 1238. This provides an output vector {circumflex over (X)}n 1244 comprising variables {circumflex over (x)}n1 1244a, {circumflex over (x)}n2 1244b and {circumflex over (x)}n3 1244c. As such, the missing value 1232b in Xn 1230 is imputed in {circumflex over (x)}n 1244b of {circumflex over (x)}n 1244.
Input vectors may be input for N data points over n=1, . . . , N for
We propose VICAUSE (missing value imputation with causal discovery), an approach to simultaneously tackle missing data imputation and causal discovery. VICAUSE provides two outputs in one framework. This is accomplished by inferring a generative model that leverages a structured latent space and a decoder based on GNNs. Namely, the structured latent space endows each variable with its own latent subspace, and the interactions between the subspaces are regulated by a GNN whose behavior depends on the graph of causal relationships, see
We describe VICAUSE for discovering causal relationship between variables first, and then present the extension to groups of variables.
Our goal is to develop a model that jointly learns to impute missing values and finds causal relationships between variables. The input to VICAUSE is a N×D training set X={xn}n=1N with N data points and D variables, which may contain missing values. The observed and unobserved training values are denoted XO and XU, respectively. In this work, we assume data are either missing completely at random (MCAR) or missing at random (MAR). The output of VICAUSE is i) a model that is able to impute missing values for a previously unseen test sample {tilde over (x)}∈D ii) a directed graph representing the causal relationships between the D variables. The graph is represented by its adjacency matrix G, i.e. a D×D matrix whose element Gij is 1 if there exists a causal relationship from the i-th variable to the j-th, and is 0 otherwise.
VICAUSE aims to discover the underlying causal relationships given partially observed data, and the learned model can also be used to impute missing data for test samples. We use a score-based approach for causal discovery. Inspired by Bayesian approaches, our score is defined as the posterior probability of G given the partially observed training data, subject to the constraint that G forms a directed acyclic graph (DAG). Thus, our objective is:
To optimize over the causal structure with the DAG constraint in Equation 1, we resort to recent continuous optimization techniques. Namely, it has been shown (“Zheng, X., Aragam, B., Ravikumar, P., and Xing, E. P. (2018). DAGs with NO TEARS: Continuous Optimization for Structure Learning. In Advances in Neural Information Processing Systems”) that G represents a DAG if and only if the non-negative quantity
equals zero. To leverage this DAG-ness characterisation, we introduce a regulariser based on this non-negative quantity to favour the DAG-ness of the solution, i.e.
The model used to compute the score needs to handle partial observations. In addition, with the learned model we can impute missing values given any observations. Thus, given a test sample {tilde over (x)}∈D with partially observed variables, we can estimate the distribution over {tilde over (x)}U (the unobserved values) given {tilde over (x)}O (the observed ones) using the learned model (Equation 9). Next, we present our model design, and then the training and imputation procedures.
Generative model We assume that the observations in X are generated given the relationships G and exogenous noise Z.
We use deep learning, in particular a Graph Neural Network (GNN) for f0, to provide a highly flexible model of the generative process.
Amortized variational inference. The true posterior distribution over Z and G cannot be obtained in closed form in Equation 3, since we use a deep learning architecture. Therefore, we resort to efficient amortized variational inference as in Kingma, D. and Welling, M. (2013). Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114; Kingma, D. P., Welling, M., et al. (2019). An introduction to variational autoencoders. Foundations and Trends® in Machine Learning, 12(4):307-392; and Zhang, C., Bütepage, J., Kjellstrõm, H., and Mandt, S. (2018). Advances in variational inference, IEEE transactions on pattern analysis and machine intelligence, 41(8):2008-2026. Here, we consider a fully factorized variational distribution q(Z, G)=q(G)Πn=1Nqϕ(Zn|xn), where qϕ(Zn|xn) is a Gaussian whose mean and (diagonal) covariance matrix are given by the encoder. For q(G) we consider the product of independent Bernoulli distributions over the edges, that is, each edge is present with a probability Gij ∈(0, 1) to be estimated. With this formulation, the evidence lower bound (ELBO) is
Next, we dive into our choice of generator (decoder), which uses a GNN to regulate the interactions between the variables. Then, we focus on the inference network (encoder), which respects the variable-wise structure of the latent space.
Generator. The generator (also known as decoder) takes Zn and G as input, and outputs the reconstructed {circumflex over (x)}n=fθ(Zn, G), where θ are the decoder parameters. We partition the exogenous noise Zn into D parts, where zn,d is the exogenous noise for each variable d=1, . . . , D. Notice that this defines a variable wise structured latent space. The decoder regulates the interactions between variables with a GNN whose behavior is determined by the relationships in G. Specifically this is done in two steps: GNN message passing layers and a final readout layer yielding the reconstructed sample.
GNN message passing in the generator. In message passing, the information flows between node e in T consecutive node-to-edge (n2e) and edge-to-node (e2n) operations [8, 15]. At the t-th step, each edge i→j has a representation (or embedding) hi→j summarizing the information sent from node i to node j. Since we are interested in the imputation task, where we may want to predict the value of the parents from their children only, we also introduce the backward embedding. This is denoted hi→jb, and codifies the information that the i→j edge lets flow from the j-th to the i-th node (for symmetry, the “standard” embedding is called here forward embedding and denoted hi→jb. Specifically the n2e and e2noperations used in VICAUSE are
Here, t refers to the t-th iteration of message passing (that is, Z(0)=Zn, notice that we omit the subindex n for simplicity). Finally, MLPf, MLPb and MLPe2n are MLPs to be estimated. Interestingly, Eqs. (5)-(6) link together the imputation and causal discovery tasks, since the information flow between two nodes (i.e. variables) is proportional to the weight of the corresponding edge.
An algorithm for training VICAUSE is shown in
Read-out layer in the generator. After T iterations of GNN message passing, we have Z(T). We then apply a final function that maps Z(T) to the reconstructed {circumflex over (x)}, i. e. {circumflex over (x)}=(g(z1T), . . . , g(zDT)), with g given by an MLP. Notice that the decoder parameters θ include the parameters of four neural networks: MLPf, MLPb, MLPe2n and g.
Inference network. As in standard VAEs, the encoder maps a sample xn to its latent representation Zn. In VICAUSE, we additionally ensure that the encoder respects the structure of the latent space. As discussed before, Zn is partitioned in D parts, one for each variable. To obtain the mean and variance of Zn, we utilize a multi-head approach with shared parameter φ={φμ, μσ} for all the variables:
Here, μφ
Given the model described above, we have the final objective to minimize w.r.t. θ, φ and G:
where ELBO is given by Equation 4 and the DAG regulariser R(G) is defined as above.
Evaluating the training loss VICAUSE. VICAUSE can work with any type of data. The log-likelihood term (the first term in Equation 4) is defined according to the data type We use a Gaussian likelihood for continuous variables and a Bernoulli likelihood for binary ones. The standard reparametrization trick is used to sample Zn from the Gaussian distribution qφ(Zn|xn). To backpropagate the gradients through the discrete variable G, we resort to the Gumbel-softmax trick to sample from q(G). The KL[qφ(Zn|xn)∥p(Zn)] term can be obtained in closed-form, since both are Gaussian distributions. The KL[q(G)∥p(G))] term can also be obtained in closed-form, since both are the product of independent Bernoulli distributions over the edges. Notice that this term allows for specifying prior knowledge on the causal structure (e.g. sparsity). Finally, the DAG-loss regulariser in Eq. 8 can be computed by evaluating the function R on a Gumbel-softmax sample from q(G). To make the model adapt to different sparsity levels in the training data X, during training we drop a random percentage of the observed values. The full training procedure for VICAUSE is summarised in Algorithm 2 shown in
Two-step training. Although important for the imputation task, the use of both forward and backward MLPs introduces a symmetry that hampers the correct identification of the causal direction. Namely, if the forward and backward MLPs are similar models, then A→B and B→A produce exactly the same information flow when the two MLPs are swapped To overcome this issue, we propose a two-step training scheme. In the first stage the backward MLP is disabled so that the symmetry is broken and the algorithm can learn the causal structure. In the second stage, we fix the graph structure (i.e. the variational parameter G) and continue to train the model with backward MLP. This two-stage training process allows VICAUSE to leverage the backward MLP for the imputation task without interfering with the causal discovery.
Revisiting the learning objectives. The optimal graph of relationships, which was denoted G* above, is given by the posterior graph of probabilities G (it gives the best score as it maximizes the posterior. Similar to “Ma, C., Tschiatschek, S., Palla, K., Hernandez-Lobato, J. M., Nowozin, S., and Zhang, C. (2019). EDDI: Efficient dynamic discovery of high value information with partial VAE. In Chaudhuri, K. and Salakhutdinov, R., editors, Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, pages 4234-4243. PMLR” and/or “Nazabal, A., Olmos, P. M., Ghahramani, Z., and Valera, I. (2020). Handling incomplete heterogeneous data using vaes. Pattern Recognition, 107:107501”, the trained model can impute missing values for a test instance {circumflex over (x)} as
Therefore, the distribution over {circumflex over (x)}U (the missing values) is obtained by applying the encoder and the decoder {tilde over (x)} as input.
Thus far, we have assumed that the relationship between individual variables are of interest. As discussed above, finding the relationships between groups of variables is needed in many real world applications. Here, we extend VICAUSE to discover relationships between (pre-defined) groups of variables.
Problem definition We assume that the D variables in X are organized in M<<D groups. For each group m=1, . . . , M, we write Im for the variables associated to that group (i.e. Im={4, 5, 6} means that the m-th group contains the fourth, fifth and sixth variables)) The goal is to learn to impute missing values for test samples {circumflex over (x)}∈D (as before), and learn causal relationships between the M groups of variables. In particular, the shape of the learned parameter G is now M×M. Also, the structured latent representation Z is split in M parts, each one corresponding to a different group.
VICAUSE for groups. The formulation of Sec. 2.2 can be naturally generalised to this setting. The generative model is analogous, but each node must be thought now as a group of variables (instead of a single variable). The main difference lies in the mappings that connect the sample xn and its latent representation Zn. Specifically, there are two such mappings: the encoder and the read-out layer in the decoder. Unlike before (Eq. 7), the same neural network cannot be used now for all the latent subspaces, since different groups of variables may have different dimensionalities (namely, the m-th group has a dimensionality of |Im|, i.e. the number of variables in that group). To overcome this, we propose to use a group-specific neural network for each latent subspace. Specifically the encoder computes the mean of the latent variable as
Where χm includes all the variables in the m-th group (i.e., χm=[xi]i∈m), and μϕ
Since VICAUSE tackles missing value imputation and causal discovery simultaneously, we review the related work from both fields Moreover, we review recent works that utilize causality to improve the performance of another deep learning task, similar to VICAUSE.
Causal discovery. Randomized controlled trials are often not possible in real-world. Causal discovery aims to find causal relationships between variables from historical data without additional experiments. There are mainly three type of methods: constraint-based, score-based and functional causal models. Constraint-based ones exploit (conditional) independence tests to find the underlying causal structure, such as PC and Fast Causal Inference (FCI). They have recently been extended to handle partially observed data through test-wise deletion and adjustments. Score-based methods find the causal structure by optimizing a scoring function such as Greedy Equivalence Search (GES) and extensions. In functional causal models, the effect variable is represented as a function of the direct causes and some noise term, with different assumptions on the functional form and the noise. Traditional methods do not scale to large number of variables. Recently, continuous optimization of causal structures has become very popular within score-based methods. In particular, continuous optimization has been combined with GNNs to improve the performance of structural equation models (SEMs). VICAUSE also considers non-linear relationships through a GNN architecture. However, since it jointly learns to impute missing values, VICAUSE leverages a general GNN architecture based on message passing, which is not an extension of linear SEMs as in Eq. (3) of “Yu, Y., Chen, J., Gao, T., and Yu, M. (2019). Dag-gnn: Dag structure learning with graph neural networks. In Proceedings of the 36th International Conference on Machine Learning”. Moreover, VICAUSE treats the graph of relationships in a fully probabilistic manner, handles missing values in the training data, and can deal with groups of variables of different sizes.
Causal deep learning. Continuous optimization of causal structures has been used to boost performance in classification. In CASTLE [“Kyono, T., Zhang, Y., and van der Schaar, M. (2020). Castle: Regularization via auxiliary causal graph discovery. In Larochelle, H., Ranzato, M., Hadsell, R., Balcan, M. F., and Lin, H., editors, Advances in Neural Information Processing Systems, volume 33, pages 1501-1512. Curran Associates, Inc.], structure learning is introduced as a regulariser for a deep learning classification model. This regulariser reconstructs only the most relevant causal features, leading to improved out-of-sample predictions. In SLAPS [“Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., and Dahl, G. E. (2017). Neural message passing for quantum chemistry. In International Conference on Machine Learning, pages 1263-1272. PMLR”], the classification objective is supplemented with a self-supervised task that learns a graph of interactions between variables through a GNN. However, these works are focused on the supervised classification task and they did not advance the performance of causal discovery methods. Causal discovery has also been used within models that predict the dynamics of interacting systems with deep neural networks [“Kipf, T., Fetaya, E., Wang, K.-C., Welling, M., and Zemel, R. (2018). Neural relational inference for interacting systems. In International Conference on Machine Learning, pages 2688-2697. PMLR.”]. Unlike VICAUSE, these approaches are developed for time series with Granger causality.
Missing values imputation. The relevance of missing data in real-world problems has motivated a long history of research. A popular approach is to estimate the missing values based on the observed ones through different techniques. Here, we find popular methods such as missforest, which relies on Random Forest, and MICE, which is based on Bayesian Ridge Regression. Also, the efficiency of amortized inference in generative models has motivated its use for missing values imputation. VICAUSE also leverages amortized inference, although the imputation is informed by the discovered causal relationships through a GNN.
We evaluate the performance of VICAUSE in three different problems: a synthetic experiment where the data generation process is controlled, a semi-synthetic problem (simulated data from a real-world problem) with many more variables, and the real-world problem that motivated the development of the group-level extension.
Baselines. For the causal discovery task, we consider five baselines PC [“Spirtes, P., Glymour, C. N., Scheines, R., and Heckerman, D. (2000). Causation, prediction, and search. MITpress”] and GES[“Chickering, D. M. (2002). Optimal structure identification with greedy search Journal of machine learning research, 3 (Nov):507-554”.] are the most popular methods in constrained-based and score-based causal discovery approaches, respectively. We also consider three recent algorithms based on continuous optimization and deep learning: NOTEARS [Zheng, X., Aragam, B., Ravikumar, P., and Xing, E. P. (2018). DAGs with NO TEARS: Continuous Optimization for Structure Learning. In Advances in Neural Information Processing Systems], the non-linear (NL) extension of NOTEARS [“Zheng, X., Dan, C., Aragam, B., Ravikumar, P., and Xing, E. P. (2020). Learning sparse nonparametric DAGs. In International Conference on Artificial Intelligence and Statistics”], and DAG-GNN [“Yu, Y., Chen, J., Gao, T., and Yu, M. (2019). Dag-gnn: Dag structure learning with graph neural networks. In Proceedings of the 36th International Conference on Machine Learning”]. Unlike VICAUSE, these causality baselines cannot deal with missing values in the training data. Therefore, in the first two of the next three sections we work with fully observed training data. In contrast, the real-world data in the third section comes with partially observed training data, and the goal is to discover group-wise relationships. Thus the causality baselines cannot be used there, as they deal with variable-wise relationships only. For the missing data imputation task, we also consider five baselines Mean Imputing and Majority Vote are popular techniques used as reference, Missforest [“Stekhoven, D. J. and Bühlmann, P. (2012). Missforest-non-parametric missing value imputation for mixed-type data. Bioinformatics, 28(1):112-118”] and MICE [“Buuren, S. v. and Groothuis-Oudshoorn, K. (2010). mice: Multivariate imputation by chained equations in r. Journal of statistical software, pages 1-68”] are two of the most widely-used imputation algorithms, and PVAE [“Ma, C., Tschiatschek, S., Palla, K., Hernandez-Lobato, J. M., Nowozin, S., and Zhang, C. (2019). EDDI: Efficient dynamic discovery of high value information with partial VAE. In Chaudhuri, K. and Salakhutdinov, R., editors, Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, pages 4234-4243 PMLR”] is a recent algorithm based on amortized inference.
Metrics. Imputation performance is evaluated with standard metrics such as RMSE (for continuous variables) and accuracy (for categorical variables). For categorical variables, we also provide the area under the ROC and the Precision-Recall curves (AUROC and AUPR, respectively), which are especially useful for imbalanced data (such as that in Sec. 4.2). Regarding causal discovery, we consider both adjacency and orientation metrics as is common practice. Whereas the former do not take into account the direction of the edges, the latter do. For each metric (adjacency and orientation) we compute recall, precision and F1-score. We also provide causal accuracy, a popular metric introduced in “Claassen, T. and Heskes, T. (2012). A bayesian approach to constraint based causal inference. In Proceedings of the Twenty-Eighth Conference on Uncertainty in Artificial Intelligence pages 207-216” that takes into account edge orientation.
Synthetic experiment We simulate fifteen synthetic datasets. To understand how the number of variables affects VICAUSE, we use D=5, 7, 9 variables (five datasets for each value of D) For each simulated dataset, we first sample the true causal structure G see
Table 1 shows imputation results for the synthetic experiment. Mean and standard error over fifteen datasets.
Table 2 shows Causal discovery results for synthetic experiment (mean and std error over fifteen datasets.
Imputation performance. VICAUSE outperforms the baselines in terms of imputation, and this is consistent across all datasets with different number of variables, see Table 1. The results split by number of variables are shown in Table 8. Therefore, in addition to predicting the relationships between variables, VICAUSE exploits this information to obtain enhanced imputation.
Causal discovery performance. VICAUSE obtains better performance than the causality baselines, see Table 2. The results split by number of variables are shown in Table 10. Notice that NOTEARS (NL) is slightly better in terms of orientation-precision, i.e. the orientation of the edges that it predicts is slightly more reliable. However, this is at the expense of a significantly lower capacity to detect true edges, see the recall and the trade-off between both (F1-score). In this small synthetic experiment, it is possible to visually inspect the predicted graph.
Motivation and dataset description. This experiment extends the previous one in three directions. First, the relationships used are not synthetic, but instead come from a well-studied medical setting [“Tu, R., Zhang, K., Bertilson, B. C., Kjellström, H., and Zhang, C. (2019b). Neuropathic pain diagnosis simulator for causal discovery algorithm evaluation. In 33rd Conference on Neural Information Processing Systems (NeurIPS), DEC08-14, 2019, Vancouver, Canada, volume 32. Neural Information Processing Systems (NIPS)”]. Second, the number of variables considered is 222—significantly larger than before. Third, the variables are binary, rather than continuous. The dataset contains records of different patients regarding the diagnosis of symptoms associated to neuropathic pain. The train and test sets have 1000 and 500 patients respectively, for which 222 binary variables have been measured (the value is 1 if the symptom is present for the patient and otherwise). The data was generated with the Neuropathic Pain Diagnosis Simulator, whose properties have been evaluated from the medical and statistical perspectives.
Table 4 shows causal discovery results for neuropathic pain dataset (mean and std error over five runs).
Table 5 shows average expert evaluation of the topic relationships found in Eedi. Cohen's κ inter-annotator agreement is 0.72 for adjacency and 0.76 for orientation (substantial agreement.
Table 6 shows imputation results for Eedi topics dataset (mean and standard error over five runs).
Imputation performance. VICAUSE shows competitive or superior performance when compared to the baselines, see Table 3. Notice that AUROC and AUPR allow for an appropriate threshold-free assessment in this imbalanced scenario. Indeed, as expected from medical data, the majority of values are 0 (no symptoms); here it is around 92% of them in the test set. Interestingly, it is precisely in AUPR where the differences between VICAUSE and the rest of baselines are larger (except for MICE, whose performance is very similar to that of VICAUSE in this dataset).
Causality results. As in the synthetic experiment, VICAUSE outperforms the causal discovery baselines, see Table 4. Notice that NOTEARS (NL) is slightly better in terms of adjacency-precision, i.e. the edges that it predicts are slightly more reliable. However, this is at the expense of a significantly lower capacity to detect true edges, see the recall and the trade-off between both (F1-score).
Motivation and dataset description. This experiment extends the previous ones in three directions. First, we tackle an important real-world problem in the field of AI-powered educational systems [“Wang, Z., Lamb, A., Saveliev, E., Cameron, P., Zaykov, Y., Hernandez-Lobato, J. M., Turner, R. E., Baraniuk, R. G., Barton, C., Jones, S. P., et al. (2021). Results and insights from diagnostic questions: The neurips 2020 education challenge. arXiv preprint arXiv:2104.0403439”, “Wang, Z., Tschiatschek, S., Woodhead, S., Hernãndez-Lobato, J. M., Jones, S. P., Baraniuk, R. G., and Zhang, C. (2020). Educational question mining at scale: Prediction, analysis and personalization. arXiv preprint arXiv:2003.05980.]. Second, we are interested in relationships between groups of variables (instead of individual variables). Third, the training data is very sparse, with 25.9% observed values. The dataset contains the responses given by 6147 students to 948 mathematics questions. The 948 variables are binary (1 if the student provided the correct answer and 0 otherwise). These 948 questions target very specific mathematical concepts, and they are grouped within a more meaningful hierarchy of topics, see
Imputation results VICAUSE achieves competitive or superior performance when compared to the baselines (Table 6). Although the dataset is relatively balanced (54% of the values are 1), we provide AUROC and AUPR for completeness. Notice that this setting is more challenging than the previous ones, since we learn relationships between groups of variables (topics). Indeed, whereas the group extension allows for more meaningful relationships, the information flow happens at a less granular level. Interestingly, even in this case. VICAUSE obtains similar or improved imputation results compared to the baselines.
Table 3 shows imputation results for neuropathic pain dataset (mean and standard error over five runs).
Table 7 shows a Distribution of the relationships across level 1 topics (Number, Algebra, and Geometry). The item (i, j) refers to edges in the direction i→j. The proportion of relationships inside level 1 topics is 82%, 42% and 34% for VICAUSE, DAG-GNN and Random, respectively.
Causal discovery results between groups. Most of the baselines used so far cannot be applied here because i) they cannot learn relationships between groups of variables and ii) they cannot deal with partially observed training data. DAG-GNN is the only one that can be adapted to satisfy both properties. For the first one we adapt DAG-GNN following the same strategy as in VICAUSE, i.e. replacing missing values with a constant value. For the latter, notice that DAG-GNN can be used for vector-valued variables according to the original formulation. However, all of them need to have the same dimensionality. To cope with arbitrary groups, we apply the group-specific mappings (Eq. 10). Finally, to have an additional reference, we also compare with randomly generated relationships, which we will refer to as Random.
Importantly, this is a real-world dataset with no ground truth on the true relationships. Therefore, we asked two experts (experienced high school teachers working with the Eedi dataset) to assess the validity of the relationships found by VICAUSE, DAG-GNN and Random. For each relationship, they evaluated the adjacency (whether it is sensible to connect the two topics) and the orientation (whether the first one is a prerequisite for the second one). They provided an integer value from 1 (strongly disagree) to 5 (strongly agree), i.e. the higher the better. The complete list of relationships and expert evaluations for VICAUSE, DAG-GNN and Random can be found in Table 11, Table 12 and Table 13, respectively. As a summary, Table 5 shows here the average evaluations: we see that the relationships discovered by VICAUSE score much more highly across both metrics than the baseline models.
Another interesting aspect is how the relationships found between level-3 topics are distributed across higher-level topics (recall
We introduced VICAUSE, a novel approach that simultaneously performs causal discovery and learns to impute missing values. Both tasks are performed jointly: imputation is informed by the discovered relationships and vice-versa. This is achieved through a structured latent space and a GNN-based decoder. Namely, each variable has its own latent subspace, and the interactions between the latent subspaces are governed by the GNN through a (global) graph of relationships.
Moreover, motivated by a real-world problem, VICAUSE is extended to learn the causal relationships among groups of variables (rather than variables themselves). VICAUSE fosters further research. In terms of causality, it would be interesting to carry out a theoretical analysis on identifiability sample complexity etc. In terms of missing values imputation, the quality of the predictions could be enhanced through a more advanced handling of missing data, beyond zero-imputation. For instance, techniques such as the set encoder used “Ma, C., Tschiatschek, S., Palla, K., Hernandez-Lobato, J. M., Nowozin, S., and Zhang, C. (2019). EDDI: Efficient dynamic discovery of high value information with partial VAE. In Chaudhuri, K. and Salakhutdinov, R., editors, Proceedings of the 36th International Conferenceon Machine Learning, volume 97 of Proceedings of Machine Learning Research, pages 4234-4243. PMLR” or PointNet [“Qi, C. R., Su, H., Mo, K., and Guibas, L. J. (2017). Pointnet: Deep learning on point sets for 3d classification and segmentation In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 652-660.”, “Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., and Smola, A. J. (2017). Deepsets. In Guyon, I., Luxburg, U. V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., and Garnett, R., editors, Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc”] could be adapted to the structured latent space defined by VICAUSE. In some examples, it is assumed that the data are missing at random.
Here we specify the complete experimental details for full reproducibility. We first provide all the details for the synthetic experiment (Sec. A.1). Then we explain the differences for the neuropathic pain and the Eedi topics experiments in Sec. A.2 and Sec. A.3, respectively.
Data generation process. We first sample the underlying true causal structure. An edge from variable i to variable j is sampled with probability 0.5 if I<j, and probability 0 if I≥j (this ensures that the true causal structure is a DAG, which is just a standard scenario, and not a requirement for any of the compared algorithms). Then, we generate the data points. Root nodes (i.e. nodes with no parents, like variables 1 and 2 in
Model parameters. We start by specifying the parameters associated to the generative process. We use a prior probability pij=0.05 in p(G) for all the edges. This favours sparse graphs, and can be adjusted depending on the problem at hand. The prior p(Z) is a standard Gaussian distribution, i.e. σz2=1. This provides a standard regularisation for the latent space. The output noise is set to σz2=0.02, which favours the accurate reconstruction of samples. As for the decoder, we perform T=3 iterations of GNN message passing. All the MHLPs in the decoder (i.e. MLPf, MLPb, MLPe2n and g) have two linear layers with ReLU non-linearity. The dimensionality of the hidden layer, which is the dimensionality of each latent subspace, is 256. Regarding the encoder, it is given by a multi-head neural network that defines the mean and standard deviation of the latent representation. The neural network is a MLP with two standard linear layers with ReLu non-linearity. The dimension of the hidden layer is also 256. When using groups, there are as many such MLPs as groups. Finally, recall that the variational posterior q(G) is the product of independent Bernoulli distributions over the edges, with a probability Gij to be estimated for each edge. These values are all initialised to Gij=0.5.
Training hyperparameters. We use Adam optimizer with learning rate 0.001. We train during 300 epochs with a batch size of 100 samples. Each one of the two stages described in the two-step training takes half of the epochs. The percentage of data dropped during training for each instance is sampled from a uniform distribution. When doing the reparametrization trick (i.e. when sampling from Zn), we obtain 1 sample during training (100 samples in test time). For the Gumbel-softmax sample, we use a temperature r=0.5. The rest of hyperparameters are the standard ones in torch.nn.functional.gumbel_softmax, in particular we use soft samples. To compute the DAG regulariser R(G), we use the exponential matrix implementation in torch.matrix_exp. This is in contrast to previous approaches, which resort to approximations. When applying the encoder, missing values in the training data are replaced with the value 0 (continuous variables).
Baselines details. Regarding the causality baselines, we ran both PC and GES with the Causal Command tool offered by the Center for Causal Discovery https://www.ccd.pitt.edu/tools/. We used the default parameters in each case (i.e. disc-bic-score for GES and cg-lr-test for PC). NOTEARS (L), NOTEARS (NL) and DAG-GNN were run with the code provided by the authors in GitHub: https://github.com/xunzheng/notears (NOTEARS (L) and NOTEARS (NL)) and https://github.com/fishmoon1234/DAG-GNN (DAG-GNN). In all cases, we used the default parameters proposed by the authors. Regarding the imputation baselines, MajorityVote and Mean Imputing were implemented in Python. MICE and Missforest were used from Scikit-learn library with default parameters https://scikit-learn.org/stable/modules/generated/sklearn.impute.IterativeImputer.html #sklearn.impute.Iterativ eImputer. For PVAE, we use the authors implementation with their proposed parameters, see https://github.com/microsoft/EDDI.
Other experimental details. VICAUSE is implemented in PyTorch. The code is available in the Supplementary material. The experiments were run using a local Tesla K80 GPU and a compute cluster provided by Azure Machine Learning platform with NVIDIA Tesla V 100 GPU.
Data generation process. We use the Neuropathic Pain Diagnosis Simulator in https://github.com/TURuibo/Neuropathic-Pain-Diagnosis-Simulator. We simulate five datasets with 1500 samples, and split each one randomly in 1000 training and 500 test samples. These five datasets are used for the five independent runs described above.
Model and training hyperparameters. Most of the hyperparameters are identical to the synthetic experiment. However, in this case we have to deal with 222 variables, many more than before. In particular, the number of possible edges is 49062. Therefore, we reduce the dimensionality of each latent subspace to 32, the batch size to 25, and the amount of test samples for Zn to 10 (in training we still use 1 as before). Moreover, we reduce the initial posterior probability for each edge to 0:2. The reason is that, for 0:5 initialization, the DAG regulariser R(G) evaluates to extremely high and unstable values for the 222×222 matrix. Since this is a more complex problem (no synthetic generation), we run the algorithm for 1000 epochs. When applying the encoder, missing values in the training data are replaced with the value 0:5 (binary variables).
Data generation process. The real-world Eedi topics dataset contains 6147 samples. We use a random 80%-10%-10% train-validation-test split. The validation set is used to perform Bayesian Optimization (BO) as described below. The five runs reported in the experimental section come from different initializations for the model parameters.
Model and training hyperparameters. Here, we follow the same specifications as in the neuropathic pain dataset. The only difference is that we perform BO for three hyperparameters: the dimensionality of the latent subspaces, the number of GNN message passing iterations, and the learning rate. The possible choices for each hyperparameter are [5; 10; 15; 20; 25; 30; 35; 40; 45; 50], [3; 5; 8; 10; 12; 14; 16; 18; 20], and [10−4; 10−3; 10−2] respectively. We perform 39 runs of BO with the hyperdrive package in Azure Machine Learning platform https://docs.microsoft.com/en-us/python/api/azureml-train-core/azureml.train.hyperdrive?view=azure-ml-py. We use validation accuracy as the target metric. The best configuration obtained through BO was 15, 8 and 10-4, respectively.
Baselines details. As explained above, in this experiment DAG-GNN is adapted to deal with missing values and groups of arbitrary size. For the former, we adapt the DAG-GNN code to replace missing values with 0:5 constant value, as in VICAUSE. For the latter, we also follow VICAUSE and use as many different neural networks as groups, all of them with the same architecture as the one used in the original code (https://github.com/fishmoon1234/DAG-GNN).
Other experimental details. The list of relationships found by VICAUSE (Table 11) and DAG-GNN (Table 12) aggregates the relationships obtained in the five independent runs. This is done by setting a threshold of 0:35 on the posterior probability of edge (which is initialized to 0:2) and considering the union for the different runs. This resulted in 50 relationships for VICAUSE and 57 for DAG-GNN. For Random, we simulated 50 random relationships. Also, the probability reported in the first column of Table 11 is the average of the probabilities obtained for that relationship in the five different runs.
Table 8 shows imputation results for the synthetic experiment in terms of RMSE (not aggregating by number of variables, D=5; 7; 9). The values are the mean and standard error over five different simulations.
Table 9 shows a mapping between indexes for row/column names in Table 14 and Table 16 and the actual level-2 topic names.
Table 10 show causality results for the synthetic experiment (not aggregating by number of variables, D=5; 7; 9. The values are the mean and standard error over five different simulations.
8
7
3
± 0.112
45 ± 0.0
1
± 0.078
± 0.041
± 0.020
7
± 0.068
4
08 ± 0.061
± 0.081
indicates data missing or illegible when filed
Table 11 shows a Full list of relationships found by VICAUSE in the Eedi topics dataset. Each row refers to one relationship (one edge). From left to right, the columns are the posterior probability of the edge, the sending node (topic), the receiving node (topic), and the adjacency and orientation evaluations from each expert. For each topic, the brackets contain its parent level 2 and level 1 topics.
indicates data missing or illegible when filed
Table 12 shows a full list of relationships found by DAG-GNN in the Eedi topics dataset. Each row refers to one relationship (one edge). From left to right, the columns are the sending node (topic), the receiving node (topic), and the adjacency and orientation evaluations from each expert. For each topic, the brackets contain its parent level 2 and level 1 topics.
indicates data missing or illegible when filed
Table 13 shows a full list of relationships found by Random in the Eedi topics dataset. Each row refers to one relationship (one edge). From left to right, the columns are the sending node (topic), the receiving node (topic), and the adjacency and orientation evaluations from each expert. For each topic, the brackets contain its parent level 2 and level 1 topics.
indicates data missing or illegible when filed
Table 14 shows how the 50 relationships found by VICAUSE are distributed across level 2 topics. The item (i; j) refers to edges in the direction i→j. There are 18 relationships inside level 2 topics (36%). See Table 9 for a mapping between indexes shown here in row/column names and the actual level-2 topic names.
Table 15 shows how the 57 relationships found by DAG-GNN are distributed across level 2 topics. The item (i; j) refers to edges in the direction i→j. There are 8 relationships inside level 2 topics (14%). See Table 9 for a mapping between indexes shown here in row/column names and the actual level-2 topic names.
Table 16 shows how the 50 relationships found by Random are distributed across level 2 topics. The item (i; j) refers to edges in the direction i→j. There are 3 relationships inside level 2 topics (6%). See Table 9 for a mapping between indexes shown here in row/column names and the actual level-2 topic names.
More generally, according to one aspect disclosed herein, there is provided computer-implemented method, the method comprising: receiving an input vector comprising values of variables; using a first neural network to encode the variables of the input vector into a plurality of latent vectors; inputting the plurality of latent vectors into a second neural network comprising a graph neural network, wherein the graph neural network is parametrized by a graph comprising edge probabilities indicating causal relationships between the variables, in order to determine a computed vector value; and tuning the edge probabilities of the graph, one or more parameters of the first neural network and one or more parameters of the second neural network to minimise a loss function, wherein the loss function comprises a measure of difference between the input vector and the computed vector value and a function of the graph.
In embodiments, the method of the above aspect is repeated for a plurality of further input vectors to provide further tuning of the edge probabilities of the graph and one or more parameters of the neural network.
In embodiments, the method comprises: after tuning the edge probabilities of the graph, the one or more parameters of the first neural network and the one or more parameters of the second neural network, setting the edge probabilities of the graph and the one or more parameters of the first neural network and the one more parameters of the second neural network; receiving a further input vector comprising the variables of the input vector, the further input vector having at least one missing value for at least one of the variables and at least one observed value for at least one of the variables; and applying the first neural network and the second neural network to the further input vector to obtain the at least one missing value.
In embodiments, the function of the graph penalizes cyclical relationships by increasing the loss function.
In embodiments, the function of the graph comprises a measure of a difference between two distributions, wherein the first distribution is an estimate of a posterior function of the graph and the second distribution is a predefined user function of the graph.
In embodiments, the loss function only operates on variables that are present in the input vector.
In embodiments, wherein the using the first neural network to encode the variables of the input vectors into the plurality of latent vectors comprises: using the first neural network to encode each variable of the input vector into a respective latent vector.
In embodiments, the method comprises: organizing the variables into a number of groups, wherein the number of groups is less than a number of variables in the input vector, and wherein the using the first neural network to encode the variables of the input vectors into the plurality of latent vectors comprises: using the first neural network to encode each group into a respective latent vector.
Each group may comprise one or more related variables.
The variables may comprise one or more data values representing sensor values of one or more devices.
The one or more devices may comprise a health monitoring device for monitoring a patient and wherein the tuning the edge probabilities of the graph function provides causal relationships between a plurality of health conditions, wherein the method comprises:
In embodiments, the one or more devices comprises a health monitoring device for monitoring a patient and the method comprises: after tuning the edge probabilities of the graph, the one or more parameters of the first neural network and the one more parameters of the second neural network, setting the edge probabilities of the graph and the one or more parameters of the first neural network and the one more parameters of the second neural network; receiving a further input vector comprising the variables of the input vector, the further input vector having at least one missing value for at least one of the variables and at least one observed value for at least one of the variables; and applying the first neural network and the second neural network to the further input vector to obtain the at least one missing value, the missing value representing a health condition.
In embodiments, tuning the edge probabilities of the graph function provides causal relationships between a plurality of sensor measurements; and the method comprises: using the causal relationships to determine one or more faults in the one or more devices.
In embodiments, the method comprises: after tuning the edge probabilities of the graph, the one or more parameters of the first neural network and the one more parameters of the second neural network, setting the edge probabilities of the graph and the one or more parameters of the first neural network and the one more parameters of the second neural network; receiving a further input vector comprising the variables of the input vector, the further input vector having at least one missing value for at least one of the variables and at least one observed value for at least one of the variables; applying the first neural network and the second neural network to the further input vector to obtain the at least one missing value, the missing value representing a state of the device; and using the missing value to determine a fault in the one or more devices.
According to an aspect disclosed herein, there may be provided a computer-implemented method comprising: receiving an input vector comprising values of variables; using a first neural network to encode the values of the variables of the input vector into a plurality of latent vectors; determining an output vector by inputting the plurality of latent vectors into a second neural network comprising a graph neural network, wherein the graph neural network is parametrized by a graph comprising edge probabilities indicating causal relationships between the variables; and minimising a loss function by tuning the edge probabilities of the graph, at least one parameter of the first neural network and at least one parameter of the second neural network, wherein the loss function comprises a function of the graph and a measure of difference between the input vector and the output vector.
In embodiments, the method is repeated for a plurality of further input vectors to provide further tuning of the edge probabilities of the graph, the at least one parameter of the first neural network and the at least one parameter of the second neural network.
In embodiments, the method comprises: after minimising the loss function by tuning the edge probabilities of the graph, the at least one parameter of the first neural network and the at least one parameter of the second neural network: setting the edge probabilities of the graph, the at least one parameter of the first neural network and the at least one parameter of the second neural network; receiving a further input vector comprising the variables of the input vector, the further input vector having at least one missing value for at least one of the variables and at least one observed value for at least one of the variables; and applying the first neural network and the second neural network to the further input vector to obtain the at least one missing value.
In embodiments, the function of the graph increases the value of the loss function when a cyclic relationship is present in the graph.
In embodiments, the function of the graph comprises a measure of a difference between two distributions, wherein the first distribution is an estimate of a posterior function of the graph and the second distribution is a predefined user function of the graph.
In embodiments, the loss function only operates on variables that are present in the input vector.
In embodiments, the using the first neural network to encode the values of the variables of the input vector into the plurality of latent vectors comprises: using the first neural network to encode each variable of the input vector into a respective latent vector.
In embodiments, the method comprises: organizing the values of the variables into a number of groups, wherein the number of groups is less than a number of variables in the input vector, and wherein the using the first neural network to encode the values of the variables of the input vector into the plurality of latent vectors comprises: using the first neural network to encode each group into a respective latent vector.
In embodiments, each group comprises at least one related variable.
In embodiments, the variables comprise at least one data value representing at least one sensor value of at least one device.
In embodiments, the at least one device comprises a health monitoring device for monitoring a patient and wherein the tuning the edge probabilities of the graph function provides causal relationships between a plurality of health conditions, wherein the method comprises: using the causal relationships to diagnose a patient.
In embodiments, the at least one device comprises a health monitoring device for monitoring a patient and wherein the method comprises, after tuning the edge probabilities of the graph, the at least one parameter of the first neural network and the at least one parameter of the second neural network: setting the edge probabilities of the graph, the at least one parameter of the first neural network and the at least one parameter of the second neural network; receiving a further input vector comprising the variables of the input vector, the further input vector having at least one missing value for at least one of the variables and at least one observed value for at least one of the variables; and applying the first neural network and the second neural network to the further input vector to obtain the at least one missing value, the missing value representing a health condition.
In embodiments, the tuning the edge probabilities of the graph provides causal relationships between a plurality of sensor measurements; and wherein the method comprises: using the causal relationships to determine at least one fault in the at least one device.
In embodiments, the method comprises, after minimising the loss function by tuning the edge probabilities of the graph, the at least one parameter of the first neural network and the at least one parameter of the second neural network: setting the edge probabilities of the graph, the at least one parameter of the first neural network and the at least one parameter of the second neural network; receiving a further input vector comprising the variables of the input vector, the further input vector having at least one missing value for at least one of the variables and at least one observed value for at least one of the variables; applying the first neural network and the second neural network to the further input vector to obtain the at least one missing value, the missing value representing a state of the device; and using the missing value to determine a fault in the at least one device.
According to an aspect disclosed herein, there is provided storage comprising at least one memory unit and a processing apparatus comprising at least one processing unit; wherein the storage stores code arranged to run on the processing apparatus, the code being configured so as when thus run to perform the operations of: receiving an input vector comprising values of variables; using a first neural network to encode the values of the variables of the input vector into a plurality of latent vectors; determining an output vector by inputting the plurality of latent vectors into a second neural network comprising a graph neural network, wherein the graph neural network is parametrized by a graph comprising edge probabilities indicating causal relationships between the variables; and minimising a loss function by tuning the edge probabilities of the graph, at least one parameter of the first neural network and at least one parameter of the second neural network, wherein the loss function comprises a function of the graph and a measure of difference between the input vector and the output vector.
According to another aspect disclosed herein, there may be provided a computer program embodied on computer-readable storage, the program comprising code configured so as when run on one or more processors to perform the operations of any method disclosed herein.
According to another aspect disclosed herein, there is provided a computer system comprising: storage comprising one or more memory units, and processing apparatus comprising one or more processing units; wherein the storage stores code arranged to run on the processing apparatus, the code being configured so as when thus run to perform the operations of any method disclosed herein.
Other variants and applications of the disclosed techniques may become apparent to a skilled person once given the disclosure herein. The scope of the present disclosure is not limited by the described embodiments but only by the accompanying claims.
Number | Date | Country | Kind |
---|---|---|---|
21186786.6 | Jul 2021 | EP | regional |
Filing Document | Filing Date | Country | Kind |
---|---|---|---|
PCT/US2022/035404 | 6/29/2022 | WO |