From 88280380c7aec293a8211499d439a3712dd0df81 Mon Sep 17 00:00:00 2001 From: Patrick Ruckstuhl Date: Fri, 27 Dec 2024 16:02:53 +0000 Subject: [PATCH] Upgarde for beancount3/beangulp fixes #135 --- CHANGES | 5 +++++ setup.cfg | 3 ++- smart_importer/detector.py | 8 ++++---- smart_importer/hooks.py | 13 ++++++------- smart_importer/predictor.py | 2 +- tests/data_test.py | 14 ++++++++++---- tests/predictors_test.py | 29 +++++++++++++++-------------- 7 files changed, 43 insertions(+), 31 deletions(-) diff --git a/CHANGES b/CHANGES index 77c78ac..8532c9c 100644 --- a/CHANGES +++ b/CHANGES @@ -1,5 +1,10 @@ Changelog ========= +v0.6 (2024-12-27) +----------------- + +Upgrade to beancount 3 and beangulp. + v0.5 (2024-01-21) ----------------- diff --git a/setup.cfg b/setup.cfg index 0af1f60..495d397 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,8 @@ packages = find: setup_requires = setuptools_scm install_requires = - beancount>=2.3.5,<3.0.0 + beancount>=3 + beangulp scikit-learn>=1.0 numpy>=1.18.0 diff --git a/smart_importer/detector.py b/smart_importer/detector.py index 42b578e..b5932f3 100644 --- a/smart_importer/detector.py +++ b/smart_importer/detector.py @@ -2,7 +2,7 @@ import logging -from beancount.ingest import similar +from beangulp import similar from smart_importer.hooks import ImporterHook @@ -23,12 +23,12 @@ def __init__(self, comparator=None, window_days=2): self.comparator = comparator self.window_days = window_days - def __call__(self, importer, file, imported_entries, existing_entries): + def __call__(self, importer, file, imported_entries, existing): """Add duplicate metadata for imported transactions. Args: imported_entries: The list of imported entries. - existing_entries: The list of existing entries as passed to the + existing: The list of existing entries as passed to the importer. Returns: @@ -37,7 +37,7 @@ def __call__(self, importer, file, imported_entries, existing_entries): duplicate_pairs = similar.find_similar_entries( imported_entries, - existing_entries, + existing, self.comparator, self.window_days, ) diff --git a/smart_importer/hooks.py b/smart_importer/hooks.py index b767824..f2ca2e8 100644 --- a/smart_importer/hooks.py +++ b/smart_importer/hooks.py @@ -9,14 +9,14 @@ class ImporterHook: """Interface for an importer hook.""" - def __call__(self, importer, file, imported_entries, existing_entries): + def __call__(self, importer, file, imported_entries, existing): """Apply the hook and modify the imported entries. Args: importer: The importer that this hooks is being applied to. file: The file that is being imported. imported_entries: The current list of imported entries. - existing_entries: The existing entries, as passed to the extract + existing: The existing entries, as passed to the extract function. Returns: @@ -36,18 +36,17 @@ def apply_hooks(importer, hooks): unpatched_extract = importer.extract @wraps(unpatched_extract) - def patched_extract_method(file, existing_entries=None): + def patched_extract_method(filepath, existing=None): logger.debug("Calling the importer's extract method.") - imported_entries = unpatched_extract( - file, existing_entries=existing_entries - ) + imported_entries = unpatched_extract(filepath, existing=existing) for hook in hooks: imported_entries = hook( - importer, file, imported_entries, existing_entries + importer, filepath, imported_entries, existing ) return imported_entries importer.extract = patched_extract_method + importer.deduplicate = lambda entries, existing: None return importer diff --git a/smart_importer/predictor.py b/smart_importer/predictor.py index 71593a5..eeec34b 100644 --- a/smart_importer/predictor.py +++ b/smart_importer/predictor.py @@ -83,7 +83,7 @@ def __call__(self, importer, file, imported_entries, existing_entries): A list of entries, modified by this predictor. """ logging.debug("Running %s for file %s", self.__class__.__name__, file) - self.account = importer.file_account(file) + self.account = importer.account(file) self.load_training_data(existing_entries) with self.lock: self.define_pipeline() diff --git a/tests/data_test.py b/tests/data_test.py index 577ea71..f219f04 100644 --- a/tests/data_test.py +++ b/tests/data_test.py @@ -7,8 +7,8 @@ import pytest from beancount.core.compare import stable_hash_namedtuple -from beancount.ingest.importer import ImporterProtocol from beancount.parser import parser +from beangulp import Importer from smart_importer import PredictPostings, apply_hooks @@ -51,14 +51,20 @@ def test_testset(testset, string_tokenizer): # pylint: disable=unbalanced-tuple-unpacking imported, training_data, expected = _load_testset(testset) - class DummyImporter(ImporterProtocol): - def extract(self, file, existing_entries=None): + class DummyImporter(Importer): + def extract(self, filepath, existing=None): return imported + def account(self, filepath): + return "" + + def identify(self, filepath): + return True + importer = DummyImporter() apply_hooks(importer, [PredictPostings(string_tokenizer=string_tokenizer)]) imported_transactions = importer.extract( - "dummy-data", existing_entries=training_data + "dummy-data", existing=training_data ) for txn1, txn2 in zip(imported_transactions, expected): diff --git a/tests/predictors_test.py b/tests/predictors_test.py index 3460f6f..dd3de51 100644 --- a/tests/predictors_test.py +++ b/tests/predictors_test.py @@ -1,8 +1,8 @@ """Tests for the `PredictPayees` and the `PredictPostings` decorator""" # pylint: disable=missing-docstring -from beancount.ingest.importer import ImporterProtocol from beancount.parser import parser +from beangulp import Importer from smart_importer import PredictPayees, PredictPostings from smart_importer.hooks import apply_hooks @@ -132,18 +132,21 @@ DENYLISTED_ACCOUNTS = ["Expenses:Denylisted"] -class BasicTestImporter(ImporterProtocol): - def extract(self, file, existing_entries=None): - if file == "dummy-data": +class BasicTestImporter(Importer): + def extract(self, filepath, existing=None): + if filepath == "dummy-data": return TEST_DATA - if file == "empty": + if filepath == "empty": return [] assert False return [] - def file_account(self, file): + def account(self, filepath): return "Assets:US:BofA:Checking" + def identify(self, filepath): + return True + PAYEE_IMPORTER = apply_hooks(BasicTestImporter(), [PredictPayees()]) POSTING_IMPORTER = apply_hooks( @@ -166,8 +169,8 @@ def test_no_transactions(): """ POSTING_IMPORTER.extract("empty") PAYEE_IMPORTER.extract("empty") - POSTING_IMPORTER.extract("empty", existing_entries=TRAINING_DATA) - PAYEE_IMPORTER.extract("empty", existing_entries=TRAINING_DATA) + POSTING_IMPORTER.extract("empty", existing=TRAINING_DATA) + PAYEE_IMPORTER.extract("empty", existing=TRAINING_DATA) def test_unchanged_narrations(): @@ -178,7 +181,7 @@ def test_unchanged_narrations(): extracted_narrations = [ transaction.narration for transaction in PAYEE_IMPORTER.extract( - "dummy-data", existing_entries=TRAINING_DATA + "dummy-data", existing=TRAINING_DATA ) ] assert extracted_narrations == correct_narrations @@ -194,7 +197,7 @@ def test_unchanged_first_posting(): extracted_first_postings = [ transaction.postings[0] for transaction in PAYEE_IMPORTER.extract( - "dummy-data", existing_entries=TRAINING_DATA + "dummy-data", existing=TRAINING_DATA ) ] assert extracted_first_postings == correct_first_postings @@ -204,9 +207,7 @@ def test_payee_predictions(): """ Verifies that the decorator adds predicted postings. """ - transactions = PAYEE_IMPORTER.extract( - "dummy-data", existing_entries=TRAINING_DATA - ) + transactions = PAYEE_IMPORTER.extract("dummy-data", existing=TRAINING_DATA) predicted_payees = [transaction.payee for transaction in transactions] assert predicted_payees == PAYEE_PREDICTIONS @@ -218,7 +219,7 @@ def test_account_predictions(): predicted_accounts = [ entry.postings[-1].account for entry in POSTING_IMPORTER.extract( - "dummy-data", existing_entries=TRAINING_DATA + "dummy-data", existing=TRAINING_DATA ) ] assert predicted_accounts == ACCOUNT_PREDICTIONS