diff options
| -rw-r--r-- | lesana/collection.py | 8 | ||||
| -rw-r--r-- | lesana/types.py | 3 | ||||
| -rw-r--r-- | tests/data/derivative/items/48d73d796c0b47af964722e154fe879c.yaml | 2 | ||||
| -rw-r--r-- | tests/data/derivative/settings.yaml | 9 | ||||
| -rw-r--r-- | tests/test_derivatives.py | 40 | ||||
| -rw-r--r-- | tests/test_types.py | 32 | 
6 files changed, 77 insertions, 17 deletions
| diff --git a/lesana/collection.py b/lesana/collection.py index 5b7b809..a449de1 100644 --- a/lesana/collection.py +++ b/lesana/collection.py @@ -150,7 +150,7 @@ class Collection(object):                  )          except FileNotFoundError:              self.settings = ruamel.yaml.safe_load("{}") -        self.fields = self.load_field_types() +        self.fields = self._load_field_types()          os.makedirs(os.path.join(self.basedir, '.lesana'), exist_ok=True)          if 'lang' in self.settings:              try: @@ -174,14 +174,14 @@ class Collection(object):              yield c              yield from self._get_subsubclasses(c) -    def load_field_types(self): +    def _load_field_types(self):          type_loaders = {}          for t in self._get_subsubclasses(types.LesanaType):              type_loaders[t.name] = t          fields = {}          for field in self.settings.get('fields', []):              try: -                fields[field['name']] = type_loaders[field['type']]() +                fields[field['name']] = type_loaders[field['type']](field)              except KeyError:                  # unknown fields are treated as if they were                  # (unvalidated) generic YAML to support working with @@ -191,7 +191,7 @@ class Collection(object):                      field['type'],                      field['name'],                  ) -                fields[field['name']] = types.LesanaYAML() +                fields[field['name']] = types.LesanaYAML(field)          return fields      def _index_file(self, fname, cache): diff --git a/lesana/types.py b/lesana/types.py index a252830..c4061f2 100644 --- a/lesana/types.py +++ b/lesana/types.py @@ -14,6 +14,9 @@ class LesanaType:      """      Base class for lesana field types.      """ +    def __init__(self, field): +        self.field = field +      def load(self, data):          raise NotImplementedError diff --git a/tests/data/derivative/items/48d73d796c0b47af964722e154fe879c.yaml b/tests/data/derivative/items/48d73d796c0b47af964722e154fe879c.yaml new file mode 100644 index 0000000..9da6b21 --- /dev/null +++ b/tests/data/derivative/items/48d73d796c0b47af964722e154fe879c.yaml @@ -0,0 +1,2 @@ +name: 'An item' +unknown: 'future' diff --git a/tests/data/derivative/settings.yaml b/tests/data/derivative/settings.yaml new file mode 100644 index 0000000..5f3f826 --- /dev/null +++ b/tests/data/derivative/settings.yaml @@ -0,0 +1,9 @@ +name: "Derivative lesana collection" +lang: 'english' +fields: +    - name: name +      type: string +      index: free +    - name: unknown +      type: derived +      index: free diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py new file mode 100644 index 0000000..f79123c --- /dev/null +++ b/tests/test_derivatives.py @@ -0,0 +1,40 @@ +import shutil +import tempfile +import unittest + +import lesana +from lesana import types + + +class DerivedType(types.LesanaString): +    """ +    A custom type +    """ +    name = 'derived' + + +class Derivative(lesana.Collection): +    """ +    A class serived from lesana.Collection +    """ + + +class testDerivatives(unittest.TestCase): +    def setUp(self): +        self.tmpdir = tempfile.mkdtemp() +        shutil.copytree( +            'tests/data/derivative', +            self.tmpdir, +            dirs_exist_ok=True +        ) +        self.collection = Derivative(self.tmpdir) + +    def tearDown(self): +        shutil.rmtree(self.tmpdir) + +    def test_load_subclasses(self): +        self.assertIsInstance(self.collection.fields['unknown'], DerivedType) + + +if __name__ == '__main__': +    unittest.main() diff --git a/tests/test_types.py b/tests/test_types.py index f27089d..363d9a4 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -12,8 +12,14 @@ class testTypes(unittest.TestCase):      def tearDown(self):          pass +    def _get_field_def(self, type_name): +        return { +            'type': type_name, +            'name': 'test_field', +        } +      def test_base(self): -        checker = types.LesanaType() +        checker = types.LesanaType(self._get_field_def('base'))          # The base class does not implement empty nor load          with self.assertRaises(NotImplementedError): @@ -23,7 +29,7 @@ class testTypes(unittest.TestCase):              checker.load("")      def test_string(self): -        checker = types.LesanaString() +        checker = types.LesanaString(self._get_field_def('string'))          s = checker.empty()          self.assertEqual(s, "") @@ -35,7 +41,7 @@ class testTypes(unittest.TestCase):          self.assertEqual(s, None)      def test_text(self): -        checker = types.LesanaText() +        checker = types.LesanaText(self._get_field_def('text'))          s = checker.empty()          self.assertEqual(s, "") @@ -47,7 +53,7 @@ class testTypes(unittest.TestCase):          self.assertEqual(s, None)      def test_int(self): -        checker = types.LesanaInt() +        checker = types.LesanaInt(self._get_field_def('integer'))          v = checker.empty()          self.assertEqual(v, 0) @@ -66,7 +72,7 @@ class testTypes(unittest.TestCase):          self.assertEqual(v, None)      def test_float(self): -        checker = types.LesanaFloat() +        checker = types.LesanaFloat(self._get_field_def('float'))          v = checker.empty()          self.assertEqual(v, 0.0) @@ -88,7 +94,7 @@ class testTypes(unittest.TestCase):          self.assertEqual(v, None)      def test_decimal(self): -        checker = types.LesanaDecimal() +        checker = types.LesanaDecimal(self._get_field_def('decimal'))          v = checker.empty()          self.assertEqual(v, decimal.Decimal(0)) @@ -110,7 +116,7 @@ class testTypes(unittest.TestCase):          self.assertEqual(v, None)      def test_timestamp(self): -        checker = types.LesanaTimestamp() +        checker = types.LesanaTimestamp(self._get_field_def('timestamp'))          v = checker.empty()          self.assertEqual(v, None) @@ -136,7 +142,7 @@ class testTypes(unittest.TestCase):          self.assertEqual(v, None)      def test_datetime(self): -        checker = types.LesanaDatetime() +        checker = types.LesanaDatetime(self._get_field_def('datetime'))          v = checker.empty()          self.assertEqual(v, None) @@ -165,7 +171,7 @@ class testTypes(unittest.TestCase):          self.assertEqual(v, None)      def test_date(self): -        checker = types.LesanaDate() +        checker = types.LesanaDate(self._get_field_def('date'))          v = checker.empty()          self.assertEqual(v, None) @@ -194,7 +200,7 @@ class testTypes(unittest.TestCase):          self.assertEqual(v, None)      def test_boolean(self): -        checker = types.LesanaBoolean() +        checker = types.LesanaBoolean(self._get_field_def('boolean'))          v = checker.empty()          self.assertEqual(v, None) @@ -210,7 +216,7 @@ class testTypes(unittest.TestCase):          self.assertEqual(v, None)      def test_file(self): -        checker = types.LesanaFile() +        checker = types.LesanaFile(self._get_field_def('file'))          v = checker.empty()          self.assertEqual(v, "") @@ -224,7 +230,7 @@ class testTypes(unittest.TestCase):          # TODO: check for invalid file paths      def test_url(self): -        checker = types.LesanaURL() +        checker = types.LesanaURL(self._get_field_def('url'))          v = checker.empty()          self.assertEqual(v, "") @@ -238,7 +244,7 @@ class testTypes(unittest.TestCase):          # TODO: check for invalid URLs      def test_yaml(self): -        checker = types.LesanaYAML() +        checker = types.LesanaYAML(self._get_field_def('yaml'))          v = checker.empty()          self.assertEqual(v, None) | 
