diff --git a/smart_importer/predictor.py b/smart_importer/predictor.py index 4ec3b74..71593a5 100644 --- a/smart_importer/predictor.py +++ b/smart_importer/predictor.py @@ -42,6 +42,8 @@ class EntryPredictor(ImporterHook): string_tokenizer: Tokenizer can let smart_importer support more languages. This parameter should be an callable function with string parameter and the returning should be a list. + denylist_accounts: Transations with any of these accounts will be + removed from the training data. """ # pylint: disable=too-many-instance-attributes @@ -54,10 +56,12 @@ def __init__( predict=True, overwrite=False, string_tokenizer: Callable[[str], list] | None = None, + denylist_accounts: list[str] | None = None, ): super().__init__() self.training_data = None self.open_accounts: dict[str, str] = {} + self.denylist_accounts = set(denylist_accounts or []) self.pipeline: Pipeline | None = None self.is_fitted = False self.lock = threading.Lock() @@ -133,6 +137,8 @@ def training_data_filter(self, txn): for pos in txn.postings: if pos.account not in self.open_accounts: return False + if pos.account in self.denylist_accounts: + return False if self.account == pos.account: found_import_account = True return found_import_account or not self.account diff --git a/tests/predictors_test.py b/tests/predictors_test.py index 14c82e3..3460f6f 100644 --- a/tests/predictors_test.py +++ b/tests/predictors_test.py @@ -32,6 +32,9 @@ 2017-01-13 * "Gas Quick" Assets:US:BofA:Checking -17.45 USD + +2017-01-14 * "Axe Throwing with Joe" + Assets:US:BofA:Checking -13.37 USD """ ) @@ -43,6 +46,7 @@ 2016-01-01 open Expenses:Auto:Gas USD 2016-01-01 open Expenses:Food:Groceries USD 2016-01-01 open Expenses:Food:Restaurant USD +2016-01-01 open Expenses:Denylisted USD 2016-01-06 * "Farmer Fresh" "Buying groceries" Assets:US:BofA:Checking -2.50 USD @@ -93,6 +97,11 @@ 2016-01-12 * "Gas Quick" Assets:US:BofA:Checking -24.09 USD Expenses:Auto:Gas + +2016-01-08 * "Axe Throwing with Joe" + Assets:US:BofA:Checking -38.36 USD + Expenses:Denylisted + """ ) @@ -105,6 +114,7 @@ "Gimme Coffee", "Uncle Boons", None, + None, ] ACCOUNT_PREDICTIONS = [ @@ -116,8 +126,11 @@ "Expenses:Food:Coffee", "Expenses:Food:Groceries", "Expenses:Auto:Gas", + "Expenses:Food:Groceries", ] +DENYLISTED_ACCOUNTS = ["Expenses:Denylisted"] + class BasicTestImporter(ImporterProtocol): def extract(self, file, existing_entries=None): @@ -133,7 +146,10 @@ def file_account(self, file): PAYEE_IMPORTER = apply_hooks(BasicTestImporter(), [PredictPayees()]) -POSTING_IMPORTER = apply_hooks(BasicTestImporter(), [PredictPostings()]) +POSTING_IMPORTER = apply_hooks( + BasicTestImporter(), + [PredictPostings(denylist_accounts=DENYLISTED_ACCOUNTS)], +) def test_empty_training_data():