tenforflow controlflow 2---the implementation of control flow

匿名 (未验证) 提交于 2019-12-03 00:43:02

如何判断计算哪个分支?如何判断循环的结束?如何把循环分割成计算子图?tf是如何实现控制流的呢?这不禁引起了我的好奇。

TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems,2TensorFlow: A system for large-scale machine learning,3Implementation of Control Flow in TensorFlow和4Dynamic Control Flow in Large-Scale Machine Learning,博客里的表述很多都是翻译的原文,也包含了一些我阅读源码后的理解,我姑且把这篇归为翻译。起初开始看tensorflow源代码,走了不少弯路,不知从何看起。看完这些论文再去研究源代码,才有些登堂入室的感觉。对tensorflow 计算图感兴趣的同学,可以研究一下。

Switch : 取决于输入p的值,Switch 算子把 d 的值传给两个输出中的某一个 。两个输入都可用,Switch节点才可执行

Merge : A Merge 算子把一个可用的输入传给输出。只要任意一个输入可用,Merge便可执行。

Enter

Exit

NextIteration: NextIteration 算子 把输入d 传入当前执行帧的下一个 iteration . TF 运行时维护迭代状态。 执行帧中执行的算子绑定一个iteration id, 用来标识op的一次执行(比如在whileloop里同一个op可能被多次执行)。输入可用时,Enter可被执行。


The cond Operator

# Build the graph for the true branch

#pred,fn1,fn2 are lists of tensors

context_t = CondContext(pred, branch=1)
res_t = context_t.Call(fn1)
# Build the graph for the false branch
context_f = CondContext(pred, branch=0)
res_f = context_f.Call(fn2)
# Add the Merge nodes for the outputs
merges = [Merge([f, t]) for (f, t) in zip(res_f, res_t)]

return merges

如上图生成的数据流图,x,y,z各自Switch 控制输入。由于Switch和Merge的存在,只有当 x<y 时,Add 运算符才被执行;反之,只有当x>y为false时,Square才被执行。取决于条件x<y,最终Merge输出Add的结果或 Square的结果。如上一段提到的那样,如果有多个输出,那么会有多个Merge,每个Merge对应一个输出。

The while_loop Operator

下面是用算子表示的构建数据流图 while_loop(pred, body,loop_vars)的伪代码:

while_context = WhileContext()
while_context.Enter()
# Add the Enter nodes for each loop variable.
enter_vars = [Enter(x, frame_name) for x in loop_vars]
# Add the Merge nodes. Note that input[1] will be updated later.
merge_vars = [Merge([x, x]) for x in enter_vars]
# Build the loop pred subgraph.
pred_result = pred(*merge_vars)
# Add the Switch nodes.
switch_vars = [Switch(x, pred_result) for x in merge_vars]
# Build the loop body subgraph.
body_result = body(*[x[1] for x in switch_vars])
# Add the NextIteration nodes.
next_vars = [NextIteration(x) for x in body_result]
# Form the cycles for the loop.
for m, v in zip(merge_vars, next_vars):
m.op._update_input(1, v)
# Add the Exit nodes.
exit_vars = [Exit(x[0]) for x in switch_vars]
while_context.Exit()

return exit_vars

以 循环变量为起点,对每一个循环变量,添加一个Enter op ,接着一个Merge op。然后使用Merge的输出构建出pred 子图用来计算循环终止条件。

以上是控制流大概的介绍,构建条件语句和循环语句的python api 实现,具体实现可以看tensorfow 源代码python部分,在后面的博客我会解读这一部分代码。

实现细节

excutor

计算图的构建是在client 端完成的(前向传递和反向传递),计算图构建完成后,通过session(direct session or distributed session)调用tensorflow runtime,tensorflow 运行时负责执行定义好的计算图。

在tensorflow 运行时看来,计算图是由一系列的whileloop的嵌套,一个whileloop大概是这样的结构。这里假设有两个循环变量,没有循环常量。

tf.whileloop(pred,body,[a,b])。pred是接受两个tensor,返回一个标量bool值tensor的函数;body是接受两个tensor,返回两个tensor的函数。

为了便于讲述,以下使用只有一个循环变量的图。

