aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGELOG.rst3
-rw-r--r--lesana/collection.py4
-rw-r--r--lesana/templating.py27
-rw-r--r--tests/test_templating.py42
4 files changed, 73 insertions, 3 deletions
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 04a8463..c8c2786 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -5,7 +5,8 @@
Unreleased
==========
-* New data type: geo (for Geo URIs)
+* New data type: geo (for Geo URIs).
+* New custom filter for templates: to_yaml.
0.8.1
=====
diff --git a/lesana/collection.py b/lesana/collection.py
index a78de2b..a4546df 100644
--- a/lesana/collection.py
+++ b/lesana/collection.py
@@ -9,7 +9,7 @@ import xapian
import jinja2
from pkg_resources import resource_string
-from . import types
+from . import types, templating
try:
import git
@@ -473,7 +473,7 @@ class Collection(object):
self.update_cache([e.fname for e in changed])
def get_template(self, template_fname, searchpath='.'):
- env = jinja2.Environment(
+ env = templating.Environment(
loader=jinja2.FileSystemLoader(
searchpath=searchpath, followlinks=True,
),
diff --git a/lesana/templating.py b/lesana/templating.py
new file mode 100644
index 0000000..375a6f0
--- /dev/null
+++ b/lesana/templating.py
@@ -0,0 +1,27 @@
+"""
+Custom jinja2 filters and other templating helpers
+"""
+import decimal
+
+import jinja2
+import ruamel.yaml
+
+
+class Environment(jinja2.Environment):
+ """
+ A customized jinja2 environment that includes our filters.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.filters['to_yaml'] = to_yaml
+
+
+def to_yaml(data):
+ """
+ Return the yaml representation of data.
+ """
+ if isinstance(data, decimal.Decimal):
+ data = str(data)
+ return ruamel.yaml.dump(data).strip('...\n').strip()
diff --git a/tests/test_templating.py b/tests/test_templating.py
new file mode 100644
index 0000000..33889ae
--- /dev/null
+++ b/tests/test_templating.py
@@ -0,0 +1,42 @@
+import decimal
+import unittest
+
+from lesana import templating
+
+
+class testFilters(unittest.TestCase):
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ def test_to_yaml(self):
+ res = templating.to_yaml(None)
+ self.assertIsInstance(res, str)
+ self.assertEqual(res, 'null')
+
+ s = "A short string"
+ res = templating.to_yaml(s)
+ self.assertEqual(res, s)
+
+ s = """
+ A long, multiline
+ string
+ with multiple
+ lines
+ """
+ res = templating.to_yaml(s)
+ self.assertIsInstance(res, str)
+ self.assertIn('"', res)
+ self.assertIn('\n', res)
+
+ res = templating.to_yaml(10)
+ self.assertEqual(res, '10')
+
+ res = templating.to_yaml(decimal.Decimal('10.1'))
+ self.assertEqual(res, "'10.1'")
+
+
+if __name__ == '__main__':
+ unittest.main()