看到 MPNN 之前的 GNN 算法,似乎有种技术考古的感觉。

MPNN 是 Google Brain 于 ICML 2017 提出的一种图神经网络框架,论文题目是 Neural Message Passing for Quantum Chemistry.

文章的目的是对现有的图神经网络进行归纳,同时为应用驱动的图神经网络开发提供了便利。

# 背景

# 分子特征预测问题

MPNN 的提出最早是为了解决分子特征预测问题。在 2017 年前后,基于高通量方法,分子特征的高效测定变得更加容易,一个包括 1.3×1051.3 \times 10^5 类分子和 1313 项特征的数据集 QM9 得以建立,而基于量子力学方法的分子性质模拟(实际上借助密度泛函理论 DFT)仍然困难。因此,需要一种面向分子性质预测的机器学习方法提高预测的效率。

# 统一的 GNN 框架

此外,Google Brain 还希望对现有的 GNN 方法进行一个解耦和归纳,从而实现一个全新的、高效的、包罗先各种 GNN 的框架。在当时 (2017 年), 提出的主要 GNN 包括:

  • 用于学习分子指纹的卷积网络 (Convolutional Networks for Learning Molecule Fingerprints, Duvenaud et al. 2015)

    随便看了看,真的一言难尽。抽空看看要不要补上简介。

  • 门控图神经网络 (Gated Graph Neural Networks (GG-NN), Li et al. 2016)
  • 交互网络 (Interaction Networks, Battaglia et al. 2016)
  • 分子图卷积 (Molecular Graph Convolution, Kearnes et al. 2017)
  • 深度张量神经网络 (Deep Tensor Neural Networks, Schütt et al. 2017)
  • 基于拉普拉斯矩阵的方法 (Laplacian Based Methods, Bruna et al. 2013; Defferrard et al. 2016; Kipf & Welling 2016)

# 门控图神经网络 (GG-NN)

门控图神经网络 (Gated Graph Neural Networks, GG-NN) 是需要重点关注的一种 GNN 结构。后续用来实验的 MPNN 便是在此基础上修改得到的。

# MPNN 框架

# 记号

MPNN 作为一个图神经网络框架,其记号仍然采用图神经网络的常用记号。图 G\mathcal{G} 顶点 vv 的特征为 xvx_v, 边 vwvw 的特征为 evwe_{vw}. 然后注意到这个图神经网络是有隐藏层的,顶点 vv 在第 tt 个隐藏层中的特征记为 hvth_{v}^t, 此外还有一个辅助的消息 mvtm_{v}^t, 表示生成第 tt 个隐藏层的消息传递阶段顶点 vv 收到的所有(来自第 t1t-1 层的)消息。

# 主要内容

MPNN 将 GNN 划分成为两个阶段:

  • 消息传递阶段 (message passing phase)
  • 读出阶段 (readout phase)

# 消息传递阶段

消息传递阶段的过程可以描述为以下两个式子:

\begin{align*} & m_{v}^{t+1} = \sum_{w \in N(v)} M_t(h_v^t, h_w^t, e_{vw}) \\ & h_{v}^{t+1} = U_t(h_v^t, m_{v}^{t+1}) \end{align*}

其中 MtM_t 是消息传递函数,接收第 tt 个隐藏层中对应节点和邻居节点、连边的信息。UtU_t 是顶点更新函数,实现 t+1t+1 隐藏层顶点特征的更新。

所以,在这个框架中,隐藏层不蕴含边的信息,使用的边权始终为初始值。换句话说,点与边的对称性被破坏了。

# 将 GCN 放入 MPNN 框架

如果将 GCN 放入 MPNN 框架,那么可以表示为

Mt(hvt,hwt)=cvwthwtUt(hvt,mvt+1)=ReLU(Wtmvt+1)\begin{aligned} & M_t(h_v^t, h_w^t) = c_{vw}^th_w^t \\ & U_t(h_v^t, m_v^{t+1}) = \text{ReLU}(W^tm_v^{t+1}) \end{aligned}

其中 cvwt=(deg(v)deg(w))1/2A~vwc_{vw}^t = \big(\deg(v)\deg(w)\big)^{-1/2}\tilde{A}_{vw}, C=(cij)N×N=D1/2A~D1/2C = (c_{ij})_{N \times N} = D^{-1/2}\tilde{A}D^{-1/2}. (这一看就很对...)

注意到,GCN 中每一个隐藏层 H(l)RN×CH^{(l)} \in \mathbb{R}^{N \times C} 实际上是各顶点 CC 维特征的组合。因此,按行抽取 H(l)H^{(l)}, 得到的就是消息传递阶段的消息 mvt+1m_{v}^{t+1}. 因此就得以证明了。

# 读出阶段

读出阶段就是基于末端隐藏层特征得到输出特征的过程。

y^=R({hvtvG})\hat{\boldsymbol{y}} = R(\{h_v^t|v \in \mathcal{G}\})

其中 RR 是读出函数,接收所有顶点末端隐藏层信息并实现特征输出。

# 实现

PyG 实现了 MPNN 框架的消息传递阶段,具体内容可以参考文档

基于文档的解释,消息传递阶段被形式化为

xi=γΘ(xi,jN(i)ϕΘ(xi,xj,ej,i))\boldsymbol{x}_i' = \gamma_{\mathbf{\Theta}}\left(\boldsymbol{x}_i, \bigoplus_{j \in \mathcal{N}(i)} \phi_{\mathbf{\Theta}}(\boldsymbol{x}_i, \boldsymbol{x}_j, \boldsymbol{e}_{j,i})\right)

其中 \oplus 是一个可微的、置换不变(即对称)的函数,例如求和 sum 、求平均 mean 、求最大 max 等。γΘ,ϕΘ\gamma_{\mathbf{\Theta}}, \phi_{\mathrm{\Theta}} 都是可学习的可微函数,例如多层感知机 (MLP) 等。

不要被这个式子的形式吓到,实际上,该式就是上述两式将 mvtm_v^t 消去并合并为一个式子的结果。

现在,我来简单介绍一下 PyG 中的 Message Passing 类。

# 代码

MPNN paper 对应的代码中显示了 MPNN 类的一些基本的实现细节。

# Ref

  • Neural Message Passing for Quantum Chemistry