为了在多个设备运行计算图(本地模式是多cpu,分布式情况是多个服务器)tensorflow 自动地把这些ops Node 分配到个设备,并插入一些send/receive 节点对

以及一些协调节点,

这其实就是计算图的子图分割的过程,具体的实现在。为什么要加一些协调节点,下文会讲。

一个子图被一个绑定到一个device(cpu/gpu/tpu...etc)的excutor管理和执行。excutor从源节点开始,反复地执行可以执行的节点(上文有讲到,除Merge节点外一个节点可以被执行,如果它的输入全部被上游节点计算好了),直到可执行的队列为空。

如果没有control flow,计算图的执行是非常简单的:每个节点只被计算一次,这就是一个简单的DAG图,按照拓扑顺序(拓扑排序会吧?)一个一个执行节点就可以了。但是control flow引入了额外的复杂性,一个节点可以不被执行,也可以被执行任意次。excutor需要管理一个节点的多次运行,以及判断计算图的计算是否完成。

三元组(Node,frame,iteration Id)标识一个执行中的node,被称作taggedNode,tagedNode保存此次执行需要的输入值和计算输出值,这有点像程序和进程的关系。为了标识执行过程中的tensor,excutor 中的tensor由三元组(value,is_dead,tag)表示,其中value是真实的tensor,is_dead 标识当前是否是在一个不被执行的条件分支,tag 是一个string,唯一标识一个tensor(表示是某个tagNode的第几个输出)。tag 是send/recv pair 的传输key的一部分,以区分一对send/recv 的多个执行。

现在我们大概总结一下,excutor的执行过程。excutor 维持系列叫做frame的数据结构,每个frame维持一系列的iteration state 数据结果。frame 和 iteration status在执行的过程中动态地构建和销毁(new/del)。为每个tagednode维持一个peningcount, 表示还有几个输入值没有被计算出来。excutory 维持一个tagged node 工作队列,初始运行的时候,把一些source节点放到队列里,通过线程池,不断地从队列取出节点,计算输出,把输出传给需要这个输出作为输入的节点,改变该节点的pending count 值,如果为0,则可以把这个节点放到工作队列,直到工作队列为空。当然,遇到了control flow ops,需要特殊处理,比如说遇到Enter(name),如果是第一次遇到对应name的Enter节点,则需要初始化一个名字为name的子frame;如果是当前iteration i 中第一次遇到NextIteration,且传入的tensor is_dead 标识不为真,则初始化 iteration i+1,如果为真,不初始化下一个iteration,不往下传 tensor(nextiteration 节点截断whileloop循环)。

以下是计算节点的运算法则。


  • Exit(d) = r :



分布式执行

Distributed Conditional Execution

一个cond 有几个原子控制流算子和其他算子构成,所以分布式执行的时候,可能会被分到不同的devices,如下图所示。

因为recv是一种source节点,所以无条件被执行。即使send节点在一个untaken branch,Recv 也会被执行。为了让Recv 知道这是一个untaken branch,Send节点会把is_dead 标志传给Recv,Recv会把这个节点传下去,直到某个Merge或Nextiteration节点。

同理,同属于一个whileloop的算子 也可能被分配到不同的devices;

在上面的这个栗子中,循环体中的Op 节点被分配到deviceB。如上图简单的分割子图,无法让Op知道它是属于一个whileloop,只它计算一次就结束了(Recv 触发Op,Recv 只执行一次,Op也就只被执行一次)。解决方案是重写计算图,在每一个子图加入control-loop 状态机。

虚线是控制边

一个标量tenser 0 作为控制循环的Enter输入。这些控制流循环提供了分布式执行whileloop的必要的信息。

让我们来模拟一下whileloop 执行0次的情况:

  • 在device A,Enter、Merge、P、Switch依次被执行。因为P不为真,Merge 会把is_dead 标志传给 send,send 传给device B 的rec节点。Exit 节点同时也可以运行,使得外层依赖这个exit 的节点可以被同时运行。p 的send 节点把 p 的值传给device B。
  • 在device B,Enter 触发 controlloop开始循环,依次执行Enter 和Merge。因为两个Rec 节点依赖Merge,Merge会触发这两节点的执行。连接switch 的 Recv 节点收到 p 的值为false,Next 节点收到is_dead 标识,终止这个循环。连接Op的Rec会收到 一个dead tensor,所以send会传一个dead tensor 回device A。在这个时间点device B 的这个子图当前f没有需要计算的节点,执行完毕。
  • 回到device A,连接Next的Recv收到一个dead tensor,循环也将终止。device A 没有需要执行的节点,计算结束。

