I\'m confused about the distributed training process in tensorflow.
I think the tensorflow feed a batch_size of data to a worker and then the worker update the ps se
There aren't really official docs besides the HowTo so a good way to figure out how things work is by studying examples.
The basic concept to understand is that there are 3 kinds of tensorflow processes.
The client -- this is the Python process which builds the graph, connects to local master (Session()
) or remote master (Session("grpc://...")
) and issues session.run
calls.
There's the master, which is the process that client connects to, and which figures out how to distribute the work among workers.
There's the worker, which does actual work. If your graph has a with tf.device(job:worker/task:0):
, block, then computation in that block should be executed on task:0
When you create new server with server = tf.train.Server
, the process that's started is both a worker and a master, but it's useful to understand the difference for debugging.
The easiest example of distributed TF is when you have a single client, which starts an in-process master, and multiple workers. Here's one such example. In this usage, the main difference from non-distributed version is that you do with tf.device("worker1")
instead of tf.device("gpu1")
to tell it to execute that part of graph on worker1
It gets more complicated when you have multiple clients, as in the case of "between-graph replication." The parameter server example, you have multiple parallel training loops, where each loop corresponds to a separate client which is a python process issuing run calls. To see on which worker the ops are actually located you can look on the with tf.device
annotations.
In your example you don't have explicit with.device("job:worker/task")
blocks in your snippet, but this part is done by tf.device(tf.train.replica_device_setter(
. Essentially instead of having a fixed device for all ops in block, the code runs the replica_device_setter
for each op to generate device to place it on. It places all variables onto /job:ps/task
workers, and the rest of the ops on the current worker. The code for replica_device_setter
got a bit complicated over time, but you could use a simpler implementation of it for the same effect as below
def simple_setter(ps_device="/job:ps/task:0"):
def _assign(op):
node_def = op if isinstance(op, tf.NodeDef) else op.node_def
if node_def.op == "Variable":
return ps_device
else:
return "/job:worker/task:%d" % (FLAGS.task)
return _assign
...
with tf.device(simple_setter):
...
When you run this, each python process will create slightly different version of the graph, except for the Variable nodes, which will look identical in each process (check with tf.get_default_graph().as_graph_def())
When you have multiple clients running training loops, one issue is -- who executes tasks that need to be done once for all clients? For instance, someone needs to run initializers for all variables. You could put sess.run(tf.initialize_all_variables...)
in client body, but with multiple clients running in parallel, this means op initializations are run more than once. So the solution is to designate one worker as "chief" worker, and only have that worker run the operation.
Also, there's no built-in distinction between worker
and ps
devices -- it's just a convention that variables get assigned to ps
devices, and ops are assigned to worker
devices. You could alternatively only have worker
devices, and have a version of replica_device_setter
put variables to 0'th worker.
Here's a barebones example with m
workers updating variables sharded over n
PS tasks, which uses explicit device assignment instead of replica_device_setter
To summarize, in your case replica_device_setter
makes sure that your global_step
is a variable that's stored on ps
worker, and as such that makes this variable shared across all of your training loops. As to why you get the same of global_step
in both workers -- there's nothing in your graph forcing global_step
to be read after it's incremented. So if you run sess.run([increment_global_step, fetch_global_step])
in parallel on two different workers, you could potentially see
worker 0: 0
worker 1: 0
worker 0: 2
worker 1: 2
etc