问题描述
在tensorflow计算图中实现的Conjugate gradient算法,一开始迭代数值就变成nan。
这个问题是在实现natural gradient optimization的时候发现的,代码实现的思路严格参照了Wikipedia上有关conjugate gradient的讲解,由于代码真正运行是在静态的计算图中,运行时也很难debug出到底出问题的是哪一步。
最后问题解决的办法也非常暴力——在屏幕上打印出所有的变量一一检查,逐层定位问题。经检查,最有可能的两个问题是:
- Conjugate gradient算法实现中可能会出现分母为零的情况
- 计算KL divergence时出现浮点数下溢,导致log输出NaN
理论上来讲这种情况是不会发生的,因为information geometry optimization的核心度量Fisher information matrix一般来讲都是非奇异的。给定$N$个不同的样本,FIM的rank不小于$N\times{\dim(\mathcal{Z})}$,其中$\dim(\mathcal{Z})$为模型输出空间概率分布本身的维度。这也就是说我们只要把batch size给到足够大,FIM不会出现奇异的情况
但实际代码运行的时候,由于一些数值上的不稳定性,FIM奇异还是有可能会出现
Code implementation
Conjugate gradient in TensorFlow static computational graph
Conjugate gradient的实现确实是比较复杂,想要将其写在TensorFlow的静态计算图内部还是比较麻烦,遇到的坑会比用numpy实现多不少,以下先放代码
1 | def hessian_vector_product(x, grad, variable): |
Ideas worth noting:
- The conjugate gradient does not require direct access to the explicit form of the matrix, only matrix-vector-product is needed. Let $\hat{F}\in{\mathbb{R}^{N\times{N}}}$ be the FIM estimated with a batch of data, and let $v\in{\mathbb{R}^{N}}$ be an arbitrary vector, then $\hat{F}^{-1}v=\nabla_{\theta}(v^{T}\nabla_{\theta}D_{KL}(\pi_{\text{old}}||\pi))$.
- When calculating the matrix-vector-product $\hat{F}^{-1}v$, remember to block the gradient from vector $v$ using
tf.stop_gradient
in every iteration. Your gradient should only come from the KL divergence term. - There are devision operations for variables
alpha
andbeta
. Remember to add a small number (EPSILON
) for the denominator in case of NaN output.
github上也可以找到其他的基于TensorFlow静态计算图的conjugate gradient实现,其中一个印象比较深刻的实现方法是将while循环判断也写在了计算图的内部,用tf.while_loop
实现,代码如下
1 | def _conjugate_gradient_ops(self, pg_grads, kl_grads, max_iterations=20, residual_tol=1e-10): |
Model definition and environment interaction
调了一晚上的bug,根源就在这部分代码中。在jupyter notebook中打印出了所有变量的实际值后,发现一些奇怪的现象:
- 问题出在从普通的gradient转成natural gradient的过程中,普通gradient数值没有问题,natural gradient时常数值爆炸,或者直接出现NaN
- 搭了一个小网络来测试conjugate gradient实现的正确与否,发现偶尔会出现FIM奇异的情况,推测是numerical instability所致
- 进一步的测试中发现,每一步的迭代中并不总是会出现FIM奇异。相对而言较大、且靠近底层的参数不容易出现FIM奇异,参数量较小且靠近输出层的参数容易出现NaN,当时最频繁出现NaN的参数,就是网络最后一层的bias
猜测这种numerical instability可能是由于KL divergence不稳定导致的,所以干脆把旧的policy网络和新的policy网络写成同一个,设置reuse=True
的来共享网络参数,这样KL divergence就会始终为0,毕竟我们需要的只是KL divergence的gradient和Hessian而已。
后续实验中发现这样写会使得网络迭代更新速度缓慢,且容易收敛到sub-optimal点,有关这个问题之后如有新的发现再来更新这里的内容吧
1 | def collect_multi_batch(env, agent, maxlen, batch_size=64, qsize=5): |