summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lesana/command.py30
1 files changed, 17 insertions, 13 deletions
diff --git a/lesana/command.py b/lesana/command.py
index 7f4144b..f2f3e56 100644
--- a/lesana/command.py
+++ b/lesana/command.py
@@ -62,6 +62,10 @@ def edit_file_in_external_editor(filepath):
class Command():
help = ''
+ def __init__(self, collection_class=Collection, entry_class=Entry):
+ self.collection_class = collection_class
+ self.entry_class = entry_class
+
def _main(self, args):
self.args = args
self.main()
@@ -80,8 +84,8 @@ class New(Command):
]
def main(self):
- collection = Collection(self.args.collection)
- new_entry = Entry(collection)
+ collection = self.collection_class(self.args.collection)
+ new_entry = self.entry_class(collection)
collection.save_entries([new_entry])
filepath = os.path.join(
collection.itemdir,
@@ -111,8 +115,8 @@ class Edit(Command):
]
def main(self):
- collection = Collection(self.args.collection)
- entries = collection.entries_from_short_eid(self.args.eid)
+ collection = self.collection_class(self.args.collection)
+ entries = collection.entries_from_short_uid(self.args.eid)
if len(entries) > 1:
return "{} is not an unique eid".format(self.args.eid)
if not entries:
@@ -146,8 +150,8 @@ class Show(Command):
]
def main(self):
- collection = Collection(self.args.collection)
- entries = collection.entries_from_short_eid(self.args.eid)
+ collection = self.collection_class(self.args.collection)
+ entries = collection.entries_from_short_uid(self.args.eid)
if len(entries) > 1:
return "{} is not an unique eid".format(self.args.eid)
if not entries:
@@ -178,7 +182,7 @@ class Index(Command):
]
def main(self):
- collection = Collection(self.args.collection)
+ collection = self.collection_class(self.args.collection)
if self.args.files:
files = (os.path.basename(f) for f in self.args.files)
else:
@@ -225,7 +229,7 @@ class Search(Command):
)
offset = self.args.offset or 0
pagesize = self.args.pagesize or 12
- collection = Collection(self.args.collection)
+ collection = self.collection_class(self.args.collection)
if self.args.query == ['*']:
results = collection.get_all_documents()
else:
@@ -267,8 +271,8 @@ class Export(Command):
]
def main(self):
- collection = Collection(self.args.collection)
- destination = Collection(self.args.destination)
+ collection = self.collection_class(self.args.collection)
+ destination = self.collection_class(self.args.destination)
if not self.args.query:
results = collection.get_all_documents()
else:
@@ -288,7 +292,7 @@ class Export(Command):
logging.error("Error converting entry: {}".format(entry))
logging.error("{}".format(e))
sys.exit(1)
- e = Entry(destination, data=data)
+ e = self.entry_class(destination, data=data)
destination.save_entries([e])
@@ -306,7 +310,7 @@ class Init(Command):
]
def main(self):
- Collection.init(
+ self.collection_class.init(
self.args.collection,
git_enabled=self.args.git,
edit_file=edit_file_in_external_editor
@@ -325,5 +329,5 @@ class Remove(Command):
]
def main(self):
- collection = Collection(self.args.collection)
+ collection = self.collection_class(self.args.collection)
collection.remove_entries(eids=self.args.entries)