Skip to content

Commit

Permalink
try out more complete Beancount types
Browse files Browse the repository at this point in the history
  • Loading branch information
yagebu committed Jan 25, 2025
1 parent 6dd9dac commit ad5412a
Show file tree
Hide file tree
Showing 29 changed files with 102 additions and 342 deletions.
59 changes: 7 additions & 52 deletions src/fava/beans/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Any
from typing import TYPE_CHECKING

from beancount.core import amount
from beancount.core import data
from beancount.core import position

Expand All @@ -28,61 +27,17 @@
Account = str


class Amount(ABC):
"""An amount in some currency."""

@property
@abstractmethod
def number(self) -> Decimal:
"""Number of units in the amount."""

@property
@abstractmethod
def currency(self) -> str:
"""Currency of the amount."""


Amount.register(amount.Amount)


class Cost(ABC):
"""A cost (basically an amount with date and label)."""

@property
@abstractmethod
def number(self) -> Decimal:
"""Number of units in the cost."""

@property
@abstractmethod
def currency(self) -> str:
"""Currency of the cost."""

@property
@abstractmethod
def date(self) -> datetime.date:
"""Date of the cost."""

@property
@abstractmethod
def label(self) -> str | None:
"""Label of the cost."""


Cost.register(position.Cost)


class Position(ABC):
"""A Beancount position - just cost and units."""

@property
@abstractmethod
def units(self) -> Amount:
def units(self) -> protocols.Amount:
"""Units of the posting."""

@property
@abstractmethod
def cost(self) -> Cost | None:
def cost(self) -> protocols.Cost | None:
"""Units of the position."""


Expand All @@ -99,17 +54,17 @@ def account(self) -> str:

@property
@abstractmethod
def units(self) -> Amount:
def units(self) -> protocols.Amount:
"""Units of the posting."""

@property
@abstractmethod
def cost(self) -> Cost | None:
def cost(self) -> protocols.Cost | None:
"""Units of the posting."""

@property
@abstractmethod
def price(self) -> Amount | None:
def price(self) -> protocols.Amount | None:
"""Price of the posting."""

@property
Expand Down Expand Up @@ -184,7 +139,7 @@ def account(self) -> str:

@property
@abstractmethod
def diff_amount(self) -> Amount | None:
def diff_amount(self) -> protocols.Amount | None:
"""Account of the directive."""


Expand Down Expand Up @@ -310,7 +265,7 @@ def currency(self) -> str:

@property
@abstractmethod
def amount(self) -> Amount:
def amount(self) -> protocols.Amount:
"""Price amount."""


Expand Down
51 changes: 34 additions & 17 deletions src/fava/beans/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,19 @@
from typing import TYPE_CHECKING

from beancount.core import data
from beancount.core.amount import ( # type: ignore[attr-defined]
A as BEANCOUNT_A,
)
from beancount.core.amount import A as BEANCOUNT_A
from beancount.core.amount import Amount as BeancountAmount
from beancount.core.position import Cost as BeancountCost
from beancount.core.position import Position as BeancountPosition

from fava.beans import BEANCOUNT_V3
from fava.beans.abc import Amount

if TYPE_CHECKING: # pragma: no cover
import datetime
from decimal import Decimal

from fava.beans.abc import Balance
from fava.beans.abc import Close
from fava.beans.abc import Cost
from fava.beans.abc import Document
from fava.beans.abc import Meta
from fava.beans.abc import Note
Expand All @@ -31,6 +27,8 @@
from fava.beans.abc import Posting
from fava.beans.abc import Transaction
from fava.beans.flags import Flag
from fava.beans.protocols import Amount
from fava.beans.protocols import Cost


@overload
Expand All @@ -47,10 +45,10 @@ def amount(amt: Decimal, currency: str) -> Amount: ... # pragma: no cover

def amount(amt: Amount | Decimal | str, currency: str | None = None) -> Amount:
"""Amount from a string or tuple."""
if isinstance(amt, Amount):
return amt
if isinstance(amt, str):
return BEANCOUNT_A(amt) # type: ignore[no-any-return]
return BEANCOUNT_A(amt) # type: ignore[return-value]
if hasattr(amt, "number") and hasattr(amt, "currency"):
return amt
if not isinstance(currency, str): # pragma: no cover
raise TypeError
return BeancountAmount(amt, currency) # type: ignore[return-value]
Expand All @@ -66,7 +64,7 @@ def cost(
label: str | None = None,
) -> Cost:
"""Create a Cost."""
return BeancountCost(number, currency, date, label) # type: ignore[return-value]
return BeancountCost(number, currency, date, label)


def position(units: Amount, cost: Cost | None) -> Position:
Expand Down Expand Up @@ -94,7 +92,7 @@ def posting(
cost, # type: ignore[arg-type]
price, # type: ignore[arg-type]
flag,
meta,
dict(meta) if meta is not None else None,
)


