blob: ed1ed0ce3efe783c24e9e1d7cdf6288822b7e659 [file] [log] [blame]
#!/usr/bin/python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Train a model from the ICML-2015 paper.
"""
# pylint: disable=line-too-long
# pylint: enable=line-too-long
import os
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile
from biology import model_config
from biology.icml import icml_models
flags.DEFINE_string('config', None, 'Serialized ModelConfig proto.')
flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master.')
flags.DEFINE_string('logdir', None, 'Directory for output files.')
flags.DEFINE_integer('replica_id', 0, 'Task ID of this replica.')
flags.DEFINE_integer('ps_tasks', 0, 'Number of parameter server tasks.')
flags.DEFINE_integer('num_folds', 5, 'Number of cross-validation folds.')
flags.DEFINE_integer('fold', None, 'Fold index for this model.')
FLAGS = flags.FLAGS
def kfold_pattern(input_pattern, num_folds, fold=None):
"""Generator for train/test filename splits.
The pattern is not expanded except for the %d being replaced by the fold
index.
Args:
input_pattern: Input filename pattern. Should contain %d for fold index.
num_folds: Number of folds.
fold: If not None, the generator only yields the train/test split for the
given fold.
Yields:
train_filenames: A list of file patterns in training set.
test_filenames: A list of file patterns in test set.
"""
# get filenames associated with each fold
fold_filepatterns = [input_pattern % i for i in range(num_folds)]
# create train/test splits
for i in range(num_folds):
if fold is not None and i != fold:
continue
train = fold_filepatterns[:i] + fold_filepatterns[i+1:]
test = [fold_filepatterns[i]]
if any([f in test for f in train]):
logging.fatal('Train/test split is not complete.')
if set(train + test) != set(fold_filepatterns):
logging.fatal('Not all input files are accounted for.')
yield train, test
def Run(input_data_types=None):
"""Trains the model with specified parameters.
Args:
input_data_types: List of legacy_types_pb2 constants or None.
"""
config = model_config.ModelConfig({
'input_pattern': '', # Should have %d for fold index substitution.
'num_classification_tasks': 259,
'tasks_in_input': 259, # Dimensionality of sstables
'max_steps': 50000000,
'summaries': False,
'batch_size': 128,
'learning_rate': 0.0003,
'num_classes': 2,
'optimizer': 'sgd',
'penalty': 0.0,
'num_features': 1024,
'layer_sizes': [1200],
'weight_init_stddevs': [0.01],
'bias_init_consts': [0.5],
'dropouts': [0.0],
})
config.ReadFromFile(FLAGS.config,
overwrite='required')
if FLAGS.replica_id == 0:
gfile.MakeDirs(FLAGS.logdir)
config.WriteToFile(os.path.join(FLAGS.logdir, 'config.pbtxt'))
model = icml_models.IcmlModel(config,
train=True,
logdir=FLAGS.logdir,
master=FLAGS.master)
if FLAGS.num_folds is not None and FLAGS.fold is not None:
folds = kfold_pattern(config.input_pattern, FLAGS.num_folds,
FLAGS.fold)
train_pattern, _ = folds.next()
train_pattern = ','.join(train_pattern)
else:
train_pattern = config.input_pattern
with model.graph.as_default():
model.Train(model.ReadInput(train_pattern,
input_data_types=input_data_types),
max_steps=config.max_steps,
summaries=config.summaries,
replica_id=FLAGS.replica_id,
ps_tasks=FLAGS.ps_tasks)
def main(unused_argv=None):
Run()
if __name__ == '__main__':
flags.MarkFlagAsRequired('config')
flags.MarkFlagAsRequired('logdir')
flags.MarkFlagAsRequired('fold')
app.run()