反向传递求导

tensorflow 支持根据链式法则自动求导。用户可以使用tensorflow的算子构建一个神经网络,定义一个损失函数,tensorflow可以自动求导,构建反向传递子图。

所谓反向传递,就是求导的链式法则。tensorflow 自动求导算法反向便利原始的计算图中的ops,依此调用ops注册的梯度函数,一步步地构建反向传播图。算子的梯度函数构造计算该算子符号梯度的计算子图。简单地说,tensorflow自动求导算法就是根据链式法则把这些个子图组装在一起。反向传播时计算梯度函数可能需要用到前向传播原始算子的输入和输出,所以这些数值需要被保持下来,以供反向传播时使用。

接下来讲有控制流的反向传播。

Backpropagation of Conditional

Backpropagation of While Loop

直观地说,对于循环变量,While Loop梯度是如下形式的。

def pred(i, _): return i < N

while_loop(pred, g_body, [0] + g_vars) 。

其中,N是前向传播whileloop 迭代的次数,g_body是前向whileloop循环体的梯度函数。下图是大致的前向传播和对应的反向传播图。

对于循环常量。循环常量是指在whileloop中使用到,又没有包含在tf.whilleloop的第三个参数中的tensor,tensorflow 假定在迭代的过程中这些值不会被改变。如下面的代码片段。

 import tensorflow as tf  i = tf.get_variable("ii", dtype=tf.int32, shape=[], initializer=tf.ones_initializer()) n = tf.Variable(10) b=tf.Variable(1) def cond(a):     return  a< n def body(a):     a = a + b     return a  a= tf.while_loop(cond, body, [i]) with tf.Session() as sess:     tf.global_variables_initializer().run()     print(a.eval())

其中 b和n是循环常量,i 是循环变量。

对于循环常量,梯度是如下形式的循环体,

 def pred(i, _): return i < N acc = 0.0; while (_pivot) {   acc += grad; }

由于和循环变量的梯度使用相同的pred,循环常量和循环变量的梯度整合在一个whileloop里。

为了计算出N,需要在前向传播的whileloop中加入计算N的逻辑。

为了保存反向传播中需要用到的前向传播中计算出的值。在构建op时,需要检测op的输入值。如果当前需要构建的op是在一个反向传播的一个whileloop中,而某个输入值是来自前向传播的whileloop,则需要为这个输入值引入一个stack,并加入相应的stackpush/stackpop 算子,在前向传播whileloop 的每一个iteration把这个符号在这一轮的值入栈,在反向传播中安装相反的顺序出 栈。这个栈存在于前向和反向whileloop之外的Frame(所以要有2个Enter,如下图)

符号:在tensorflow 运行时执行计算图之前,tensor只是一个符号。

此外,因为iteration 可以并行,为了保证入栈的次序和出栈的次序。还需要加入一些控制依赖边。保证一个whileloop中,对同一个符号对应的变量的入栈或出栈操作,iteration i中的操作在 iteraion i+1 中的操作之前被执行。以及在whileloop嵌套的情况下,

同一个符号

Frame:frame 是一个whileloop的执行时体现,因为whileloop可能嵌套,一个嵌套在whileloop 里的whileloop 可能被执行多次,每次执行对应一个frame,类似于程序和进程的概念。

以上是个对控制流实现大概的梳理。控制流相关的代码,client 端主要在tensorfow/python/ops/control_flow_ops.py。目前C/C++ client API 并不支持对whileloop 求梯度。运行时的对控制流的处理主要是在 1、子图分割时,加入控制流循环,这部分实现在graph 的graph_partition。2、对控制流原元语安装上文介绍的规则特殊处理,这部分代码实现是在excuter。感兴趣的同学可以看看源码,相信会有不小的收获。

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!