Skip to content

Commit

Permalink
ingest: initial beangulp Importer compat
Browse files Browse the repository at this point in the history
  • Loading branch information
yagebu committed Dec 15, 2024
1 parent f638350 commit 862d1e0
Show file tree
Hide file tree
Showing 11 changed files with 301 additions and 120 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ disable = [
"duplicate-code",
"unused-argument", # is checked by ruff (ARG001)
"stop-iteration-return",
"ungrouped-imports",
"invalid-unary-operand-type", # type-checking like, had false-positives
"not-an-iterable", # type-checking like, had false-positives
"unsubscriptable-object", # type-checking like, had false-positives
Expand Down
1 change: 1 addition & 0 deletions src/fava/beans/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
Meta: TypeAlias = Mapping[str, MetaValue]
TagsOrLinks: TypeAlias = set[str] | frozenset[str]
Account = str


class Amount(ABC):
Expand Down
4 changes: 2 additions & 2 deletions src/fava/beans/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def transaction(
flag: Flag,
payee: str | None,
narration: str,
tags: TagsOrLinks | None = None,
links: TagsOrLinks | None = None,
tags: frozenset[str] | None = None,
links: frozenset[str] | None = None,
postings: list[Posting] | None = None,
) -> Transaction:
"""Create a Beancount Transaction."""
Expand Down
152 changes: 93 additions & 59 deletions src/fava/core/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
from runpy import run_path
from typing import TYPE_CHECKING

from beangulp import Importer

try: # pragma: no cover
from beancount.ingest import cache # type: ignore[import-untyped]
from beancount.ingest import extract

DEFAULT_HOOKS = [extract.find_duplicate_entries]
BEANGULP = False
except ImportError:
from beangulp import cache # type: ignore[import-untyped]
from beangulp import cache

DEFAULT_HOOKS = []
BEANGULP = True

from fava.beans.ingest import BeanImporterProtocol
from fava.core.file import _incomplete_sortkey
Expand All @@ -35,6 +35,7 @@
if TYPE_CHECKING: # pragma: no cover
import datetime
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import Sequence
from typing import Callable

Expand All @@ -43,7 +44,7 @@
from fava.core import FavaLedger

HookOutput = list[tuple[str, list[Directive]]]
Hooks = list[Callable[[HookOutput, Sequence[Directive]], HookOutput]]
Hooks = Sequence[Callable[[HookOutput, Sequence[Directive]], HookOutput]]


class IngestError(BeancountError):
Expand All @@ -53,16 +54,15 @@ class IngestError(BeancountError):
class ImporterMethodCallError(FavaAPIError):
"""Error calling one of the importer methods."""

def __init__(self, err: Exception) -> None:
super().__init__(f"Error calling importer method: {err}")
def __init__(self) -> None:
super().__init__(
f"Error calling method on importer:\n\n{traceback.format_exc()}"
)


class ImporterExtractError(FavaAPIError):
class ImporterExtractError(ImporterMethodCallError):
"""Error calling extract for importer."""

def __init__(self, name: str, err: Exception) -> None:
super().__init__(f"Error calling extract for importer {name}: {err}")


class MissingImporterConfigError(FavaAPIError):
"""Missing import-config option."""
Expand Down Expand Up @@ -98,6 +98,8 @@ class ImportConfigLoadError(FavaAPIError):
def walk_dir(directory: Path) -> Iterable[Path]:
"""Walk through all files in dir.
Ignores common dot-directories like .git, .cache. .venv, see IGNORE_DIRS.
Args:
directory: The directory to start in.
Expand All @@ -112,24 +114,25 @@ def walk_dir(directory: Path) -> Iterable[Path]:


# Keep our own cache to also keep track of file mtimes
_CACHE: dict[str, tuple[int, FileMemo]] = {}
_CACHE: dict[Path, tuple[int, FileMemo]] = {}


def get_cached_file(filename: str) -> FileMemo:
def get_cached_file(path: Path) -> FileMemo:
"""Get a cached FileMemo.
This checks the file's mtime before getting it from the Cache.
In addition to using the beangulp cache.
"""
mtime = Path(filename).stat().st_mtime_ns
cached = _CACHE.get(filename)
mtime = path.stat().st_mtime_ns
filename = str(path)
cached = _CACHE.get(path)
if cached:
mtime_cached, memo_cached = cached
if mtime <= mtime_cached:
if mtime <= mtime_cached: # pragma: no cover
return memo_cached
memo: FileMemo = cache._FileMemo(filename) # noqa: SLF001
cache._CACHE[filename] = memo # noqa: SLF001
_CACHE[filename] = (mtime, memo)
_CACHE[path] = (mtime, memo)
return memo


Expand All @@ -152,21 +155,49 @@ class FileImporters:
importers: list[FileImportInfo]


def get_name(importer: BeanImporterProtocol | Importer) -> str:
"""Get the name of an importer."""
try:
if isinstance(importer, Importer):
return importer.name
return importer.name()
except Exception as err:
raise ImporterMethodCallError from err


def importer_identify(
importer: BeanImporterProtocol | Importer, path: Path
) -> bool:
"""Get the name of an importer."""
try:
if isinstance(importer, Importer):
return importer.identify(str(path))
return importer.identify(get_cached_file(path))
except Exception as err:
raise ImporterMethodCallError from err


def file_import_info(
filename: str,
importer: BeanImporterProtocol,
path: Path,
importer: BeanImporterProtocol | Importer,
) -> FileImportInfo:
"""Generate info about a file with an importer."""
file = get_cached_file(filename)
filename = str(path)
try:
account = importer.file_account(file)
date = importer.file_date(file)
name = importer.file_name(file)
if isinstance(importer, Importer):
account = importer.account(filename)
date = importer.date(filename)
name = importer.filename(filename)
else:
file = get_cached_file(path)
account = importer.file_account(file)
date = importer.file_date(file)
name = importer.file_name(file)
except Exception as err:
raise ImporterMethodCallError(err) from err
raise ImporterMethodCallError from err

return FileImportInfo(
importer.name(),
get_name(importer),
account or "",
date or local_today(),
name or Path(filename).name,
Expand All @@ -178,64 +209,67 @@ def file_import_info(


def find_imports(
config: Sequence[BeanImporterProtocol], directory: Path
config: Sequence[BeanImporterProtocol | Importer], directory: Path
) -> Iterable[FileImporters]:
"""Pair files and matching importers.
Yields:
For each file in directory, a pair of its filename and the matching
importers.
"""
for filename in walk_dir(directory):
stat = filename.stat()
if stat.st_size > _FILE_TOO_LARGE_THRESHOLD:
for path in walk_dir(directory):
stat = path.stat()
if stat.st_size > _FILE_TOO_LARGE_THRESHOLD: # pragma: no cover
continue

file = get_cached_file(str(filename))
importers = [
file_import_info(str(filename), importer)
file_import_info(path, importer)
for importer in config
if importer.identify(file)
if importer_identify(importer, path)
]
yield FileImporters(
name=str(filename), basename=filename.name, importers=importers
name=str(path), basename=path.name, importers=importers
)


def extract_from_file(
importer: BeanImporterProtocol,
filename: str,
importer: BeanImporterProtocol | Importer,
path: Path,
existing_entries: Sequence[Directive],
) -> list[Directive]:
"""Import entries from a document.
Args:
importer: The importer instance to handle the document.
filename: Filesystem path to the document.
path: Filesystem path to the document.
existing_entries: Existing entries.
Returns:
The list of imported entries.
"""
file = get_cached_file(filename)
entries = (
importer.extract(file, existing_entries=existing_entries)
if "existing_entries" in signature(importer.extract).parameters
else importer.extract(file)
)
if not entries:
return []
filename = str(path)
if isinstance(importer, Importer):
entries = importer.extract(filename, existing=existing_entries)
else:
file = get_cached_file(path)
entries = (
importer.extract(file, existing_entries=existing_entries)
if "existing_entries" in signature(importer.extract).parameters
else importer.extract(file)
) or []

if hasattr(importer, "sort"):
importer.sort(entries)
else:
entries.sort(key=_incomplete_sortkey)
if isinstance(importer, Importer):
importer.deduplicate(entries, existing=existing_entries)
return entries


def load_import_config(
module_path: Path,
) -> tuple[dict[str, BeanImporterProtocol], Hooks]:
) -> tuple[Mapping[str, BeanImporterProtocol | Importer], Hooks]:
"""Load the given import config and extract importers and hooks.
Args:
Expand All @@ -253,13 +287,13 @@ def load_import_config(
if "CONFIG" not in mod:
msg = "CONFIG is missing"
raise ImportConfigLoadError(msg)
if not isinstance(mod["CONFIG"], list):
if not isinstance(mod["CONFIG"], list): # pragma: no cover
msg = "CONFIG is not a list"
raise ImportConfigLoadError(msg)

config = mod["CONFIG"]
hooks = list(DEFAULT_HOOKS)
if "HOOKS" in mod:
hooks = DEFAULT_HOOKS
if "HOOKS" in mod: # pragma: no cover
hooks = mod["HOOKS"]
if not isinstance(hooks, list) or not all(
callable(fn) for fn in hooks
Expand All @@ -268,14 +302,16 @@ def load_import_config(
raise ImportConfigLoadError(msg)
importers = {}
for importer in config:
if not isinstance(importer, BeanImporterProtocol):
if not isinstance(
importer, (BeanImporterProtocol, Importer)
): # pragma: no cover
name = importer.__class__.__name__
msg = (
f"Importer class '{name}' in '{module_path}' does "
"not satisfy importer protocol"
)
raise ImportConfigLoadError(msg)
importers[importer.name()] = importer
importers[get_name(importer)] = importer
return importers, hooks


Expand All @@ -284,7 +320,7 @@ class IngestModule(FavaModule):

def __init__(self, ledger: FavaLedger) -> None:
super().__init__(ledger)
self.importers: dict[str, BeanImporterProtocol] = {}
self.importers: Mapping[str, BeanImporterProtocol | Importer] = {}
self.hooks: Hooks = []
self.mtime: int | None = None

Expand All @@ -306,9 +342,9 @@ def _error(self, msg: str) -> None:
)

def load_file(self) -> None: # noqa: D102
if self.module_path is None:
return
module_path = self.module_path
if module_path is None:
return

if not module_path.exists():
self._error("Import config does not exist")
Expand Down Expand Up @@ -356,20 +392,18 @@ def extract(self, filename: str, importer_name: str) -> list[Directive]:
if not self.module_path:
raise MissingImporterConfigError

if (
self.mtime is None
or self.module_path.stat().st_mtime_ns > self.mtime
):
self.load_file()
# reload (if changed)
self.load_file()

try:
path = Path(filename)
new_entries = extract_from_file(
self.importers[importer_name],
filename,
path,
existing_entries=self.ledger.all_entries,
)
except Exception as exc:
raise ImporterExtractError(importer_name, exc) from exc
raise ImporterExtractError from exc

new_entries_list = [(filename, new_entries)]
for hook_fn in self.hooks:
Expand Down
26 changes: 26 additions & 0 deletions stubs/beangulp/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import datetime
from abc import ABC
from abc import abstractmethod
from collections.abc import Sequence

from fava.beans.abc import Account
from fava.beans.abc import Directive

class Importer(ABC):
@property
def name(self) -> str: ...
@abstractmethod
def identify(self, filepath: str) -> bool: ...
@abstractmethod
def account(self, filepath: str) -> Account: ...
def date(self, filepath: str) -> datetime.date | None: ...
def filename(self, filepath: str) -> str | None: ...
def extract(
self, filepath: str, existing: Sequence[Directive]
) -> list[Directive]: ...
def deduplicate(
self, entries: list[Directive], existing: Sequence[Directive]
) -> None: ...
def sort(
self, entries: list[Directive], reverse: bool = False
) -> None: ...
3 changes: 3 additions & 0 deletions stubs/beangulp/cache.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Any

_CACHE: dict[str, Any]
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"currency": "USD",
"number": "10"
},
"date": "TODAY",
"date": "2022-12-12",
"diff_amount": null,
"meta": {
"__source__": "Balance",
Expand Down
Loading

0 comments on commit 862d1e0

Please sign in to comment.