summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lesana/collection.py8
-rw-r--r--lesana/types.py3
-rw-r--r--tests/data/derivative/items/48d73d796c0b47af964722e154fe879c.yaml2
-rw-r--r--tests/data/derivative/settings.yaml9
-rw-r--r--tests/test_derivatives.py40
-rw-r--r--tests/test_types.py32
6 files changed, 77 insertions, 17 deletions
diff --git a/lesana/collection.py b/lesana/collection.py
index 5b7b809..a449de1 100644
--- a/lesana/collection.py
+++ b/lesana/collection.py
@@ -150,7 +150,7 @@ class Collection(object):
)
except FileNotFoundError:
self.settings = ruamel.yaml.safe_load("{}")
- self.fields = self.load_field_types()
+ self.fields = self._load_field_types()
os.makedirs(os.path.join(self.basedir, '.lesana'), exist_ok=True)
if 'lang' in self.settings:
try:
@@ -174,14 +174,14 @@ class Collection(object):
yield c
yield from self._get_subsubclasses(c)
- def load_field_types(self):
+ def _load_field_types(self):
type_loaders = {}
for t in self._get_subsubclasses(types.LesanaType):
type_loaders[t.name] = t
fields = {}
for field in self.settings.get('fields', []):
try:
- fields[field['name']] = type_loaders[field['type']]()
+ fields[field['name']] = type_loaders[field['type']](field)
except KeyError:
# unknown fields are treated as if they were
# (unvalidated) generic YAML to support working with
@@ -191,7 +191,7 @@ class Collection(object):
field['type'],
field['name'],
)
- fields[field['name']] = types.LesanaYAML()
+ fields[field['name']] = types.LesanaYAML(field)
return fields
def _index_file(self, fname, cache):
diff --git a/lesana/types.py b/lesana/types.py
index a252830..c4061f2 100644
--- a/lesana/types.py
+++ b/lesana/types.py
@@ -14,6 +14,9 @@ class LesanaType:
"""
Base class for lesana field types.
"""
+ def __init__(self, field):
+ self.field = field
+
def load(self, data):
raise NotImplementedError
diff --git a/tests/data/derivative/items/48d73d796c0b47af964722e154fe879c.yaml b/tests/data/derivative/items/48d73d796c0b47af964722e154fe879c.yaml
new file mode 100644
index 0000000..9da6b21
--- /dev/null
+++ b/tests/data/derivative/items/48d73d796c0b47af964722e154fe879c.yaml
@@ -0,0 +1,2 @@
+name: 'An item'
+unknown: 'future'
diff --git a/tests/data/derivative/settings.yaml b/tests/data/derivative/settings.yaml
new file mode 100644
index 0000000..5f3f826
--- /dev/null
+++ b/tests/data/derivative/settings.yaml
@@ -0,0 +1,9 @@
+name: "Derivative lesana collection"
+lang: 'english'
+fields:
+ - name: name
+ type: string
+ index: free
+ - name: unknown
+ type: derived
+ index: free
diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py
new file mode 100644
index 0000000..f79123c
--- /dev/null
+++ b/tests/test_derivatives.py
@@ -0,0 +1,40 @@
+import shutil
+import tempfile
+import unittest
+
+import lesana
+from lesana import types
+
+
+class DerivedType(types.LesanaString):
+ """
+ A custom type
+ """
+ name = 'derived'
+
+
+class Derivative(lesana.Collection):
+ """
+ A class serived from lesana.Collection
+ """
+
+
+class testDerivatives(unittest.TestCase):
+ def setUp(self):
+ self.tmpdir = tempfile.mkdtemp()
+ shutil.copytree(
+ 'tests/data/derivative',
+ self.tmpdir,
+ dirs_exist_ok=True
+ )
+ self.collection = Derivative(self.tmpdir)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdir)
+
+ def test_load_subclasses(self):
+ self.assertIsInstance(self.collection.fields['unknown'], DerivedType)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/test_types.py b/tests/test_types.py
index f27089d..363d9a4 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -12,8 +12,14 @@ class testTypes(unittest.TestCase):
def tearDown(self):
pass
+ def _get_field_def(self, type_name):
+ return {
+ 'type': type_name,
+ 'name': 'test_field',
+ }
+
def test_base(self):
- checker = types.LesanaType()
+ checker = types.LesanaType(self._get_field_def('base'))
# The base class does not implement empty nor load
with self.assertRaises(NotImplementedError):
@@ -23,7 +29,7 @@ class testTypes(unittest.TestCase):
checker.load("")
def test_string(self):
- checker = types.LesanaString()
+ checker = types.LesanaString(self._get_field_def('string'))
s = checker.empty()
self.assertEqual(s, "")
@@ -35,7 +41,7 @@ class testTypes(unittest.TestCase):
self.assertEqual(s, None)
def test_text(self):
- checker = types.LesanaText()
+ checker = types.LesanaText(self._get_field_def('text'))
s = checker.empty()
self.assertEqual(s, "")
@@ -47,7 +53,7 @@ class testTypes(unittest.TestCase):
self.assertEqual(s, None)
def test_int(self):
- checker = types.LesanaInt()
+ checker = types.LesanaInt(self._get_field_def('integer'))
v = checker.empty()
self.assertEqual(v, 0)
@@ -66,7 +72,7 @@ class testTypes(unittest.TestCase):
self.assertEqual(v, None)
def test_float(self):
- checker = types.LesanaFloat()
+ checker = types.LesanaFloat(self._get_field_def('float'))
v = checker.empty()
self.assertEqual(v, 0.0)
@@ -88,7 +94,7 @@ class testTypes(unittest.TestCase):
self.assertEqual(v, None)
def test_decimal(self):
- checker = types.LesanaDecimal()
+ checker = types.LesanaDecimal(self._get_field_def('decimal'))
v = checker.empty()
self.assertEqual(v, decimal.Decimal(0))
@@ -110,7 +116,7 @@ class testTypes(unittest.TestCase):
self.assertEqual(v, None)
def test_timestamp(self):
- checker = types.LesanaTimestamp()
+ checker = types.LesanaTimestamp(self._get_field_def('timestamp'))
v = checker.empty()
self.assertEqual(v, None)
@@ -136,7 +142,7 @@ class testTypes(unittest.TestCase):
self.assertEqual(v, None)
def test_datetime(self):
- checker = types.LesanaDatetime()
+ checker = types.LesanaDatetime(self._get_field_def('datetime'))
v = checker.empty()
self.assertEqual(v, None)
@@ -165,7 +171,7 @@ class testTypes(unittest.TestCase):
self.assertEqual(v, None)
def test_date(self):
- checker = types.LesanaDate()
+ checker = types.LesanaDate(self._get_field_def('date'))
v = checker.empty()
self.assertEqual(v, None)
@@ -194,7 +200,7 @@ class testTypes(unittest.TestCase):
self.assertEqual(v, None)
def test_boolean(self):
- checker = types.LesanaBoolean()
+ checker = types.LesanaBoolean(self._get_field_def('boolean'))
v = checker.empty()
self.assertEqual(v, None)
@@ -210,7 +216,7 @@ class testTypes(unittest.TestCase):
self.assertEqual(v, None)
def test_file(self):
- checker = types.LesanaFile()
+ checker = types.LesanaFile(self._get_field_def('file'))
v = checker.empty()
self.assertEqual(v, "")
@@ -224,7 +230,7 @@ class testTypes(unittest.TestCase):
# TODO: check for invalid file paths
def test_url(self):
- checker = types.LesanaURL()
+ checker = types.LesanaURL(self._get_field_def('url'))
v = checker.empty()
self.assertEqual(v, "")
@@ -238,7 +244,7 @@ class testTypes(unittest.TestCase):
# TODO: check for invalid URLs
def test_yaml(self):
- checker = types.LesanaYAML()
+ checker = types.LesanaYAML(self._get_field_def('yaml'))
v = checker.empty()
self.assertEqual(v, None)