(This blog is featured in DataScienceWeekly here and Chinese translation here (中文) by Xiatian)

Have you ever wondered what’s the magic behind the tutorials on Large-scale Linear Models and Wide & Deep Learning? I hope this post would at least point you to the right direction.

Please take a look at my previous blog posts to understanding some basics of TensorFlow Learn and its integration with other high-level TensorFlow modules.

The purpose of this post is to help you better understand the underlying principles of estimators in TensorFlow Learn and point out some tips and hints if you ever want to build your own estimator that’s suitable for your particular application. This post will be helpful when you ever wonder how everything works internally and gets overwelmed by the large codebase.

## Understanding `BaseEstimator`

and `Estimator`

`BaseEstimator`

is the abstract and base class for training and evaluating TensorFlow models. It provides the basic functionalities like `fit()`

, `partial_fit()`

, `evaluate()`

, and `predict()`

by utilizing detailed logics hidden in `graph_actions.py`

to handle model inference, evaluation, and training, as well as `data_feeder.py`

to handle data batches fetching for different types of input (Note: in the future, `DataFeeder`

will be replaced by `learn.DataFrame`

). It also checks for compatibility of inputs in terms of `dtypes`

and whether inputs are sparse using `estimators.tensor_signature`

.

In the meantime, `BaseEstimator`

intializes the settings for monitors, checkpointing, etc. While providing most of the logics required for building and evaluating a customized model function, it leaves implementations for `_get_train_ops()`

, `_get_eval_ops()`

, and `_get_predict_ops()`

to its sub-classes, in order to give freedom to sub-classes that require custom handling. `BaseEstimator`

is also distributed, I’ve discussed briefly in my previous blogpost here.

`Estimator`

implemented in the module is the perfect example of how to implement those functions that are left to be overriden by sub-classes of `BaseEstimator`

.

For example, `_get_train_ops()`

in `Estimator`

takes `features`

and `targets`

as inputs, and then returns a tuple of train `Operation`

and loss `Tensor`

, using the customized model function. If you want to implement your own estimator, this also gives you freedom to decide whether `targets`

can be ignored if the estimator can be trained in unsupervised fashion.

Similarly, `_get_eval_ops()`

lets a sub-class to use customized metrics to evaluate each training step. A list of available metrics can be found in a couple of high-level modules in TensorFlow. It should return a dictionary of `Tensor`

object that represents the evaluation ops for the metrics specified.

`_get_predict_ops()`

is implemented to customize predictions, e.g. probability v.s. actual prediction output. It returns a `Tensor`

or a dictionary of `Tensor`

object that represents prediction ops. You can then easily use super-class’s `predict()`

to achieve functionalities like `transform()`

similar to the one in Scikit-learn for unsupervised problems.

## Examples of Estimators

`LogisticRegressor`

`Estimator`

already provides most of the implementations you need. For example, `LogisticRegressor`

only needs to provide its own metrics, such as AUC, accuracy, precision and recall, dedicated for only binary classification problems. So later a user can sub-class `LogisticRegressor`

to implement a estimator for binary classification without much further effort.

`TensorForestEstimator`

A `TensorForestEstimator`

has also been added to TensorFlow Learn recently. It hides most of the detailed implementations of Random Forests in `contrib.tensor_forest`

while utilizing some exposed high-level components to build the estimator so users can use `contrib.tensor_forest`

more easily.

For example, instead of passing all hyper-parameters to the contructor of `TensorForestEstimator`

, they are passed into `params`

in the contructor and the `params`

are filled by `params.fill()`

and later it will be used in Tensor Forest’s own `RandomForestGraphs`

for constructing the whole graph.

```
class TensorForestEstimator(estimator.BaseEstimator):
"""An estimator that can train and evaluate a random forest."""
def __init__(self, params, device_assigner=None, model_dir=None,
graph_builder_class=tensor_forest.RandomForestGraphs,
master='', accuracy_metric=None,
tf_random_seed=None, verbose=1,
config=None):
self.params = params.fill()
```

Since there are a lot of more details for implementation of Random Forest’s inference (many of them have been written as separate kernels to speed things up), its `_get_predict_ops()`

utilizes `tensor_forest.RandomForestGraphs`

as its graph builder. It calls `graph_builder.inference_graph`

to get the prediction ops.

```
def _get_predict_ops(self, features):
graph_builder = self.graph_builder_class(
self.params, device_assigner=self.device_assigner, training=False,
**self.construction_args)
features, spec = data_ops.ParseDataTensorOrDict(features)
return graph_builder.inference_graph(features, data_spec=spec)
```

Similarly, it uses `graph_builder.training_loss`

for the implementation of `_get_train_ops()`

. Note that `TensorForestEstimator`

uses functions in `tensor_forest.data.data_ops`

module, such as `ParseDataTensorOrDict`

and `ParseLabelTensorOrDict`

to parse the input features and labels.

## Other Examples

A new estimator for K-means clustering has been added today, located in `contrib.factorization.python.ops.kmeans`

. Similar to `TensorForestEstimator`

, dedicated kernels are written to highly optimize the speed and only some of the surfaced high-level components are used for the implementation of the estimator. More examples of implementing different estimators can be found in `learn.estimators`

.

I highly recommend you taking a look to understand the underlying code structure better and start Implementing your own estimators!

Please do not hesitate to leave a message if you have any questions.

## More Resources:

- Key Features of Scikit Flow Illustrated
- High-level Learn Module in TensorFlow
- Introduction to Scikit Flow and why you want to start learning TensorFlow
- DNNs, custom model and Digit recognition examples
- Categorical variables: One hot vs Distributed representation
- Scikit Flow: Easy Deep Learning with TensorFlow and Scikit-learn

Copyright © Yuan Tang 2021

**Banner Credit to TensorFlow Org**