Please take a look at my previous blog posts to understanding some basics of TensorFlow Learn and its integration with other high-level TensorFlow modules.
The purpose of this post is to help you better understand the underlying principles of estimators in TensorFlow Learn and point out some tips and hints if you ever want to build your own estimator that’s suitable for your particular application. This post will be helpful when you ever wonder how everything works internally and gets overwelmed by the large codebase.
BaseEstimator is the abstract and base class for training and evaluating TensorFlow models. It provides the basic functionalities like
predict() by utilizing detailed logics hidden in
graph_actions.py to handle model inference, evaluation, and training, as well as
data_feeder.py to handle data batches fetching for different types of input (Note: in the future,
DataFeeder will be replaced by
learn.DataFrame). It also checks for compatibility of inputs in terms of
dtypes and whether inputs are sparse using
In the meantime,
BaseEstimator intializes the settings for monitors, checkpointing, etc. While providing most of the logics required for building and evaluating a customized model function, it leaves implementations for
_get_predict_ops() to its sub-classes, in order to give freedom to sub-classes that require custom handling.
BaseEstimator is also distributed, I’ve discussed briefly in my previous blogpost here.
Estimator implemented in the module is the perfect example of how to implement those functions that are left to be overriden by sub-classes of
targets as inputs, and then returns a tuple of train
Operation and loss
Tensor, using the customized model function. If you want to implement your own estimator, this also gives you freedom to decide whether
targets can be ignored if the estimator can be trained in unsupervised fashion.
_get_eval_ops() lets a sub-class to use customized metrics to evaluate each training step. A list of available metrics can be found in a couple of high-level modules in TensorFlow. It should return a dictionary of
Tensor object that represents the evaluation ops for the metrics specified.
_get_predict_ops() is implemented to customize predictions, e.g. probability v.s. actual prediction output. It returns a
Tensor or a dictionary of
Tensor object that represents prediction ops. You can then easily use super-class’s
predict() to achieve functionalities like
transform() similar to the one in Scikit-learn for unsupervised problems.
Examples of Estimators
Estimator already provides most of the implementations you need. For example,
LogisticRegressor only needs to provide its own metrics, such as AUC, accuracy, precision and recall, dedicated for only binary classification problems. So later a user can sub-class
LogisticRegressor to implement a estimator for binary classification without much further effort.
TensorForestEstimator has also been added to TensorFlow Learn recently. It hides most of the detailed implementations of Random Forests in
contrib.tensor_forest while utilizing some exposed high-level components to build the estimator so users can use
contrib.tensor_forest more easily.
For example, instead of passing all hyper-parameters to the contructor of
TensorForestEstimator, they are passed into
params in the contructor and the
params are filled by
params.fill() and later it will be used in Tensor Forest’s own
RandomForestGraphs for constructing the whole graph.
class TensorForestEstimator(estimator.BaseEstimator): """An estimator that can train and evaluate a random forest.""" def __init__(self, params, device_assigner=None, model_dir=None, graph_builder_class=tensor_forest.RandomForestGraphs, master='', accuracy_metric=None, tf_random_seed=None, verbose=1, config=None): self.params = params.fill()
Since there are a lot of more details for implementation of Random Forest’s inference (many of them have been written as separate kernels to speed things up), its
tensor_forest.RandomForestGraphs as its graph builder. It calls
graph_builder.inference_graph to get the prediction ops.
def _get_predict_ops(self, features): graph_builder = self.graph_builder_class( self.params, device_assigner=self.device_assigner, training=False, **self.construction_args) features, spec = data_ops.ParseDataTensorOrDict(features) return graph_builder.inference_graph(features, data_spec=spec)
Similarly, it uses
graph_builder.training_loss for the implementation of
_get_train_ops(). Note that
TensorForestEstimator uses functions in
tensor_forest.data.data_ops module, such as
ParseLabelTensorOrDict to parse the input features and labels.
A new estimator for K-means clustering has been added today, located in
contrib.factorization.python.ops.kmeans. Similar to
TensorForestEstimator, dedicated kernels are written to highly optimize the speed and only some of the surfaced high-level components are used for the implementation of the estimator. More examples of implementing different estimators can be found in
I highly recommend you taking a look to understand the underlying code structure better and start Implementing your own estimators!
Please do not hesitate to leave a message if you have any questions.
- Key Features of Scikit Flow Illustrated
- High-level Learn Module in TensorFlow
- Introduction to Scikit Flow and why you want to start learning TensorFlow
- DNNs, custom model and Digit recognition examples
- Categorical variables: One hot vs Distributed representation
- Scikit Flow: Easy Deep Learning with TensorFlow and Scikit-learn
Copyright © Yuan Tang 2023
Banner Credit to TensorFlow Org