TensorFlow Integration

If you're already using TensorFlow then Comet works out of the box. Just add these lines of code to your training script:

```python from comet_ml import Experiment experiment = Experiment()

Your code.


For more information on getting started, see details on the Comet config file.

TensorFlow Examples

Over at github.com/comet-ml/comet-examples we have an extensive collection of examples of Comet combined with many ML frameworks, including TensorFlow.

  1. TensorFlow 1 examples
  2. TensorFlow 2 examples

If you have an example you'd like to add, please feel free to make a Pull Request. If you have a wish for an example you'd like to see, please feel free to make a Pull Request request.

TensorFlow Estimator Integration

The Comet auto-logging system has instrumented the following canned TensorFlow Estimators to log all hyperparameters and model graph definitions for TensorFlow versions 1 and 2:

  • BaselineClassifier
  • BaselineEstimator
  • BaselineRegressor
  • BoostedTreesClassifier
  • BoostedTreesRegressor
  • DNNClassifier
  • DNNEstimator
  • DNNLinearCombinedClassifier
  • DNNLinearCombinedEstimator
  • DNNLinearCombinedRegressor
  • DNNRegressor
  • KMeansClustering
  • LinearClassifier
  • LinearEstimator
  • LinearRegressor
  • RNNEstimator

TensorFlow metrics are auto-logged via the TensorBoard summary API. In addition, more hyperparameters and metrics can be logged manually, as show below.

If you have extended Estimator (or using the base class directly), you will need to manually log your hyperparameters; however, your model graph definition and metrics will still be auto-logged.

End-to-end example

Here is an end-to-end TensorFlow example.

For more examples using TensorFlow, see our Comet Examples Github repository.

```python """A very simple MNIST classifier. See extensive documentation at https://www.tensorflow.org/get_started/mnist/beginners """ from future import absolute_import from future import division from future import print_function from comet_ml import Experiment

from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf

def get_data(): mnist = input_data.read_data_sets("/tmp/tensorflow/mnist/input_data/", one_hot=True) return mnist

def build_model_graph(hyper_params): # Create the model x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.matmul(x, W) + b

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])

cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(hyper_params['learning_rate']).minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
correct_prediction = tf.cast(correct_prediction, tf.float32)
accuracy = tf.reduce_mean(correct_prediction)

return train_step, cross_entropy, accuracy, x, y, y_

def train(hyper_params): mnist = get_data()

# Get graph definition, tensors and ops
train_step, cross_entropy, accuracy, x, y, y_ = build_model_graph(hyper_params)

experiment = Experiment(project_name="tf")

with tf.Session() as sess:
    with experiment.train():

        for i in range(hyper_params["steps"]):
            batch = mnist.train.next_batch(hyper_params["batch_size"])
            # Compute train accuracy every 10 steps
            if i % 10 == 0:
                train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1]})
                print('step %d, training accuracy %g' % (i, train_accuracy))

            # Update weights (back propagation)
            _, loss_val = sess.run([train_step, cross_entropy],
                                   feed_dict={x: batch[0], y_: batch[1]})


    ### Finished Training ###

    with experiment.test():
        # Compute test accuracy
        acc = accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})
        print('test accuracy %g' % acc)

if name == 'main': hyper_params = {"learning_rate": 0.5, "steps": 1000, "batch_size": 50} train(hyper_params) ```