Monday, 12 March 2018

Introduction to Distributed Tensorflow and sharing any tensorflow object between sessions across different processes

I'm currently working on Reinforcement Learning, where there's an algorithm called A3C, where I need to have global network, and there will be parallel workers running in different processes which needs to update the global network.

Using just tf.Session(), we can't access the nodes (ops or tensors or variables or whatever) from another tensorflow's session running in a different process.

Enter Distributed Tensorflow. Anything is possible using the power of Distributed TF.

Here's a very friendly article on Distributed TF:
Distributed TensorFlow: A Gentle Introduction

It's highly recommended that you read it before proceeding (atleast skim through it), as my post will be additions on top of it for making things easy.

Everything is easy when explained with an example. I'll be explaining based on the following cluster configuration:

Here, we have 2 jobs, namely 'worker' (the worker that'll update the global network) and 'ps' (let's call it the parameter server, which contains the global network). The workers can run tasks on four different servers (as can be seen from the cluster config), and there's a single server for the global network. (You can design it as you see fit for your application)

We create a cluster object by:
cluster = tf.train.ClusterSpec(jobs)

So, we can create our network in the parameter server (or the current program). But how do you reference it from other workers, so as to access the shared tensors or variables or ops?

For example, consider the following scenario:


If you refer the tensorflow docs, you can see the return type of any call. For instance, it can be found that the variables var, state, conv1, and train_op are respectively of types 'tf.Variable', 'tf.Tensor', 'tf.Tensor' and 'tf.Operation'.

Also, note that these are created in the server "/job:ps/task:0" only from where you can access it.

" Okay, how do refer/access it from other processes?"

Each object created above has a member named 'name', which can be access like var.name or state.name or conv1.name or train_op.name. It returns a string telling you the full reference name of the object along with the variable_scopes (if used). You may print it out and check the result. For instance, conv1.name will return something like 'Conv/conv/Elu:0'

You need to store those names so that you can refer it from anywhere.

I have created a simple class to do that (You may extend it however you want):


The code is self-explanatory I guess.

So you can now create an object that has all such reference names.
And you can add all your reference names to the object like:




Now that you have the objects' references that can be shared, you can start your workers.

You can spawn processes from a single program, by using the multiprocessing module (Or you should know the way around your things)





Wondering what finish_counter is? Will get to that soon..


Let's see how workers can be implemented. Say you have a function like this that each worker process executes:


If you had read the article that I had linked above, if you haven't noticed, there's no way to decide when to terminate all these processes, since the server.join() call (after the processing is complete) blocks the process. (And server just can't be killed like that, other tasks may depend on that server)

So I thought of creating a shared class (with shared memory) which acts as a counter, telling the no. of servers that has finished processing, and is waiting to be killed.

To create an shared object:

(This is the finish_counter that was passed when processes were created, remember? ;) )

So, the end of the worker_function() should look like:



The main program's ending should look like:




That's it. You've now implemented Distributed TF with a parameter server and multiple workers.
You can extend this solution according to your application needs.

No comments:

Post a Comment