blob: 25f4dadb00950bd4c1d96cbf4fb55d4b50a52a8f [file] [log] [blame]
#!/usr/bin/python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile
from tensorflow.python.platform import googletest
from biology import model_config
EXAMPLE_DICT = {
'hello': 'world',
'pi': 3.14159,
'Forty_Two': 42,
'great': True,
'spells': ['alohamora', 'expelliarmus'],
'scores': [9.8, 10.0],
'sizes': [2000, 100],
'waver': [True, False, True],
}
EXAMPLE_DEFAULTS = {
'hello': 'there',
'pi': 3.14,
'Forty_Two': 24,
'great': False,
'spells': ['abracadabra', 'cruciatus'],
'scores': [1.8, 1.0],
'sizes': [1200, 10],
'waver': [False, True, False],
}
EXAMPLE_FILE_CONTENTS = """parameter {
name: "Forty_Two"
int_value: 42
}
parameter {
name: "great"
bool_value: true
}
parameter {
name: "hello"
string_value: "world"
}
parameter {
name: "pi"
float_value: 3.14159
}
parameter {
name: "scores"
float_list: 9.8
float_list: 10.0
}
parameter {
name: "sizes"
int_list: 2000
int_list: 100
}
parameter {
name: "spells"
string_list: "alohamora"
string_list: "expelliarmus"
}
parameter {
name: "waver"
bool_list: true
bool_list: false
bool_list: true
}
"""
class ModelConfigTest(googletest.TestCase):
def setUp(self):
super(ModelConfigTest, self).setUp()
self.root = tempfile.mkdtemp(dir=flags.FLAGS.test_tmpdir)
def _assertMatchesExample(self, config):
self.assertEqual(config.hello, 'world')
self.assertEqual(config.pi, 3.14159)
self.assertEqual(config.Forty_Two, 42)
self.assertTrue(config.great)
self.assertEqual(config.scores, [9.8, 10.0])
self.assertEqual(config.sizes, [2000, 100])
self.assertEqual(config.spells, ['alohamora', 'expelliarmus'])
self.assertEqual(config.waver, [True, False, True])
def testCreatesAttributes(self):
config = model_config.ModelConfig(EXAMPLE_DICT)
self._assertMatchesExample(config)
def testGetOptionalParam(self):
config = model_config.ModelConfig(EXAMPLE_DICT)
self.assertEqual('world', config.GetOptionalParam('hello', 'everybody'))
self.assertEqual('default', config.GetOptionalParam('otherkey', 'default'))
def testOnlyValidAttributeNamesAllowed(self):
config = model_config.ModelConfig()
with self.assertRaises(ValueError):
config.AddParam('spaces not allowed',
'blah',
overwrite='forbidden')
with self.assertRaises(ValueError):
config.AddParam('42_must_start_with_letter',
'blah',
overwrite='forbidden')
with self.assertRaises(ValueError):
config.AddParam('hyphens-not-allowed',
'blah',
overwrite='forbidden')
with self.assertRaises(ValueError):
config.AddParam('',
'empty string no good',
overwrite='forbidden')
def testDuplicateKeysNotAllowed(self):
config = model_config.ModelConfig(EXAMPLE_DICT)
with self.assertRaises(ValueError):
config.AddParam('hello',
'everybody',
overwrite='forbidden')
def testRequireDefault(self):
config = model_config.ModelConfig(EXAMPLE_DICT)
config.AddParam('hello',
'everybody',
overwrite='required')
with self.assertRaises(ValueError):
config.AddParam('not',
'present',
overwrite='required')
def testSilentOverwrite(self):
config = model_config.ModelConfig(EXAMPLE_DICT)
config.AddParam('not', 'present', overwrite='allowed')
config.AddParam('not', 'anymore', overwrite='allowed')
def testHeterogeneousList(self):
config = model_config.ModelConfig()
with self.assertRaises(ValueError):
config.AddParam('different',
['types for', 'different', 0xF, 0x0, 'lks'],
overwrite='forbidden')
def testWritesFile(self):
config = model_config.ModelConfig(EXAMPLE_DICT)
filename = os.path.join(self.root, 'config.pbtxt')
config.WriteToFile(filename)
with open(filename) as pbtxt_file:
self.assertEqual(EXAMPLE_FILE_CONTENTS, pbtxt_file.read())
def testReadsFile_NoDuplicates(self):
filename = os.path.join(self.root, 'config.pbtxt')
with open(filename, 'w') as pbtxt_file:
pbtxt_file.write(EXAMPLE_FILE_CONTENTS)
config = model_config.ModelConfig()
config.ReadFromFile(filename, overwrite='forbidden')
self._assertMatchesExample(config)
def testReadsFile_RequireDefaults(self):
filename = os.path.join(self.root, 'config.pbtxt')
with open(filename, 'w') as pbtxt_file:
pbtxt_file.write(EXAMPLE_FILE_CONTENTS)
self.assertEqual(set(EXAMPLE_DEFAULTS.keys()), set(EXAMPLE_DICT.keys()))
config = model_config.ModelConfig(EXAMPLE_DEFAULTS)
config.ReadFromFile(filename, overwrite='required')
self._assertMatchesExample(config)
if __name__ == '__main__':
googletest.main()