ROS Tensorflow

4 minute read


In this blog post I go through some useful techniques to nicely integrate Tensorflow (TF) in ROS nodes. The related code can be found in the following ros_tensorflow github repo.

What this blog post covers:

Integrating Tensorflow 2 in ROS nodes

Tensorflow requires python 3, but ROS melodic and ubuntu 18.04 still rely on python 2 by default. You can get away by running the Tensorflow node with python 3, using the shebang #!/usr/bin/env python3. You will have to install some missing ROS python 3 libraries, which you can do with conda or pip. If you plan to work with images, you will encounter an additional issue with cv_bridge, which you can solve by following this tutorial.

Model creation

Since rospy callbacks are called in a separate thread than the main, it is important to handle the Tensorflow session accordingly. When you create your model at initialization phase in the main function, your code is running in the main thread. Here, you should store the Tensorflow session so that you can restore it later in your callbacks:

`#!/usr/bin/env python3`

import tensorflow as tf
import rospy

class ModelWrapper():
    def __init__(self):
        # store the session object from the main thread
        self.session = tf.compat.v1.keras.backend.get_session()

        self.model = tf.keras.Sequential()
        # ...

class RosInterface():
    def __init__(self):
        self.wrapped_model = ModelWrapper()

def main():
    ri = RosInterface()
    rate = rospy.Rate(1)
    while not rospy.is_shutdown():

if __name__ == "__main__":

Model training

If you are recording data at runtime, you may want to re-train your model at runtime too. The best way is to perform training in a ROS action server, so that a client can also cancel training if needed. Since the action callback is called in a separate thread, you have to restore the session from the main thread. Here is the training function for the tensorflow model:

import tensorflow as tf

class ModelWrapper():
    # ...
    def train(self, x_train, y_train, n_epochs=100, callbacks=[]):
        with self.session.graph.as_default():

  , y_train,

The line tf.compat.v1.keras.backend.set_session(self.session) restores the session object to the main thread. If you omit this line, you we’ll get a strange error that will look like the following:

Error processing request: Error while reading resource variable model/dense_1/kernel from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/model/dense_1/kernel)

We will the train function from an action server. Note that we declare a callbacks parameter, which will allow us to provide action feedback while the model is training and cancel training when the action is cancelled. Here are some nice wrappers to achieve this:

import tensorflow as tf

class StopTrainOnCancel(tf.keras.callbacks.Callback):
    def __init__(self, check_preempt):
        super(tf.keras.callbacks.Callback, self).__init__()
        self.check_preempt = check_preempt
    def on_batch_end(self, batch, logs={}):
        self.model.stop_training = self.check_preempt()

class EpochCallback(tf.keras.callbacks.Callback):
    def __init__(self, cb):
        super(tf.keras.callbacks.Callback, self).__init__()
        self.cb = cb
    def on_epoch_end(self, epoch, logs):
        self.cb(epoch, logs)

Now let’s implement the ROS action server in our RosInterface class:

import rospy
import actionlib
from ros_tensorflow_msgs.msg import TrainAction, TrainFeedback, TrainResult

class RosInterface():
    def __init__(self):
        self.wrapped_model = ModelWrapper()
        self.train_as = actionlib.SimpleActionServer('train', TrainAction,
                                                     self.train_cb, False)

    def train_cb(self, goal):
        stop_on_cancel = StopTrainOnCancel(
            check_preempt=lambda : self.train_as.is_preempt_requested())
        pub_feedback = EpochCallback(
            lambda epoch, logs: self.train_as.publish_feedback(
                TrainFeedback(i_epoch=epoch, loss=logs['loss'], acc=logs['accuracy'])))

        # ... load x_train and y_train
        self.wrapped_model.train(x_train, y_train, callbacks=[

        # Training finished either because it was done or because it was cancelled
        if self.train_as.is_preempt_requested():

Model predictions

There are two ways you can integrate a Tensorflow model in a ROS nodes:

  • Making predictions in a loop
  • Making predictions on callbacks

If you want to perform prediction on callbacks, you will have to restore the Tensorflow session of the main thread, just like training:

class ModelWrapper():
    # ...
    def predict(self, x):
        with self.session.graph.as_default():
            out = self.model.predict(x)
        return out

Restoring the session is not necessary if you perform predictions in the main loop, but it doesn’t hurt. Since predictions are fast compared to training, you can implement them with rostopic subscribers or rosservice servers, no need for actionlib in this case.

In the provided ros_tensorflow repo I demonstrate both making predictions in a service callback and making predictions in main loop. The repo also makes a clean separation between Tensorflow code and ROS code. If you plan to use Tensorflow in your ROS project you can simply fork the repo and start from there.