diff options
-rw-r--r-- | lesana/collection.py | 77 | ||||
-rw-r--r-- | lesana/types.py | 10 | ||||
-rw-r--r-- | tests/test_types.py | 15 |
3 files changed, 57 insertions, 45 deletions
diff --git a/lesana/collection.py b/lesana/collection.py index b265b4d..4b3e964 100644 --- a/lesana/collection.py +++ b/lesana/collection.py @@ -1,4 +1,3 @@ -import decimal import logging import os import uuid @@ -8,6 +7,7 @@ import xapian import jinja2 from pkg_resources import resource_string +from . import types try: import git @@ -72,7 +72,14 @@ class Entry(object): @property def yaml_data(self): - return ruamel.yaml.dump(self.data, Dumper=ruamel.yaml.RoundTripDumper) + to_dump = self.data.copy() + # Decimal fields can't be represented by + # ruamel.yaml.RoundTripDumper, but transforming them to strings + # should be enough for all cases that we need. + for field in self.collection.settings['fields']: + if field['type'] == 'decimal': + to_dump[field['name']] = str(to_dump.get(field['name'], '')) + return ruamel.yaml.dump(to_dump, Dumper=ruamel.yaml.RoundTripDumper) @property def idterm(self): @@ -88,51 +95,26 @@ class Entry(object): for field in self.collection.settings['fields']: value = self.data.get(field['name'], None) t = field['type'] + try: + self.data[field['name']] = self.collection.types[t].load(value) + except KeyError: + errors.append( + { + 'field': field['name'], + 'error': "No such type {}".format(t), + } + ) + except types.LesanaValueError as e: + errors.append( + { + 'field': field['name'], + 'error': e, + } + ) + if t != 'list' and not value: # empty fields are always fine except for lists continue - if t == 'integer': - try: - int(value) - except ValueError: - valid = False - errors.append( - { - 'field': field['name'], - 'error': - 'Invalid value for integer field: {}'.format( - value - ), - } - ) - elif t == 'float': - try: - float(value) - except ValueError: - valid = False - errors.append( - { - 'field': field['name'], - 'error': - 'Invalid value for float field: {}'.format( - value - ), - } - ) - elif t == 'decimal': - try: - decimal.Decimal(value) - except decimal.InvalidOperation: - valid = False - errors.append( - { - 'field': field['name'], - 'error': - 'Invalid value for decimal field: {}'.format( - value - ), - } - ) elif t == 'list': if not hasattr(value, '__iter__'): valid = False @@ -168,6 +150,7 @@ class Collection(object): def __init__(self, directory=None, itemdir='items'): self.basedir = directory or os.getcwd() self.itemdir = os.path.join(self.basedir, itemdir) + self.types = self._load_types() try: with open(os.path.join(self.basedir, 'settings.yaml')) as fp: self.settings = ruamel.yaml.load( @@ -193,6 +176,12 @@ class Collection(object): self.safe = False self.entry_class = Entry + def _load_types(self): + type_loaders = {} + for t in types.LesanaType.__subclasses__(): + type_loaders[t.name] = t() + return type_loaders + def _index_file(self, fname, cache): with open(os.path.join(self.itemdir, fname)) as fp: if self.safe: diff --git a/lesana/types.py b/lesana/types.py index e093c9f..db6b661 100644 --- a/lesana/types.py +++ b/lesana/types.py @@ -25,6 +25,8 @@ class LesanaString(LesanaType): name = 'string' def load(self, data): + if not data: + return data return str(data) def empty(self): @@ -45,6 +47,8 @@ class LesanaInt(LesanaType): name = "integer" def load(self, data): + if not data: + return data try: return int(data) except ValueError: @@ -63,6 +67,8 @@ class LesanaFloat(LesanaType): name = "float" def load(self, data): + if not data: + return data try: return float(data) except ValueError: @@ -78,9 +84,11 @@ class LesanaDecimal(LesanaType): """ A floating point number """ - name = "float" + name = "decimal" def load(self, data): + if not data: + return data try: return decimal.Decimal(data) except decimal.InvalidOperation: diff --git a/tests/test_types.py b/tests/test_types.py index 907be20..7aff73f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -20,6 +20,9 @@ class testTypes(unittest.TestCase): s = checker.load("Hello World!") self.assertEqual(s, "Hello World!") + s = checker.load(None) + self.assertEqual(s, None) + def test_text(self): checker = types.LesanaText() @@ -29,6 +32,9 @@ class testTypes(unittest.TestCase): s = checker.load("Hello World!") self.assertEqual(s, "Hello World!") + s = checker.load(None) + self.assertEqual(s, None) + def test_int(self): checker = types.LesanaInt() @@ -45,6 +51,9 @@ class testTypes(unittest.TestCase): with self.assertRaises(types.LesanaValueError): checker.load(d) + v = checker.load(None) + self.assertEqual(v, None) + def test_float(self): checker = types.LesanaFloat() @@ -64,6 +73,9 @@ class testTypes(unittest.TestCase): with self.assertRaises(types.LesanaValueError): checker.load(d) + v = checker.load(None) + self.assertEqual(v, None) + def test_decimal(self): checker = types.LesanaDecimal() @@ -83,6 +95,9 @@ class testTypes(unittest.TestCase): with self.assertRaises(types.LesanaValueError): checker.load(d) + v = checker.load(None) + self.assertEqual(v, None) + if __name__ == '__main__': unittest.main() |