blob: 0c76402646b0141d618c3164f220f15b37a81097 [file] [log] [blame]
#!/usr/bin/python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluate a model from the ICML-2015 paper.
This script requires a trained model with its associated config and checkpoint.
If you don't have a trained model, run icml_train.py first.
"""
# pylint: disable=line-too-long
# pylint: enable=line-too-long
from nowhere.research.biology.collaborations.pande.py import utils
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile
from biology import model_config
from biology.icml import icml_models
flags.DEFINE_string('config', None, 'Serialized ModelConfig proto.')
flags.DEFINE_string('checkpoint', None,
'Model checkpoint file. File can contain either an '
'absolute checkpoint (e.g. model.ckpt-{step}) or a '
'serialized CheckpointState proto.')
flags.DEFINE_string('input_pattern', None, 'Input file pattern; '
'It should include %d for fold index substitution.')
flags.DEFINE_string('master', 'local', 'BNS name of the TensorFlow master.')
flags.DEFINE_string('logdir', None, 'Directory for output files.')
flags.DEFINE_integer('num_folds', 5, 'Number of cross-validation folds.')
flags.DEFINE_integer('fold', None, 'Fold index for this model.')
flags.DEFINE_enum('model_type', 'single', ['single', 'deep', 'deepaux', 'py',
'pydrop1', 'pydrop2'],
'Which model from the ICML paper should be trained/evaluated')
FLAGS = flags.FLAGS
def main(unused_argv=None):
config = model_config.ModelConfig()
config.ReadFromFile(FLAGS.config, overwrite='allowed')
gfile.MakeDirs(FLAGS.logdir)
model = icml_models.CONSTRUCTORS[FLAGS.model_type](config,
train=False,
logdir=FLAGS.logdir,
master=FLAGS.master)
if FLAGS.num_folds is not None and FLAGS.fold is not None:
folds = utils.kfold_pattern(FLAGS.input_pattern, FLAGS.num_folds,
FLAGS.fold)
_, test_pattern = folds.next()
test_pattern = ','.join(test_pattern)
else:
test_pattern = FLAGS.input_pattern
with model.graph.as_default():
model.Eval(model.ReadInput(test_pattern), FLAGS.checkpoint)
if __name__ == '__main__':
flags.MarkFlagAsRequired('config')
flags.MarkFlagAsRequired('checkpoint')
flags.MarkFlagAsRequired('input_pattern')
flags.MarkFlagAsRequired('logdir')
app.run()