From 1dee2ede32b1fcf47640cfc24e26e04e533146c0 Mon Sep 17 00:00:00 2001
From: Elena ``of Valhalla'' Grandi <valhalla@trueelena.org>
Date: Wed, 28 Oct 2020 11:52:26 +0100
Subject: Improve round trip loading of yaml files

---
 CHANGELOG.rst                                      |  3 ++
 lesana/collection.py                               | 37 ++++++++--------------
 .../items/5084bc6e94f24dc6976629282ef30419.yaml    | 15 +++++++++
 tests/test_collection.py                           | 23 ++++++++++----
 4 files changed, 49 insertions(+), 29 deletions(-)
 create mode 100644 tests/data/complex/items/5084bc6e94f24dc6976629282ef30419.yaml

diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 3f7fdc7..a16ac4a 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -5,6 +5,9 @@
 Unreleased
 ==========
 
+* Improved round trip loading of data results in less spurious changes
+  when editing entries.
+
 0.6.2
 =====
 
diff --git a/lesana/collection.py b/lesana/collection.py
index 21eb364..f399f62 100644
--- a/lesana/collection.py
+++ b/lesana/collection.py
@@ -1,3 +1,4 @@
+import io
 import logging
 import os
 import uuid
@@ -68,7 +69,7 @@ class Entry(object):
                 data += "{name}: []\n".format(**field)
             else:
                 data += "{name}: \n".format(**field)
-        return ruamel.yaml.load(data, Loader=ruamel.yaml.RoundTripLoader)
+        return self.collection.yaml.load(data)
 
     @property
     def yaml_data(self):
@@ -81,7 +82,9 @@ class Entry(object):
                 v = to_dump.get(field['name'], '')
                 if v:
                     to_dump[field['name']] = str(v)
-        return ruamel.yaml.dump(to_dump, Dumper=ruamel.yaml.RoundTripDumper)
+        s_io = io.StringIO()
+        self.collection.yaml.dump(to_dump, s_io)
+        return s_io.getvalue()
 
     @property
     def idterm(self):
@@ -130,13 +133,13 @@ class Collection(object):
     def __init__(self, directory=None, itemdir='items'):
         self.basedir = directory or os.getcwd()
         self.itemdir = os.path.join(self.basedir, itemdir)
+        self.yaml = ruamel.yaml.YAML()
+        self.yaml.preserve_quotes = True
         try:
             with open(os.path.join(self.basedir, 'settings.yaml')) as fp:
-                self.settings = ruamel.yaml.load(
-                    fp, ruamel.yaml.RoundTripLoader
-                )
+                self.settings = self.yaml.load(fp)
         except FileNotFoundError:
-            self.settings = ruamel.yaml.safe_load("{}")
+            self.settings = self.yaml.load("{}")
         self.fields = self._load_field_types()
         os.makedirs(os.path.join(self.basedir, '.lesana'), exist_ok=True)
         if 'lang' in self.settings:
@@ -151,9 +154,6 @@ class Collection(object):
         else:
             self.stemmer = xapian.Stem('english')
         self._enquire = None
-        # This selects whether to load all other yaml files with
-        # safe_load or load + RoundTripLoader
-        self.safe = False
         self.entry_class = Entry
 
     def _get_subsubclasses(self, cls):
@@ -186,10 +186,7 @@ class Collection(object):
 
     def _index_file(self, fname, cache):
         with open(os.path.join(self.itemdir, fname)) as fp:
-            if self.safe:
-                data = ruamel.yaml.safe_load(fp)
-            else:
-                data = ruamel.yaml.load(fp, ruamel.yaml.RoundTripLoader)
+            data = self.yaml.load(fp)
         entry = self.entry_class(self, data, fname)
         valid, errors = entry.validate()
         if not valid:
@@ -389,12 +386,7 @@ class Collection(object):
 
     def _doc_to_entry(self, doc):
         fname = doc.get_value(0).decode('utf-8')
-        if self.safe:
-            data = ruamel.yaml.safe_load(doc.get_data())
-        else:
-            data = ruamel.yaml.load(
-                doc.get_data(), ruamel.yaml.RoundTripLoader
-            )
+        data = self.yaml.load(doc.get_data())
         entry = self.entry_class(self, data=data, fname=fname,)
         return entry
 
@@ -522,13 +514,12 @@ class Collection(object):
             skel = resource_string('lesana', 'data/settings.yaml').decode(
                 'utf-8'
             )
-            skel_dict = ruamel.yaml.load(skel, ruamel.yaml.RoundTripLoader)
+            yaml = ruamel.yaml.YAML()
+            skel_dict = yaml.load(skel)
             skel_dict['git'] = git_enabled
             skel_dict.update(settings)
             with open(filepath, 'w') as fp:
-                ruamel.yaml.dump(
-                    skel_dict, stream=fp, Dumper=ruamel.yaml.RoundTripDumper
-                )
+                yaml.dump(skel_dict, stream=fp)
         if edit_file:
             edit_file(filepath)
         if git_enabled and repo:
diff --git a/tests/data/complex/items/5084bc6e94f24dc6976629282ef30419.yaml b/tests/data/complex/items/5084bc6e94f24dc6976629282ef30419.yaml
new file mode 100644
index 0000000..874833e
--- /dev/null
+++ b/tests/data/complex/items/5084bc6e94f24dc6976629282ef30419.yaml
@@ -0,0 +1,15 @@
+# This entry has a comment at the beginning
+name: 'A commented entry'
+# ruamel.yaml does not support preserving indent levels, so please leave the
+# description indented by two spaces.
+description: |
+  An entry with comments in the yaml data
+position: 'there'
+# There is a comment above something
+something:
+tags: []
+keywords: []
+exists: true
+with_default: default value
+amount: 1
+# and a comment at the end
diff --git a/tests/test_collection.py b/tests/test_collection.py
index f7ddf6d..bbc35ba 100644
--- a/tests/test_collection.py
+++ b/tests/test_collection.py
@@ -153,12 +153,6 @@ class testSimpleCollection(unittest.TestCase):
         self.collection.update_cache()
         self.assertIsNotNone(self.collection.stemmer)
 
-    def test_load_safe(self):
-        # Simply run the code with self.collection.safe = True to check
-        # that it doesn't break.
-        self.collection.safe = True
-        self.collection.update_cache()
-
     def test_full_search(self):
         self.collection.start_search('Item')
         res = self.collection.get_all_search_results()
@@ -321,6 +315,23 @@ class testComplexCollection(unittest.TestCase):
         for f in to_test:
             self.assertIsInstance(self.collection.fields[f[0]], f[1])
 
+    def test_comments_are_preserved(self):
+        e = self.collection.entry_from_eid('5084bc6e94f24dc6976629282ef30419')
+        yaml_data = e.yaml_data
+        self.assertTrue(
+            yaml_data.startswith("# This entry has a comment at the beginning")
+        )
+        self.assertTrue(
+            yaml_data.endswith("# and a comment at the end\n")
+        )
+
+    def test_data_is_stored_as_written_on_file(self):
+        e = self.collection.entry_from_eid('5084bc6e94f24dc6976629282ef30419')
+        fname = 'tests/data/complex/items/' + \
+            '5084bc6e94f24dc6976629282ef30419.yaml'
+        with open(fname, 'r') as fp:
+            self.assertEqual(e.yaml_data, fp.read())
+
 
 class testCollectionWithErrors(unittest.TestCase):
     def setUp(self):
-- 
cgit v1.2.3