Skip to content

Commit

Permalink
ingest: add typed protocols for importers
Browse files Browse the repository at this point in the history
  • Loading branch information
yagebu committed Jan 21, 2024
1 parent 0f1d3bf commit f351018
Show file tree
Hide file tree
Showing 27 changed files with 349 additions and 132 deletions.
7 changes: 1 addition & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,12 @@ extend-select = [
"ALL",
]
extend-ignore = [
"A003", # allow class attributes to shadow builtins
"ANN101", # allow `self` to be untyped
"ANN401", # allow `typing.Any`
"C901", # ignore mccabe complecity for now
"CPY", # no copyright notices in all files
"D102", # allow undocumented methods
"D105", # allow magic methods to be undocumented
"D107", # allow __init__ to be undocumented - the class should be.
"DTZ", # allow naive dates (for now)
"E203", # black
"EM101", # allow long strings as error messages
"EM102", # allow f-strings as error messages
"PD", # pandas-related, has false-positives
Expand Down Expand Up @@ -171,7 +167,6 @@ max-args = 9
"docs/**" = ["D", "ANN", "INP"]
"tests/conftest.py" = ["S101"]
"tests/test_*.py" = ["D", "PLC2701", "S101", "SLF001"]
"tests/data/import_config.py" = ["ANN", "INP"]
"src/fava/ext/portfolio_list/__init__.py" = ["ANN"]
"tests/data/import_config.py" = ["D", "INP"]
"stubs/**" = ["D"]
"src/fava/core/filters.py" = ["D"]
3 changes: 2 additions & 1 deletion src/fava/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dataclasses import fields
from datetime import date
from datetime import datetime
from datetime import timezone
from functools import lru_cache
from io import BytesIO
from pathlib import Path
Expand Down Expand Up @@ -372,7 +373,7 @@ def download_query(result_format: str) -> Response:
@fava_app.route("/<bfile>/download-journal/")
def download_journal() -> Response:
"""Download a Journal file."""
now = datetime.now().replace(microsecond=0)
now = datetime.now(tz=timezone.utc).replace(microsecond=0)
filename = f"journal_{now.isoformat()}.beancount"
data = BytesIO(bytes(render_template("beancount_file"), "utf8"))
return send_file(data, as_attachment=True, download_name=filename)
Expand Down
26 changes: 17 additions & 9 deletions src/fava/beans/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from decimal import Decimal
from typing import overload
from typing import TYPE_CHECKING

from beancount.core import data
Expand All @@ -16,6 +16,7 @@

if TYPE_CHECKING: # pragma: no cover
import datetime
from decimal import Decimal

from fava.beans.abc import Balance
from fava.beans.abc import Cost
Expand All @@ -27,20 +28,27 @@
from fava.beans.flags import Flag


def decimal(num: Decimal | str) -> Decimal:
"""Decimal from a string."""
if isinstance(num, str):
return Decimal(num)
return num
@overload
def amount(amt: Amount) -> Amount: ...


def amount(amt: Amount | tuple[Decimal, str] | str) -> Amount:
"""Amount from a string."""
@overload
def amount(amt: str) -> Amount: ...


@overload
def amount(amt: Decimal, currency: str) -> Amount: ...


def amount(amt: Amount | Decimal | str, currency: str | None = None) -> Amount:
"""Amount from a string or tuple."""
if isinstance(amt, Amount):
return amt
if isinstance(amt, str):
return BEANCOUNT_A(amt) # type: ignore[no-any-return]
return BeancountAmount(*amt) # type: ignore[return-value]
if not isinstance(currency, str):
raise TypeError
return BeancountAmount(amt, currency) # type: ignore[return-value]


def position(units: Amount, cost: Cost | None) -> Position:
Expand Down
74 changes: 74 additions & 0 deletions src/fava/beans/ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Types for Beancount importers."""

from __future__ import annotations

from typing import Protocol
from typing import runtime_checkable
from typing import TYPE_CHECKING
from typing import TypeVar

if TYPE_CHECKING:
import datetime
from collections.abc import Callable

from fava.beans.abc import Directive


T = TypeVar("T")


class FileMemo(Protocol):
"""The file with caching support that is passed to importers."""

name: str

def convert(self, converter_func: Callable[[str], T]) -> T:
"""Run a conversion function for the file."""

def mimetype(self) -> str:
"""Get the mimetype of the file."""

def contents(self) -> str:
"""Get the file contents."""


@runtime_checkable
class BeanImporterProtocol(Protocol):
"""Interface for Beancount importers.
typing.Protocol version of beancount.ingest.importer.ImporterProtocol
Importers can subclass from this one instead of the Beancount one to
get type checking for the methods.
"""

def name(self) -> str:
"""Return a unique id/name for this importer."""
cls = self.__class__
return f"{cls.__module__}.{cls.__name__}"

def identify(self, file: FileMemo) -> bool:
"""Return true if this importer matches the given file."""

def extract(
self,
file: FileMemo, # noqa: ARG002
*,
existing_entries: list[Directive] | None = None, # noqa: ARG002
) -> list[Directive] | None:
"""Extract transactions from a file."""
return None

def file_account(self, file: FileMemo) -> str:
"""Return an account associated with the given file."""

def file_name(self, file: FileMemo) -> str | None: # noqa: ARG002
"""A filter that optionally renames a file before filing."""
return None

def file_date(
self,
file: FileMemo, # noqa: ARG002
) -> datetime.date | None:
"""Attempt to obtain a date that corresponds to the given file."""
return None
8 changes: 5 additions & 3 deletions src/fava/core/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import datetime
from dataclasses import dataclass
from dataclasses import field
from typing import Dict
Expand All @@ -17,8 +16,11 @@
from fava.core.group_entries import TransactionPosting
from fava.core.module_base import FavaModule
from fava.core.tree import Tree
from fava.util.date import local_today

if TYPE_CHECKING: # pragma: no cover
import datetime

from fava.beans.abc import Directive
from fava.beans.abc import Meta
from fava.core.tree import TreeNode
Expand Down Expand Up @@ -72,7 +74,7 @@ def uptodate_status(
def balance_string(tree_node: TreeNode) -> str:
"""Balance directive for the given account for today."""
account = tree_node.name
today = str(datetime.date.today())
today = str(local_today())
res = ""
for currency, number in units(tree_node.balance).items():
res += f"{today} balance {account:<28} {number:>15} {currency}\n"
Expand Down Expand Up @@ -129,7 +131,7 @@ def setdefault(
self[key] = AccountData()
return self[key]

def load_file(self) -> None:
def load_file(self) -> None: # noqa: D102
self.clear()
entries_by_account = group_entries_by_account(self.ledger.all_entries)
tree = Tree(self.ledger.all_entries)
Expand Down
2 changes: 1 addition & 1 deletion src/fava/core/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, ledger: FavaLedger) -> None:
self.tags: list[str] = []
self.years: list[str] = []

def load_file(self) -> None:
def load_file(self) -> None: # noqa: D102
all_entries = self.ledger.all_entries

all_links = set()
Expand Down
2 changes: 1 addition & 1 deletion src/fava/core/budgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, ledger: FavaLedger) -> None:
super().__init__(ledger)
self.budget_entries: BudgetDict = {}

def load_file(self) -> None:
def load_file(self) -> None: # noqa: D102
self.budget_entries, errors = parse_budgets(
self.ledger.all_entries_by_type.Custom,
)
Expand Down
4 changes: 2 additions & 2 deletions src/fava/core/charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def loads(s: str | bytes) -> Any:
class FavaJSONProvider(JSONProvider):
"""Use custom JSON encoder and decoder."""

def dumps(self, obj: Any, **_kwargs: Any) -> str:
def dumps(self, obj: Any, **_kwargs: Any) -> str: # noqa: D102
return dumps(obj)

def loads(self, s: str | bytes, **_kwargs: Any) -> Any:
def loads(self, s: str | bytes, **_kwargs: Any) -> Any: # noqa: D102
return simplejson_loads(s)


Expand Down
2 changes: 1 addition & 1 deletion src/fava/core/commodities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, ledger: FavaLedger) -> None:
self.names: dict[str, str] = {}
self.precisions: dict[str, int] = {}

def load_file(self) -> None:
def load_file(self) -> None: # noqa: D102
self.names = {}
self.precisions = {}
for commodity in self.ledger.all_entries_by_type.Commodity:
Expand Down
9 changes: 8 additions & 1 deletion src/fava/core/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, ledger: FavaLedger) -> None:
self._instances: dict[str, FavaExtensionBase] = {}
self._loaded_extensions: set[type[FavaExtensionBase]] = set()

def load_file(self) -> None:
def load_file(self) -> None: # noqa: D102
all_extensions = []
custom_entries = self.ledger.all_entries_by_type.Custom
_extension_entries = extension_entries(custom_entries)
Expand Down Expand Up @@ -71,22 +71,27 @@ def get_extension(self, name: str) -> FavaExtensionBase | None:
return self._instances.get(name, None)

def after_load_file(self) -> None:
"""Run all `after_load_file` hooks."""
for ext in self._exts:
ext.after_load_file()

def before_request(self) -> None:
"""Run all `before_request` hooks."""
for ext in self._exts:
ext.before_request()

def after_entry_modified(self, entry: Directive, new_lines: str) -> None:
"""Run all `after_entry_modified` hooks."""
for ext in self._exts:
ext.after_entry_modified(entry, new_lines)

def after_insert_entry(self, entry: Directive) -> None:
"""Run all `after_insert_entry` hooks."""
for ext in self._exts:
ext.after_insert_entry(entry)

def after_delete_entry(self, entry: Directive) -> None:
"""Run all `after_delete_entry` hooks."""
for ext in self._exts:
ext.after_delete_entry(entry)

Expand All @@ -96,10 +101,12 @@ def after_insert_metadata(
key: str,
value: str,
) -> None:
"""Run all `after_insert_metadata` hooks."""
for ext in self._exts:
ext.after_insert_metadata(entry, key, value)

def after_write_source(self, path: str, source: str) -> None:
"""Run all `after_write_source` hooks."""
for ext in self._exts:
ext.after_write_source(path, source)

Expand Down
39 changes: 29 additions & 10 deletions src/fava/core/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import datetime
import sys
import traceback
from dataclasses import dataclass
Expand All @@ -17,11 +16,15 @@
from beancount.ingest import extract
from beancount.ingest import identify

from fava.beans.ingest import BeanImporterProtocol
from fava.core.module_base import FavaModule
from fava.helpers import BeancountError
from fava.helpers import FavaAPIError
from fava.util.date import local_today

if TYPE_CHECKING: # pragma: no cover
import datetime

from fava.beans.abc import Directive
from fava.core import FavaLedger

Expand Down Expand Up @@ -49,7 +52,10 @@ class FileImporters:
importers: list[FileImportInfo]


def file_import_info(filename: str, importer: Any) -> FileImportInfo:
def file_import_info(
filename: str,
importer: BeanImporterProtocol,
) -> FileImportInfo:
"""Generate info about a file with an importer."""
file = cache.get_file(filename)
try:
Expand All @@ -62,7 +68,7 @@ def file_import_info(filename: str, importer: Any) -> FileImportInfo:
return FileImportInfo(
importer.name(),
account or "",
date or datetime.date.today(),
date or local_today(),
name or Path(filename).name,
)

Expand All @@ -73,7 +79,7 @@ class IngestModule(FavaModule):
def __init__(self, ledger: FavaLedger) -> None:
super().__init__(ledger)
self.config: list[Any] = []
self.importers: dict[str, Any] = {}
self.importers: dict[str, BeanImporterProtocol] = {}
self.hooks: list[Any] = []
self.mtime: int | None = None

Expand All @@ -86,9 +92,15 @@ def module_path(self) -> Path | None:
return self.ledger.join_path(config_path)

def _error(self, msg: str) -> None:
self.ledger.errors.append(IngestError(None, msg, None))
self.ledger.errors.append(
IngestError(
{"filename": str(self.module_path), "lineno": 0},
msg,
None,
),
)

def load_file(self) -> None:
def load_file(self) -> None: # noqa: D102
if self.module_path is None:
return
module_path = self.module_path
Expand Down Expand Up @@ -119,9 +131,16 @@ def load_file(self) -> None:
self._error(f"Error in importer '{module_path}': {message}")
else:
self.hooks = hooks
self.importers = {
importer.name(): importer for importer in self.config
}
self.importers = {}
for importer in self.config:
if not isinstance(importer, BeanImporterProtocol):
name = importer.__class__.__name__
self._error(
f"Importer class '{name}' in '{module_path}' does "
"not satisfy importer protocol",
)
else:
self.importers[importer.name()] = importer

def import_data(self) -> list[FileImporters]:
"""Identify files and importers that can be imported.
Expand All @@ -132,7 +151,7 @@ def import_data(self) -> list[FileImporters]:
if not self.config:
return []

ret: list[FileImporters] = []
ret = []

for directory in self.ledger.fava_options.import_dirs:
full_path = self.ledger.join_path(directory)
Expand Down
Loading

0 comments on commit f351018

Please sign in to comment.