Skip to content

Commit

Permalink
internal_api: add separate dataclasses for chart types
Browse files Browse the repository at this point in the history
  • Loading branch information
yagebu committed Feb 4, 2024
1 parent 71deea5 commit e4a869e
Showing 1 changed file with 85 additions and 64 deletions.
149 changes: 85 additions & 64 deletions src/fava/internal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any
from typing import TYPE_CHECKING

from flask import current_app
Expand All @@ -20,11 +19,16 @@

if TYPE_CHECKING: # pragma: no cover
from datetime import date
from typing import Literal

from fava.beans.abc import Meta
from fava.beans.abc import Query
from fava.core.accounts import AccountDict
from fava.core.charts import DateAndBalance
from fava.core.charts import DateAndBalanceWithBudget
from fava.core.extensions import ExtensionDetails
from fava.core.fava_options import FavaOptions
from fava.core.tree import SerialisedTreeNode
from fava.helpers import BeancountError
from fava.util.date import Interval

Expand Down Expand Up @@ -65,7 +69,7 @@ class LedgerData:
precisions: dict[str, int]
tags: list[str]
years: list[str]
user_queries: list[Any]
user_queries: list[Query]
upcoming_events_count: int
extensions: list[ExtensionDetails]
sidebar_links: list[tuple[str, str]]
Expand All @@ -77,7 +81,7 @@ def get_errors() -> list[SerialisedError]:
return [SerialisedError.from_beancount_error(e) for e in g.ledger.errors]


def _get_options() -> dict[str, Any]:
def _get_options() -> dict[str, str | list[str]]:
options = g.ledger.options
return {
"documents": options["documents"],
Expand Down Expand Up @@ -126,77 +130,94 @@ def get_ledger_data() -> LedgerData:


@dataclass(frozen=True)
class ChartData:
"""The common data format to pass charts to the frontend."""
class BalancesChart:
"""Data for a balances chart."""

type: str
label: str
data: Any


def _chart_interval_totals(
interval: Interval,
account_name: str | tuple[str, ...],
label: str | None = None,
*,
invert: bool = False,
) -> ChartData:
return ChartData(
"bar",
label or str(account_name),
g.ledger.charts.interval_totals(
g.filtered,
interval,
account_name,
g.conv,
invert=invert,
),
)
data: list[DateAndBalance]
type: Literal["balances"] = "balances"


def _chart_hierarchy(
account_name: str,
begin_date: date | None = None,
end_date: date | None = None,
label: str | None = None,
) -> ChartData:
return ChartData(
"hierarchy",
label or account_name,
g.ledger.charts.hierarchy(
g.filtered,
account_name,
g.conv,
begin_date,
end_date or g.filtered.end_date,
),
)
@dataclass(frozen=True)
class BarChart:
"""Data for a bar chart."""

label: str
data: list[DateAndBalanceWithBudget]
type: Literal["bar"] = "bar"

def _chart_net_worth() -> ChartData:
return ChartData(
"balances",
gettext("Net Worth"),
g.ledger.charts.net_worth(g.filtered, g.interval, g.conv),
)

@dataclass(frozen=True)
class HierarchyChart:
"""Data for a hierarchy chart."""

def _chart_account_balance(account_name: str) -> ChartData:
return ChartData(
"balances",
gettext("Account Balance"),
g.ledger.charts.linechart(
g.filtered,
account_name,
g.conv,
),
)
label: str
data: SerialisedTreeNode
type: Literal["hierarchy"] = "hierarchy"


if TYPE_CHECKING:
ChartData = BalancesChart | BarChart | HierarchyChart


class ChartApi:
"""Functions to generate chart data."""

account_balance = _chart_account_balance
hierarchy = _chart_hierarchy
interval_totals = _chart_interval_totals
net_worth = _chart_net_worth
@staticmethod
def account_balance(account_name: str) -> ChartData:
"""Generate data for an account balances chart."""
return BalancesChart(
gettext("Account Balance"),
g.ledger.charts.linechart(
g.filtered,
account_name,
g.conv,
),
)

@staticmethod
def hierarchy(
account_name: str,
begin_date: date | None = None,
end_date: date | None = None,
label: str | None = None,
) -> ChartData:
"""Generate data for an account hierarchy chart."""
return HierarchyChart(
label or account_name,
g.ledger.charts.hierarchy(
g.filtered,
account_name,
g.conv,
begin_date,
end_date or g.filtered.end_date,
),
)

@staticmethod
def interval_totals(
interval: Interval,
account_name: str | tuple[str, ...],
label: str | None = None,
*,
invert: bool = False,
) -> ChartData:
"""Generate data for an account per interval chart."""
return BarChart(
label or str(account_name),
g.ledger.charts.interval_totals(
g.filtered,
interval,
account_name,
g.conv,
invert=invert,
),
)

@staticmethod
def net_worth() -> ChartData:
"""Generate data for net worth chart."""
return BalancesChart(
gettext("Net Worth"),
g.ledger.charts.net_worth(g.filtered, g.interval, g.conv),
)

0 comments on commit e4a869e

Please sign in to comment.