The present application claims the priority to a Chinese Patent Application No. 202210649202.3, filed with the China National Intellectual Property Administration on Jun. 9, 2022 and entitled “Multi-turn dialogue system and method based on retrieval”, which is incorporated herein by reference in its entirety.
The present application relates to a multi-turn dialogue system and method based on retrieval, belonging to the technical field of natural language processing, in particular to the technical field of dialogue robot.
It is a challenging task in the field of artificial intelligence to create a robot that can communicate naturally with humans in the open field. At present, there are two methods to build such a dialogue robot, i.e., a generative-based method and a retrieval-based method. The generative-based method directly generates a reply based on a language model trained on a large-scale dialogue data set, while the retrieval-based method selects a best matched reply from a candidate set.
The task of the dialogue robot based on retrieval is: given a candidate set consisting of the first n turns of dialogue and several candidate answers, the model is required to select a dialogue that is most suitable for the n+1-th turn of dialogue from the candidate set, which is an important and quite challenging task. A core step of this task is to calculate a matching score of n turns of dialogues (i.e. context) and candidate answers. Some early methods were to aggregate n turns of dialogues into a dense vector through a recurrent neural network, and then calculate a cosine similarity between candidate answer representation and the dense vector, so as to select a candidate answer with the highest score as the answer. In order to avoid the loss of context information, Wu et al. proposed an SMN model (see the literature: Wu Y, Wu W, Xing C, et al Sequential Matching Network: A New Architecture for Multi-turn Response Selection in Retrieval-Based Chatbots[C]; proceedings of the Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics, ACL 2017, Vancouver, Canada, July 30-August 4, Volume 1: Long Papers, F, 2017. Association for Computational Linguistics.), which matches the candidate answers with each turn of dialogue to obtain a matching vector, and then aggregates the n matching vectors in time sequence through the recurrent neural network RNN, and then calculates a score with the aggregated vector. Later, a lot of work continued this way. However, these work have neglected that the context information in the matching phase is very important.
The method of Gu et al. (see the literature: Gu J, Li T, Liu Q, et al. Speaker-Aware BERT for Multi-Turn Response Selection in Retrieval-Based Chatbots[C]; proceedings of the CIKM, F, 2020.) considers the problem of global information participating in interaction, but their method simply compresses the entire dialogue context and matches candidate answers to achieve the purpose of global information participating in interaction. The matching granularity of this method is too coarse, and the compression of too long dialogue into a dense vector will lose a lot of information and may introduce noise.
In summary, how to effectively use dialogue context information to improve the matching accuracy of dialogue candidate answers has become an urgent technical problem in the technical field of dialogue robot.
In view of this, the present invention aims to invent a system and method based on retrieval, which can dynamically absorb the context information in the matching process of dialogue-candidate answer, and improve the matching accuracy of dialogue-candidate answer.
In order to achieve the above purpose, the present invention proposes a multi-turn dialogue system based on retrieval, which includes the following modules:
Û
k
=f
catt(Uk-1,Uk-1,C)
{circumflex over (R)}
k
=f
catt(Rk-1,Rk-1,C)
Ū
k
=f
catt(Uk-1,Rk-1,C)
k
=f
catt(Rk-1,Uk-1,C)
Ũ
k
=[U
k-1
,Û
k
,Ū
k
,U
k-1
⊙Ū
k]
{tilde over (R)}
k
=[R
k-1
,{circumflex over (R)}
k
,
k
,R
k-1
⊙
k]
U
k=max(0,WhŨk-1+bh)
R
k=max(0,Wh{tilde over (R)}k-1+bh)+Rk-1
In the formulas, Uk-1∈m×d and Rk-1∈n×d represent inputs of a k-th global interaction layer, wherein m and n represent the number of words contained in a current turn of dialogue and the number of words contained in the candidate answer, respectively, and inputs of a first global interaction layer are U0=Eu, R0=Er; Wh∈4d×d and bh are training parameters; an operator ⊙ represents a multiplication of elements; d represents the dimension of a vector;
In the formulas, fcatt( ) represents the global attention mechanism, which is specifically defined as follows:
f
catt(Q,K,C)={tilde over (Q)}+FNN({tilde over (Q)})
{circumflex over (Q)}=S(Q,K,C)·K
C
q=softmax(QW2CT)·C
C
k=softmax(KW2CT)·C
M
i
=M
i,self
|M
i,interaction
⊕M
i,enhanced
The word level vector in the representation module is obtained by a tool Word2vec; the character level vector is obtained by encoding character information through the convolutional neural network.
The specific calculation process of the short-term dependence information sequence (h1, . . . , hl) is:
h
i
=GRU(vi,hi-1)
The specific calculation process of the long-term dependence information sequence (g1, . . . , gl) is:
(gl, . . . ,gl)=MultiHead(Q,K,V)
where,
Q=V
m
W
Q
, K=V
m
W
K
, V=V
m
W
V,
The specific process of the prediction module to calculate a matching score of the context c and the candidate answer r involved in matching is calculated as follows:
to obtain (ĝ1, . . . , ĝl), wherein ⊙ represents the multiplication of elements;
{tilde over (g)}
i
=GRU(ĝi,{tilde over (g)}i-1)
g(c,r)=σ({tilde over (g)}l·wo+bo)
In the above formula, σ(·) represents a sigmoid function, wo and bo are training parameters.
The system is trained using the following loss function:
The present invention also proposes a multi-turn dialogue method based on retrieval, comprising:
Û
k
=f
catt(Uk-1,Uk-1,C)
{circumflex over (R)}
k
=f
catt(Rk-1,Rk-1,C)
Ū
k
=f
catt(Uk-1,Rk-1,C)
k
=f
catt(Rk-1,Uk-1,C)
Ũ
k
=[U
k-1
,Û
k
,Ū
k
,U
k-1
⊙Ū
k]
{tilde over (R)}
k
=[R
k-1
,{circumflex over (R)}
k
,
k
,R
k-1
⊙
k]
U
k=max(0,WhŨk-1+bh)
R
k=max(0,Wh{tilde over (R)}k-1+bh)+Rk-1
f
catt(Q,K,C)={tilde over (Q)}+FNN({tilde over (Q)})
{circumflex over (Q)}=S(Q,K,C)·K
C
q=softmax(QWaCT)·C
C
k=softmax(KWaCT)·C
M
i
=M
i,self
⊕M
i,interaction
⊕M
i,enhanced
h
i
=GRU(vi,hi-1)
(g1, . . . ,gl)=MultiHead(Q,K,V)
Q=V
m
W
Q
, K=V
m
W
K
, V=V
m
W
V,
to obtain (ĝ1, . . . , ĝl), wherein ⊕ represents the multiplication of elements;
{tilde over (g)}
i
=GRU({tilde over (g)}i,{tilde over (g)}i-1)
g(c,r)=σ({tilde over (g)}l·wo+bo)
The beneficial effect of the present invention is that the system and method of the present invention extend a general attention mechanism to a global attention mechanism, and dynamically absorb the context information in the dialogue-candidate answer matching process. The system of the present invention can simultaneously capture the short-term dependence and long-term dependence of the matching information sequence, and effectively improve the matching accuracy of dialogue-candidate answer.
In order to make the purpose, technical solutions and advantages of the present invention more clear, the present invention is further described in detail below in combination with the drawings.
Referring to
The attention mechanism is the basis of the global interaction layer. For details of attention mechanism, see Vaswani A, Shazeer N, Parmar N, et al. Attention is All you Need[C]; proceedings of the NIPS, F, 2017.
The specific calculation process is as follows:
Û
k
=f
catt(Uk-1,Uk-1,C)
{circumflex over (R)}
k
=f
catt(Rk-1,Rk-1,C)
Ū
k
=f
catt(Uk-1,Rk-1,C)
k
=f
catt(Rk-1,Uk-1,C)
Ũ
k
=[U
k-1
,Û
k
,Ū
k
,U
k-1
⊙Ū
k]
{tilde over (R)}
k
=[R
k-1
,{circumflex over (R)}
k
,
k
,R
k-1
⊙
k]
U
k=max(0,WhŨk-1+bh)
R
k=max(0,Wh{tilde over (R)}k-1+bh)+Rk-1
In the above formulas, Uk-1∈m×d and Rk-1∈n×d represent inputs of a k-th global interaction layer, wherein m and n represent the number of words contained in a current turn of dialogue and the number of words contained in the candidate answer, respectively, and inputs of the first global interaction layer is U0=Eu, R0=Er; Wh∈4d×d and bh are training parameters; the operator ⊙ represents a multiplication of elements; d represents the dimension of a vector
In the above formula, fcatt( ) represents the described global attention mechanism, which is specifically defined as follows:
f
catt(Q,K,C)={tilde over (Q)}+FNN({tilde over (Q)})
In the above formula, FNN({tilde over (Q)})=max(0,{tilde over (Q)}Wf+bf)Wg+bg, wherein W{f,g}∈d×d and b{f,g} are trainable parameters, Q and {circumflex over (Q)} are mixed using a residual connection to obtain {tilde over (Q)}.
In this embodiment, Q and {circumflex over (Q)} are mixed using a residual connection used by He et al. (For details, see He K, Zhang X, Ren S, et al. Deep Residual Learning for Image Recognition[C]; proceedings of the CVPR 2016, F, 2016.) in the present invention to obtain a new Q. In order to prevent gradient explosion or gradient disappearance, the present invention uses a layer normalization (For detail, see: Ba L J, Kiros J R, Hinton G E. Layer Normalization[J]. CoRR, 2016.).
Wherein {circumflex over (Q)} is calculated according to the following formula:
{circumflex over (Q)}=S(Q,K,C)·K
In the above formula, Q∈n
In the above formula, W{b,c,d,e} are trainable parameters, Ciq represents the i-th row of Cq, and its physical meaning is the fusion context information related to the i-th word in the query sequence Q; Clk represents the j-th row of Ck, and its physical meaning is the fusion context information related to the j-th word of the key sequence K;
C
q=softmax(QWaCT)·C
C
k=softmax(KWaCT)·C
A convolutional neural network is used to extract a d dimension matching vector vi from a matching image Mi of the i-th turn of dialogue, and the matching vector from the first to l-th turn of dialogues are represented by (v1, . . . , vl); the matching image Mi of the i-th turn of dialogue is calculated according to the following formula:
M
i
=M
i,self
⊕M
i,interaction
⊕M
i,enhanced
In the above formula, Mi∈m
The aggregation module: this module is composed of one RNN network and one Transformer network, and configured to: receive the matching vector (v1, . . . , vl) output by the matching module, process the matching vector by the RNN network to obtain a short-term dependence information sequence (h1, . . . , hl), and process the matching vector by the Transformer network to obtain a long-term dependence information sequence (g1, . . . , gl).
In this embodiment, an encoder in Transformer (For detail, see Vaswani A, Shazeer N, Parmar N, et al. Attention is All you Need[C]; proceedings of the NIPS, F, 2017.) captures the long-term dependence information in the matching vector (v1, . . . , vl).
The prediction module is configured to calculate a matching score of the context c and the candidate answer r involved in matching according to the short-term dependence information sequence (h1, . . . , hl) and the long-term dependence information sequence (g1, . . . , gl) output by the aggregation module.
The word level vector described in the representation module is obtained by the tool Word2vec (see Mikolov T, Sutskever I, Chen K, et al. Distributed Representations of Words and Phrases and their Compositionality[C]; proceedings of the NIPS 2013, F, 2013.); the character level vector is obtained by encoding character information through the convolutional neural network. For the convolutional neural network used in the embodiment, see Lee K, He L, Lewis M, et al. End-to-end Neural Coreference Resolution [Z]. EMNLP. 2017.
The specific calculation process of the short-term dependence information sequence (h1, . . . , hl) is:
h
l
=GRU(vl,h1-l)
The specific calculation process of the long-term dependence information sequence (g1, . . . , gl) is:
(g1, . . . ,gl)=MultiHead(Q,K,V)
where,
Q=V
m
W
Q
,K=V
m
W
K
,V=V
m
W
V,
The specific process of the prediction module to calculate the matching score of the context c and the candidate answer r involved in matching is calculated as follows:
to obtain (ĝ1, . . . ,ĝl), wherein ⊕ represents the multiplication of elements;
{tilde over (g)}
l
=GRU(ĝi,{tilde over (g)}i-1)
g(c,r)=σ({tilde over (g)}l·wo+bo)
In the above formula, σ(·) represents a sigmoid function, wo and bo are training parameters.
The system is trained using the following loss function:
Referring to
Û
k
=f
catt(Uk-1,Uk-1,C)
{circumflex over (R)}
k
=f
catt(Rk-1,Rk-1,C)
Ū
k
=f
catt(Uk-1,Rk-1,C)
k
=f
catt(Rk-1,Uk-1,C)
Ũ
k
=[U
k-1
,Û
k
,Ū
k
,U
k-1
⊙Ū
k]
{tilde over (R)}
k
=[R
k-1
,{circumflex over (R)}
k
,
k
,R
k-1
⊙
k]
U
k=max(0,WhŨk-1+bh)
R
k=max(0,Wh{tilde over (R)}k-1+bh)+Rk-1
In the above formula, Uk-1∈m×d and Rk-1∈n×d represent inputs of the k-th global interaction layer, wherein m and n represent the number of words contained in the current turn of dialogue and the number of words contained in the candidate answer, respectively, and inputs of the first global interaction layer are U0=Eu, R0=Er; Wh∈4d×d and bh are training parameters; the operator ⊕ represents a multiplication of elements; d represents the dimension of a vector
In the above formula, fcatt( ) represents the described global attention mechanism, which is specifically defined as follows:
f
catt(Q,K,C)={tilde over (Q)}+FNN({tilde over (Q)})
In the above formula, FNN({tilde over (Q)})=max(0,{tilde over (Q)}Wf+bf)Wg+bg, wherein W{f,g}∈d×d and b{f,g} are trainable parameters, Q and {circumflex over (Q)} are mixed using a residual connection to obtain {tilde over (Q)}, wherein {circumflex over (Q)} is calculated according to the following formula:
{circumflex over (Q)}=S(Q,K,C)·K
In the above formula, Q∈n
In the above formula, W{b,c,d,e} are trainable parameters, Ciq represents the i-th row of Cq, and its physical meaning is the fusion context information related to the i-th word in the query sequence Q; Cjk represents the j-th row of Ck, and its physical meaning is the fusion context information related to the j-th word of a key sequence K;
C
q=softmax(QWaCT)·C
C
k=softmax(KWaCT)·C
M
i
=M
i,self
⊕M
i,interaction
⊕M
i,enhanced
In the above formula, Mi∈m
The specific calculation process of the short-term dependence information sequence (h1, . . . , hl) is:
h
l
=GRU(vl,hl-1)
The specific calculation process of the long-term dependence information sequence (g1, . . . , gl) is:
(g1, . . . , gl)=MultiHead(Q,K,V)
In the above formula,
Q=V
m
W
Q
, K=V
m
W
K
, V=V
m
W
V,
wherein WQ, WK and WV are training parameters; Multihead ( ) represents a multi-head attention function; Vm=(v1, . . . , vl).
to obtain (ĝ1, . . . , {tilde over (g)}l), wherein ⊕ represents the multiplication of elements;
{tilde over (g)}
i
=GRU(ĝi,{tilde over (g)}i-1)
g(c,r)=σ({tilde over (g)}l·wo+bo)
In the above formula, σ(·) represents a sigmoid function, wo and bo are training parameters.
The inventor has conducted a lot of experiments on the proposed system and method on three widely used multi-turn dialogue retrieval data sets. The three data sets are Ubuntu dialogue data sets (see Lowe R, Pow N, Serban I, et al. The Ubuntu Dialogue Corpus: A Large Dataset for Research in Unstructured Multi-Turn Dialogue Systems[C]; proceedings of the SIGDIAL 2015, F, 2015.), Douban Dialog Data sets (see Wu Y, Wu W, Xing C, et al. Sequential Matching Network: A New Architecture for Multi-turn Response Selection in Retrieval-Based Chatbots[C]; proceedings of the Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics, ACL 2017, Vancouver, Canada, July 30-August 4, Volume 1: Long Papers, F, 2017. Association for Computational Linguistics.), and E-commerce dialogue data sets (see Zhang Z, Li J, Zhu P, et al. Modeling Multi-turn Conversation with Deep Utterance Aggregation[C]; proceedings of the COLING, F, 2018.). The statistical information of these data sets are shown in Table 1.
In the experiment, the word level vector of English text is a 200 dimensional word vector obtained by word2vec, and the character level vector of English text is a 100 dimensional vector; The experimental results show that the method proposed by the present invention is effective and feasible, and the indicators are obviously superior to other methods.
Number | Date | Country | Kind |
---|---|---|---|
202210649202.3 | Jun 2022 | CN | national |