Expand All @@ -113,7 +111,7 @@ def transaction(
) -> Transaction:
"""Create a Beancount Transaction."""
return data.Transaction( # type: ignore[return-value]
meta,
dict(meta),
date,
flag,
payee,
Expand All @@ -134,7 +132,7 @@ def balance(
) -> Balance:
"""Create a Beancount Balance."""
return data.Balance( # type: ignore[return-value]
meta,
dict(meta),
date,
account,
_amount(amount), # type: ignore[arg-type]
Expand All @@ -150,7 +148,9 @@ def close(
) -> Close:
"""Create a Beancount Open."""
return data.Close( # type: ignore[return-value]
meta, date, account
dict(meta),
date,
account,
)


Expand All @@ -164,7 +164,12 @@ def document(
) -> Document:
"""Create a Beancount Document."""
return data.Document( # type: ignore[return-value]
meta, date, account, filename, tags, links
dict(meta),
date,
account,
filename,
tags,
links,
)


Expand All @@ -179,10 +184,18 @@ def note(
"""Create a Beancount Note."""
if not BEANCOUNT_V3: # pragma: no cover
return data.Note( # type: ignore[call-arg,return-value]
meta, date, account, comment
dict(meta),
date,
account,
comment,
)
return data.Note( # type: ignore[return-value]
meta, date, account, comment, tags, links
dict(meta),
date,
account,
comment,
tags,
links,
)


Expand All @@ -195,5 +208,9 @@ def open( # noqa: A001
) -> Open:
"""Create a Beancount Open."""
return data.Open( # type: ignore[return-value]
meta, date, account, currencies, booking
dict(meta),
date,
account,
currencies,
booking,
)
6 changes: 3 additions & 3 deletions src/fava/beans/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def load_string(value: str) -> LoaderResult:
"""Load a Beancoun string."""
return loader.load_string(value)
return loader.load_string(value) # type: ignore[return-value]


def load_uncached(
Expand All @@ -22,9 +22,9 @@ def load_uncached(
) -> LoaderResult:
"""Load a Beancount file."""
if is_encrypted: # pragma: no cover
return loader.load_file(beancount_file_path)
return loader.load_file(beancount_file_path) # type: ignore[return-value]

return loader._load( # type: ignore[attr-defined,no-any-return] # noqa: SLF001
return loader._load( # type: ignore[return-value] # noqa: SLF001
[(beancount_file_path, True)],
None,
None,
Expand Down
34 changes: 25 additions & 9 deletions src/fava/beans/str.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from functools import singledispatch
from typing import TYPE_CHECKING

from beancount.core import amount
from beancount.core import data
from beancount.core import position
from beancount.core.position import CostSpec
from beancount.parser.printer import format_entry

from fava.beans.abc import Amount
from fava.beans.abc import Cost
from fava.beans.abc import Directive
from fava.beans.abc import Position
from fava.beans.helpers import replace
Expand All @@ -22,22 +23,36 @@

@singledispatch
def to_string(
obj: Amount | Cost | CostSpec | Directive | Position,
obj: amount.Amount
| protocols.Amount
| protocols.Cost
| CostSpec
| Directive
| Position,
_currency_column: int | None = None,
_indent: int | None = None,
) -> str:
"""Convert to a string."""
number = getattr(obj, "number", None)
currency = getattr(obj, "currency", None)
if isinstance(number, Decimal) and isinstance(currency, str):
# The Amount and Cost protocols are ambigous, so handle this here
# instead of having this be dispatched
if hasattr(obj, "date"): # pragma: no cover
cost_to_string(obj) # type: ignore[arg-type]
return f"{number} {currency}"
msg = f"Unsupported object of type {type(obj)}"
raise TypeError(msg)


@to_string.register(Amount)
def _(obj: Amount) -> str:
@to_string.register(amount.Amount)
def amount_to_string(obj: amount.Amount | protocols.Amount) -> str:
"""Convert an amount to a string."""
return f"{obj.number} {obj.currency}"


@to_string.register(Cost)
def cost_to_string(cost: Cost | protocols.Cost) -> str:
@to_string.register(position.Cost)
def cost_to_string(cost: protocols.Cost | position.Cost) -> str:
"""Convert a cost to a string."""
res = f"{cost.number} {cost.currency}, {cost.date.isoformat()}"
return f'{res}, "{cost.label}"' if cost.label else res
Expand Down Expand Up @@ -69,7 +84,7 @@ def _(cost: CostSpec) -> str:

@to_string.register(Position)
def _(obj: Position) -> str:
units_str = to_string(obj.units)
units_str = amount_to_string(obj.units)
if obj.cost is None:
return units_str
cost_str = to_string(obj.cost)
Expand All @@ -86,7 +101,8 @@ def _format_entry(
key: entry.meta[key] for key in entry.meta if not key.startswith("_")
}
entry = replace(entry, meta=meta)
printed_entry = format_entry(entry, prefix=" " * indent) # type: ignore[arg-type]
assert isinstance(entry, data.ALL_DIRECTIVES) # noqa: S101
printed_entry = format_entry(entry, prefix=" " * indent)
string = align(printed_entry, currency_column)
string = string.replace("<class 'beancount.core.number.MISSING'>", "")
return "\n".join(line.rstrip() for line in string.split("\n"))
8 changes: 4 additions & 4 deletions src/fava/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def interval_balances(
interval_balances = [
Tree(
iter_entry_dates(
filtered.entries,
filtered.entries, # type: ignore[arg-type]
date.min if accumulate else date_range.begin,
date_range.end,
),
Expand Down Expand Up @@ -558,18 +558,18 @@ def context(
for posting in entry_.postings:
balance = balances.get(posting.account, None)
if balance is not None:
balance.add_position(posting)
balance.add_position(posting) # type: ignore[arg-type]

def visualise(inv: Inventory) -> Sequence[str]:
return [to_string(pos) for pos in sorted(inv)]
return [to_string(pos) for pos in sorted(iter(inv))]

before = {acc: visualise(inv) for acc, inv in balances.items()}

if isinstance(entry, Balance):
return entry, before, None, source_slice, sha256sum

for posting in entry.postings:
balances[posting.account].add_position(posting)
balances[posting.account].add_position(posting) # type: ignore[arg-type]
after = {acc: visualise(inv) for acc, inv in balances.items()}
return entry, before, after, source_slice, sha256sum

Expand Down
Loading

0 comments on commit ad5412a

Please sign in to comment.