From e894c9182f4d16f9a46ccb87bdaeca1a7dede040 Mon Sep 17 00:00:00 2001 From: blais Date: Wed, 19 Jun 2024 11:27:21 -0400 Subject: [PATCH] (Lint everything through `ruff format`.) --- beanprice/__init__.py | 1 + beanprice/date_utils.py | 12 +- beanprice/date_utils_test.py | 17 +- beanprice/net_utils.py | 4 +- beanprice/net_utils_test.py | 23 +- beanprice/price.py | 469 ++++++++++++-------- beanprice/price_test.py | 371 +++++++++------- beanprice/source.py | 24 +- beanprice/sources/__init__.py | 1 + beanprice/sources/alphavantage.py | 52 +-- beanprice/sources/alphavantage_test.py | 43 +- beanprice/sources/coinbase.py | 11 +- beanprice/sources/coinbase_test.py | 34 +- beanprice/sources/coinmarketcap.py | 31 +- beanprice/sources/coinmarketcap_test.py | 27 +- beanprice/sources/eastmoneyfund.py | 5 +- beanprice/sources/eastmoneyfund_test.py | 1 - beanprice/sources/iex.py | 18 +- beanprice/sources/iex_test.py | 23 +- beanprice/sources/oanda.py | 39 +- beanprice/sources/oanda_test.py | 125 +++--- beanprice/sources/quandl.py | 40 +- beanprice/sources/quandl_test.py | 241 +++++----- beanprice/sources/ratesapi.py | 25 +- beanprice/sources/ratesapi_test.py | 28 +- beanprice/sources/tsp.py | 112 ++--- beanprice/sources/tsp_test.py | 76 ++-- beanprice/sources/yahoo.py | 134 +++--- experiments/dividends/download_dividends.py | 53 ++- 29 files changed, 1134 insertions(+), 906 deletions(-) diff --git a/beanprice/__init__.py b/beanprice/__init__.py index 0fd8ef8..2913cb8 100644 --- a/beanprice/__init__.py +++ b/beanprice/__init__.py @@ -137,5 +137,6 @@ historical data for its CURRENCY:* instruments. """ + __copyright__ = "Copyright (C) 2015-2020 Martin Blais" __license__ = "GNU GPLv2" diff --git a/beanprice/date_utils.py b/beanprice/date_utils.py index 3bd7285..fc1bbd4 100644 --- a/beanprice/date_utils.py +++ b/beanprice/date_utils.py @@ -1,5 +1,5 @@ -"""Date utilities. -""" +"""Date utilities.""" + __copyright__ = "Copyright (C) 2020 Martin Blais" __license__ = "GNU GPLv2" @@ -39,14 +39,14 @@ def intimezone(tz_value: str): Returns: A contextmanager in the given timezone locale. """ - tz_old = os.environ.get('TZ', None) - os.environ['TZ'] = tz_value + tz_old = os.environ.get("TZ", None) + os.environ["TZ"] = tz_value time.tzset() try: yield finally: if tz_old is None: - del os.environ['TZ'] + del os.environ["TZ"] else: - os.environ['TZ'] = tz_old + os.environ["TZ"] = tz_old time.tzset() diff --git a/beanprice/date_utils_test.py b/beanprice/date_utils_test.py index 054c16d..2271cb9 100644 --- a/beanprice/date_utils_test.py +++ b/beanprice/date_utils_test.py @@ -9,17 +9,16 @@ class TestDateUtils(unittest.TestCase): - def test_parse_date_liberally(self): const_date = datetime.date(2014, 12, 7) test_cases = ( - ('12/7/2014',), - ('7-Dec-2014',), - ('7/12/2014', {'parserinfo': dateutil.parser.parserinfo(dayfirst=True)}), - ('12/7', {'default': datetime.datetime(2014, 1, 1)}), - ('7.12.2014', {'dayfirst': True}), - ('14 12 7', {'yearfirst': True}), - ('Transaction of 7th December 2014', {'fuzzy': True}), + ("12/7/2014",), + ("7-Dec-2014",), + ("7/12/2014", {"parserinfo": dateutil.parser.parserinfo(dayfirst=True)}), + ("12/7", {"default": datetime.datetime(2014, 1, 1)}), + ("7.12.2014", {"dayfirst": True}), + ("14 12 7", {"yearfirst": True}), + ("Transaction of 7th December 2014", {"fuzzy": True}), ) for case in test_cases: if len(case) == 2: @@ -40,5 +39,5 @@ def test_intimezone(self): self.assertNotEqual(now_tokyo, now_nyc) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/beanprice/net_utils.py b/beanprice/net_utils.py index d11f8d9..782db00 100644 --- a/beanprice/net_utils.py +++ b/beanprice/net_utils.py @@ -1,5 +1,5 @@ -"""Network utilities. -""" +"""Network utilities.""" + __copyright__ = "Copyright (C) 2015-2016 Martin Blais" __license__ = "GNU GPLv2" diff --git a/beanprice/net_utils_test.py b/beanprice/net_utils_test.py index 2cea9b4..de8d3d5 100644 --- a/beanprice/net_utils_test.py +++ b/beanprice/net_utils_test.py @@ -9,30 +9,29 @@ class TestRetryingUrlopen(unittest.TestCase): - def test_success_200(self): response = http.client.HTTPResponse(mock.MagicMock()) response.status = 200 - with mock.patch('urllib.request.urlopen', return_value=response): - self.assertIs(net_utils.retrying_urlopen('http://nowhere.com'), response) + with mock.patch("urllib.request.urlopen", return_value=response): + self.assertIs(net_utils.retrying_urlopen("http://nowhere.com"), response) def test_success_other(self): response = http.client.HTTPResponse(mock.MagicMock()) - with mock.patch('urllib.request.urlopen', return_value=response): - self.assertIsNone(net_utils.retrying_urlopen('http://nowhere.com')) + with mock.patch("urllib.request.urlopen", return_value=response): + self.assertIsNone(net_utils.retrying_urlopen("http://nowhere.com")) def test_timeout_once(self): response = http.client.HTTPResponse(mock.MagicMock()) response.status = 200 - with mock.patch('urllib.request.urlopen', - side_effect=[None, response]): - self.assertIs(net_utils.retrying_urlopen('http://nowhere.com'), response) + with mock.patch("urllib.request.urlopen", side_effect=[None, response]): + self.assertIs(net_utils.retrying_urlopen("http://nowhere.com"), response) def test_max_retry(self): - with mock.patch('urllib.request.urlopen', - side_effect=[None, None, None, None, None, None]): - self.assertIsNone(net_utils.retrying_urlopen('http://nowhere.com')) + with mock.patch( + "urllib.request.urlopen", side_effect=[None, None, None, None, None, None] + ): + self.assertIsNone(net_utils.retrying_urlopen("http://nowhere.com")) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/beanprice/price.py b/beanprice/price.py index 8ac3871..a467bdc 100644 --- a/beanprice/price.py +++ b/beanprice/price.py @@ -1,5 +1,5 @@ -"""Driver code for the price script. -""" +"""Driver code for the price script.""" + __copyright__ = "Copyright (C) 2015-2020 Martin Blais" __license__ = "GNU GPLv2" @@ -63,22 +63,22 @@ class DatedPrice(NamedTuple): # The Python package where the default sources are found. -DEFAULT_PACKAGE = 'beanprice.sources' +DEFAULT_PACKAGE = "beanprice.sources" # Stand-in currency name for unknown currencies. -UNKNOWN_CURRENCY = '?' +UNKNOWN_CURRENCY = "?" # A cache for the prices. _CACHE = None # Expiration for latest prices in the cache. -DEFAULT_EXPIRATION = datetime.timedelta(seconds=30*60) # 30 mins. +DEFAULT_EXPIRATION = datetime.timedelta(seconds=30 * 60) # 30 mins. # The default source parser is back. -DEFAULT_SOURCE = 'beanprice.sources.yahoo' +DEFAULT_SOURCE = "beanprice.sources.yahoo" def format_dated_price_str(dprice: DatedPrice) -> str: @@ -89,15 +89,16 @@ def format_dated_price_str(dprice: DatedPrice) -> str: Returns: The string for a DatedPrice instance. """ - psstrs = ['{}({}{})'.format(psource.module.__name__, - '1/' if psource.invert else '', - psource.symbol) - for psource in dprice.sources] - base_quote = '{} /{}'.format(dprice.base, dprice.quote) - return '{:<32} @ {:10} [ {} ]'.format( - base_quote, - dprice.date.isoformat() if dprice.date else 'latest', - ','.join(psstrs)) + psstrs = [ + "{}({}{})".format( + psource.module.__name__, "1/" if psource.invert else "", psource.symbol + ) + for psource in dprice.sources + ] + base_quote = "{} /{}".format(dprice.base, dprice.quote) + return "{:<32} @ {:10} [ {} ]".format( + base_quote, dprice.date.isoformat() if dprice.date else "latest", ",".join(psstrs) + ) def parse_source_map(source_map_spec: str) -> Dict[str, List[PriceSource]]: @@ -133,16 +134,15 @@ def parse_source_map(source_map_spec: str) -> Dict[str, List[PriceSource]]: ValueError: If an invalid pattern has been specified. """ source_map: Dict[str, List[PriceSource]] = collections.defaultdict(list) - for source_list_spec in re.split('[ ;]', source_map_spec): - match = re.match('({}):(.*)$'.format(amount.CURRENCY_RE), - source_list_spec) + for source_list_spec in re.split("[ ;]", source_map_spec): + match = re.match("({}):(.*)$".format(amount.CURRENCY_RE), source_list_spec) if not match: raise ValueError('Invalid source map pattern: "{}"'.format(source_list_spec)) currency, source_strs = match.groups() source_map[currency].extend( - parse_single_source(source_str) - for source_str in source_strs.split(',')) + parse_single_source(source_str) for source_str in source_strs.split(",") + ) return source_map @@ -163,7 +163,7 @@ def parse_single_source(source: str) -> PriceSource: Raises: ValueError: If invalid. """ - match = re.match(r'([a-zA-Z]+[a-zA-Z0-9\._]+)/(\^?)([a-zA-Z0-9:=_\-\.\(\)]+)$', source) + match = re.match(r"([a-zA-Z]+[a-zA-Z0-9\._]+)/(\^?)([a-zA-Z0-9:=_\-\.\(\)]+)$", source) if not match: raise ValueError('Invalid source name: "{}"'.format(source)) short_module_name, invert, symbol = match.groups() @@ -184,7 +184,7 @@ def import_source(module_name: str): Raises: ImportError: If the module cannot be imported. """ - default_name = '{}.{}'.format(DEFAULT_PACKAGE, module_name) + default_name = "{}.{}".format(DEFAULT_PACKAGE, module_name) try: __import__(default_name) return sys.modules[default_name] @@ -193,13 +193,14 @@ def import_source(module_name: str): __import__(module_name) return sys.modules[module_name] except ImportError as exc: - raise ImportError('Could not find price source module "{}"'.format( - module_name)) from exc + raise ImportError( + 'Could not find price source module "{}"'.format(module_name) + ) from exc def find_currencies_declared( - entries: data.Entries, - date: datetime.date = None) -> List[Tuple[str, str, List[PriceSource]]]: + entries: data.Entries, date: datetime.date = None +) -> List[Tuple[str, str, List[PriceSource]]]: """Return currencies declared in Commodity directives. If a 'price' metadata field is provided, include all the quote currencies @@ -226,17 +227,21 @@ def find_currencies_declared( # First, we look for a "price" metadata field, which defines conversions # for various currencies. Each of these quote currencies generates a # pair in the output. - source_str = entry.meta.get('price', None) + source_str = entry.meta.get("price", None) if source_str is not None: if source_str == "": - logging.debug("Skipping ignored currency (with empty price): %s", - entry.currency) + logging.debug( + "Skipping ignored currency (with empty price): %s", entry.currency + ) continue try: source_map = parse_source_map(source_str) except ValueError as exc: - logging.warning("Ignoring currency with invalid 'price' source: %s (%s)", - entry.currency, exc) + logging.warning( + "Ignoring currency with invalid 'price' source: %s (%s)", + entry.currency, + exc, + ) else: for quote, psources in source_map.items(): currencies.append((entry.currency, quote, psources)) @@ -258,13 +263,15 @@ def log_currency_list(message, currencies): """ logging.debug("-------- {}:".format(message)) for base, quote in currencies: - logging.debug(" {:>32}".format('{} /{}'.format(base, quote))) + logging.debug(" {:>32}".format("{} /{}".format(base, quote))) -def get_price_jobs_at_date(entries: data.Entries, - date: Optional[datetime.date] = None, - inactive: bool = False, - undeclared_source: Optional[str] = None): +def get_price_jobs_at_date( + entries: data.Entries, + date: Optional[datetime.date] = None, + inactive: bool = False, + undeclared_source: Optional[str] = None, +): """Get a list of prices to fetch from a stream of entries. The active holdings held on the given date are included. @@ -285,8 +292,7 @@ def get_price_jobs_at_date(entries: data.Entries, # tickers for each (base, quote) pair. This is the only place tickers # appear. declared_triples = find_currencies_declared(entries, date) - currency_map = {(base, quote): psources - for base, quote, psources in declared_triples} + currency_map = {(base, quote): psources for base, quote, psources in declared_triples} # Compute the initial list of currencies to consider. if undeclared_source: @@ -333,12 +339,14 @@ def get_price_jobs_at_date(entries: data.Entries, # or perhaps to extend it to intervals, and let the price source decide for # itself how to implement fetching (e.g., use a single call + filter, or use # multiple calls). Querying independently for each day is not the best strategy. -def get_price_jobs_up_to_date(entries, - date_last=None, - inactive=False, - undeclared_source=None, - update_rate='weekday', - compress_days=1): +def get_price_jobs_up_to_date( + entries, + date_last=None, + inactive=False, + undeclared_source=None, + update_rate="weekday", + compress_days=1, +): """Get a list of trailing prices to fetch from a stream of entries. The list of dates runs from the latest available price up to the latest date. @@ -360,8 +368,7 @@ def get_price_jobs_up_to_date(entries, # tickers for each (base, quote) pair. This is the only place tickers # appear. declared_triples = find_currencies_declared(entries, date_last) - currency_map = {(base, quote): psources - for base, quote, psources in declared_triples} + currency_map = {(base, quote): psources for base, quote, psources in declared_triples} # Compute the initial list of currencies to consider. if undeclared_source: @@ -381,7 +388,6 @@ def get_price_jobs_up_to_date(entries, log_currency_list("Currencies in primary list", currencies) - # By default, restrict to only the currencies with non-zero balances # up to the given date. # Also, find the earliest start date to fetch prices from. @@ -408,16 +414,15 @@ def get_price_jobs_up_to_date(entries, for base_quote in lifetimes_map: intervals = lifetimes_map[base_quote] result = prices.get_latest_price(price_map, base_quote) - if (result is None or result[0] is None): + if result is None or result[0] is None: lifetimes_map[base_quote] = lifetimes.trim_intervals(intervals, None, date_last) else: latest_price_date = result[0] date_first = latest_price_date + datetime.timedelta(days=1) if date_first < date_last: - lifetimes_map[base_quote] = \ - lifetimes.trim_intervals(intervals, - date_first, - date_last) + lifetimes_map[base_quote] = lifetimes.trim_intervals( + intervals, date_first, date_last + ) else: # We don't need to update if we're already up to date. lifetimes_map[base_quote] = [] @@ -430,22 +435,18 @@ def get_price_jobs_up_to_date(entries, del lifetimes_map[key] # Create price jobs based on fetch rate - if update_rate == 'daily': + if update_rate == "daily": required_prices = lifetimes.required_daily_prices( - lifetimes_map, - date_last, - weekdays_only=False) - elif update_rate == 'weekday': + lifetimes_map, date_last, weekdays_only=False + ) + elif update_rate == "weekday": required_prices = lifetimes.required_daily_prices( - lifetimes_map, - date_last, - weekdays_only=True) - elif update_rate == 'weekly': - required_prices = lifetimes.required_weekly_prices( - lifetimes_map, - date_last) + lifetimes_map, date_last, weekdays_only=True + ) + elif update_rate == "weekly": + required_prices = lifetimes.required_weekly_prices(lifetimes_map, date_last) else: - raise ValueError('Invalid Update Rate') + raise ValueError("Invalid Update Rate") jobs = [] # Build up the list of jobs to fetch prices for. @@ -491,15 +492,17 @@ def fetch_cached_price(source, symbol, date): if _CACHE is None: # The cache is disabled; just call and return. - result = (source.get_latest_price(symbol) - if time is None else - source.get_historical_price(symbol, time)) + result = ( + source.get_latest_price(symbol) + if time is None + else source.get_historical_price(symbol, time) + ) else: # The cache is enabled and we have to compute the current/latest price. # Try to fetch from the cache but miss if the price is too old. md5 = hashlib.md5() - md5.update(str((type(source).__module__, symbol, date)).encode('utf-8')) + md5.update(str((type(source).__module__, symbol, date)).encode("utf-8")) key = md5.hexdigest() timestamp_now = int(now().timestamp()) try: @@ -510,7 +513,8 @@ def fetch_cached_price(source, symbol, date): # aware datetime objects cannot be serialized properly due to bug.) if result_naive.time is not None: result = result_naive._replace( - time=result_naive.time.replace(tzinfo=tz.tzutc())) + time=result_naive.time.replace(tzinfo=tz.tzutc()) + ) else: result = result_naive @@ -519,9 +523,11 @@ def fetch_cached_price(source, symbol, date): except KeyError: logging.info("Fetching: %s (time: %s)", symbol, time) try: - result = (source.get_latest_price(symbol) - if time is None else - source.get_historical_price(symbol, time)) + result = ( + source.get_latest_price(symbol) + if time is None + else source.get_historical_price(symbol, time) + ) except ValueError as exc: logging.error("Error fetching %s: %s", symbol, exc) result = None @@ -550,17 +556,16 @@ def setup_cache(cache_filename: Optional[str], clear_cache: bool): if not cache_filename: return - logging.info('Using price cache at "%s" (with indefinite expiration)', - cache_filename) + logging.info('Using price cache at "%s" (with indefinite expiration)', cache_filename) - flag = 'c' + flag = "c" if clear_cache and cache_filename: logging.info("Clearing cache %s*", cache_filename) - flag = 'n' + flag = "n" global _CACHE _CACHE = shelve.open(cache_filename, flag=flag) - _CACHE.expiration = DEFAULT_EXPIRATION # type: ignore + _CACHE.expiration = DEFAULT_EXPIRATION # type: ignore def reset_cache(): @@ -604,10 +609,10 @@ def fetch_price(dprice: DatedPrice, swap_inverted: bool = False) -> Optional[dat if swap_inverted: base, quote = quote, base else: - price = ONE/price + price = ONE / price assert base is not None - fileloc = data.new_metadata('<{}>'.format(type(psource.module).__name__), 0) + fileloc = data.new_metadata("<{}>".format(type(psource.module).__name__), 0) # The datetime instance is required to be aware. We always convert to the # user's timezone before extracting the date. This means that if the market @@ -621,14 +626,12 @@ def fetch_price(dprice: DatedPrice, swap_inverted: bool = False) -> Optional[dat raise ValueError("Time returned by the price source is not timezone aware.") date = srctime.astimezone(tz.tzlocal()).date() - return data.Price(fileloc, date, base, - amount.Amount(price, quote or UNKNOWN_CURRENCY)) + return data.Price(fileloc, date, base, amount.Amount(price, quote or UNKNOWN_CURRENCY)) def filter_redundant_prices( - price_entries: List[data.Price], - existing_entries: List[data.Price], - diffs: bool = False) -> Tuple[List[data.Price], List[data.Price]]: + price_entries: List[data.Price], existing_entries: List[data.Price], diffs: bool = False +) -> Tuple[List[data.Price], List[data.Price]]: """Filter out new entries that are redundant from an existing set. If the price differs, we override it with the new entry only on demand. This @@ -646,9 +649,11 @@ def filter_redundant_prices( # Note: We have to be careful with the dates, because requesting the latest # price for a date may yield the price at a previous date. Clobber needs to # take this into account. See {1cfa25e37fc1}. - existing_prices = {(entry.date, entry.currency): entry - for entry in existing_entries - if isinstance(entry, data.Price)} + existing_prices = { + (entry.date, entry.currency): entry + for entry in existing_entries + if isinstance(entry, data.Price) + } filtered_prices: List[data.Price] = [] ignored_prices: List[data.Price] = [] for entry in price_entries: @@ -679,84 +684,174 @@ def process_args() -> Tuple[Any, List[DatedPrice], List[data.Price], Any]: parser = argparse.ArgumentParser(description=beanprice.__doc__.splitlines()[0]) # Input sources or filenames. - parser.add_argument('sources', nargs='+', help=( - 'A list of filenames (or source "module/symbol", if -e is ' - 'specified) from which to create a list of jobs.')) - - parser.add_argument('-e', '--expressions', '--expression', action='store_true', help=( - 'Interpret the arguments as "module/symbol" source strings.')) + parser.add_argument( + "sources", + nargs="+", + help=( + 'A list of filenames (or source "module/symbol", if -e is ' + "specified) from which to create a list of jobs." + ), + ) + + parser.add_argument( + "-e", + "--expressions", + "--expression", + action="store_true", + help=('Interpret the arguments as "module/symbol" source strings.'), + ) # Regular options. - parser.add_argument('-v', '--verbose', action='count', help=( - "Print out progress log. Specify twice for debugging info.")) - - parser.add_argument('-d', '--date', action='store', - type=date_utils.parse_date_liberally, help=( - "Specify the date for which to fetch the prices.")) - - parser.add_argument('--update', action='store_true', help=( - "Fetch prices from most recent price for each source " - "up to present day or specified --date. See also " - "--update-rate, --update-compress options.")) - - parser.add_argument('--update-rate', choices=['daily', 'weekday', 'weekly'], - default='weekday', help=( - "Specify how often dates are fetched. Options are daily, weekday, or weekly " - "(fridays)")) - - parser.add_argument('--update-compress', action='store', type=int, default=0, help=( - "Specify the number of inactive days to ignore. This option ignored if --inactive " - "used.")) - - parser.add_argument('-i', '--inactive', action='store_true', help=( - "Select all commodities from input files, not just the ones active on the date")) - - parser.add_argument('-u', '--undeclared', action='store_true', help=( - "Include commodities viewed in the file even without a " - "corresponding Commodity directive, from this default source. " - "The currency name itself is used as the lookup symbol in this default source.")) - - parser.add_argument('-c', '--clobber', action='store_true', help=( - "Do not skip prices which are already present in input files; fetch them anyway.")) - - parser.add_argument('-a', '--all', action='store_true', help=( - "A shorthand for --inactive, --undeclared, --clobber.")) - - parser.add_argument('-s', '--swap-inverted', action='store_true', help=( - "For inverted sources, swap currencies instead of inverting the rate. " - "For example, if fetching the rate for CAD from 'USD:google/^CURRENCY:USDCAD' " - "results in 1.25, by default we would output \"price CAD 0.8000 USD\". " - "Using this option we would instead output \" price USD 1.2500 CAD\".")) - - parser.add_argument('-w', '--workers', action='store', type=int, default=1, help=( - "Specify the number of concurrent fetchers.")) - - parser.add_argument('-n', '--dry-run', action='store_true', help=( - "Don't actually fetch the prices, just print the list of the ones to be fetched.")) + parser.add_argument( + "-v", + "--verbose", + action="count", + help=("Print out progress log. Specify twice for debugging info."), + ) + + parser.add_argument( + "-d", + "--date", + action="store", + type=date_utils.parse_date_liberally, + help=("Specify the date for which to fetch the prices."), + ) + + parser.add_argument( + "--update", + action="store_true", + help=( + "Fetch prices from most recent price for each source " + "up to present day or specified --date. See also " + "--update-rate, --update-compress options." + ), + ) + + parser.add_argument( + "--update-rate", + choices=["daily", "weekday", "weekly"], + default="weekday", + help=( + "Specify how often dates are fetched. Options are daily, weekday, or weekly " + "(fridays)" + ), + ) + + parser.add_argument( + "--update-compress", + action="store", + type=int, + default=0, + help=( + "Specify the number of inactive days to ignore. This option ignored if " + "--inactive used." + ), + ) + + parser.add_argument( + "-i", + "--inactive", + action="store_true", + help=( + "Select all commodities from input files, not just the ones active on the date" + ), + ) + + parser.add_argument( + "-u", + "--undeclared", + action="store_true", + help=( + "Include commodities viewed in the file even without a " + "corresponding Commodity directive, from this default source. " + "The currency name itself is used as the lookup symbol in this default source." + ), + ) + + parser.add_argument( + "-c", + "--clobber", + action="store_true", + help=( + "Do not skip prices which are already present in input files; " + "fetch them anyway." + ), + ) + + parser.add_argument( + "-a", + "--all", + action="store_true", + help=("A shorthand for --inactive, --undeclared, --clobber."), + ) + + parser.add_argument( + "-s", + "--swap-inverted", + action="store_true", + help=( + "For inverted sources, swap currencies instead of inverting the rate. " + "For example, if fetching the rate for CAD from 'USD:google/^CURRENCY:USDCAD' " + 'results in 1.25, by default we would output "price CAD 0.8000 USD". ' + 'Using this option we would instead output " price USD 1.2500 CAD".' + ), + ) + + parser.add_argument( + "-w", + "--workers", + action="store", + type=int, + default=1, + help=("Specify the number of concurrent fetchers."), + ) + + parser.add_argument( + "-n", + "--dry-run", + action="store_true", + help=( + "Don't actually fetch the prices, just print the list of the ones " + "to be fetched." + ), + ) # Caching options. - cache_group = parser.add_argument_group('cache') - cache_filename = path.join(tempfile.gettempdir(), - "{}.cache".format(path.basename(sys.argv[0]))) - cache_group.add_argument('--cache', dest='cache_filename', - action='store', default=cache_filename, - help="The base filename for the underlying price cache " - "database. An extension may be added to the filename and " - "more than one file may be created.") - cache_group.add_argument('--no-cache', dest='cache_filename', - action='store_const', const=None, - help="Disable the price cache.") - cache_group.add_argument('--clear-cache', action='store_true', - help="Clear the cache prior to startup.") + cache_group = parser.add_argument_group("cache") + cache_filename = path.join( + tempfile.gettempdir(), "{}.cache".format(path.basename(sys.argv[0])) + ) + cache_group.add_argument( + "--cache", + dest="cache_filename", + action="store", + default=cache_filename, + help="The base filename for the underlying price cache " + "database. An extension may be added to the filename and " + "more than one file may be created.", + ) + cache_group.add_argument( + "--no-cache", + dest="cache_filename", + action="store_const", + const=None, + help="Disable the price cache.", + ) + cache_group.add_argument( + "--clear-cache", action="store_true", help="Clear the cache prior to startup." + ) args = parser.parse_args() - verbose_levels = {None: logging.WARN, - 0: logging.WARN, - 1: logging.INFO, - 2: logging.DEBUG} - logging.basicConfig(level=verbose_levels[args.verbose], - format='%(levelname)-8s: %(message)s') + verbose_levels = { + None: logging.WARN, + 0: logging.WARN, + 1: logging.INFO, + 2: logging.DEBUG, + } + logging.basicConfig( + level=verbose_levels[args.verbose], format="%(levelname)-8s: %(message)s" + ) if args.undeclared: args.undeclared = DEFAULT_SOURCE @@ -782,52 +877,62 @@ def process_args() -> Tuple[Any, List[DatedPrice], List[data.Price], Any]: try: psource_map = parse_source_map(source_str) except ValueError: - extra = "; did you provide a filename?" if path.exists(source_str) else '' - msg = ('Invalid source "{{}}"{}. '.format(extra) + - 'Supported format is "CCY:module/SYMBOL"') + extra = "; did you provide a filename?" if path.exists(source_str) else "" + msg = ( + 'Invalid source "{{}}"{}. '.format(extra) + + 'Supported format is "CCY:module/SYMBOL"' + ) parser.error(msg.format(source_str)) else: for currency, psources in psource_map.items(): for date in dates: - jobs.append(DatedPrice( - psources[0].symbol, currency, date, psources)) + jobs.append( + DatedPrice(psources[0].symbol, currency, date, psources) + ) elif args.update: # Use Beancount input filename sources to create # prices jobs up to present time. for filename in args.sources: if not path.exists(filename) or not path.isfile(filename): - parser.error('File does not exist: "{}"; ' - 'did you mean to use -e?'.format(filename)) + parser.error( + 'File does not exist: "{}"; ' "did you mean to use -e?".format(filename) + ) continue logging.info('Loading "%s"', filename) entries, errors, options_map = loader.load_file(filename, log_errors=sys.stderr) if dcontext is None: - dcontext = options_map['dcontext'] + dcontext = options_map["dcontext"] if args.date is None: latest_date = datetime.date.today() else: latest_date = args.date - jobs.extend(get_price_jobs_up_to_date(entries, - latest_date, - args.inactive, - args.undeclared, - args.update_rate, - args.update_compress)) + jobs.extend( + get_price_jobs_up_to_date( + entries, + latest_date, + args.inactive, + args.undeclared, + args.update_rate, + args.update_compress, + ) + ) all_entries.extend(entries) else: # Interpret the arguments as Beancount input filenames. for filename in args.sources: if not path.exists(filename) or not path.isfile(filename): - parser.error('File does not exist: "{}"; ' - 'did you mean to use -e?'.format(filename)) + parser.error( + 'File does not exist: "{}"; ' "did you mean to use -e?".format(filename) + ) continue logging.info('Loading "%s"', filename) entries, errors, options_map = loader.load_file(filename, log_errors=sys.stderr) if dcontext is None: - dcontext = options_map['dcontext'] + dcontext = options_map["dcontext"] for date in dates: - jobs.extend(get_price_jobs_at_date( - entries, date, args.inactive, args.undeclared)) + jobs.extend( + get_price_jobs_at_date(entries, date, args.inactive, args.undeclared) + ) all_entries.extend(entries) return args, jobs, data.sorted(all_entries), dcontext @@ -844,8 +949,12 @@ def main(): # Fetch all the required prices, processing all the jobs. executor = futures.ThreadPoolExecutor(max_workers=args.workers) - price_entries = filter(None, executor.map( - functools.partial(fetch_price, swap_inverted=args.swap_inverted), jobs)) + price_entries = filter( + None, + executor.map( + functools.partial(fetch_price, swap_inverted=args.swap_inverted), jobs + ), + ) # Sort them by currency, regardless of date (the dates should be close # anyhow, and we tend to put them in chunks in the input files anyhow). diff --git a/beanprice/price_test.py b/beanprice/price_test.py index 944d3cf..0668358 100644 --- a/beanprice/price_test.py +++ b/beanprice/price_test.py @@ -1,5 +1,5 @@ -"""Tests for main driver for price fetching. -""" +"""Tests for main driver for price fetching.""" + __copyright__ = "Copyright (C) 2015-2020 Martin Blais" __license__ = "GNU GPLv2" @@ -56,36 +56,36 @@ def run_with_args(function, args, runner_file=None): class TestCache(unittest.TestCase): - def test_fetch_cached_price__disabled(self): # Latest. - with mock.patch('beanprice.price._CACHE', None): + with mock.patch("beanprice.price._CACHE", None): self.assertIsNone(price._CACHE) source = mock.MagicMock() - price.fetch_cached_price(source, 'HOOL', None) + price.fetch_cached_price(source, "HOOL", None) self.assertTrue(source.get_latest_price.called) # Historical. - with mock.patch('beanprice.price._CACHE', None): + with mock.patch("beanprice.price._CACHE", None): self.assertIsNone(price._CACHE) source = mock.MagicMock() - price.fetch_cached_price(source, 'HOOL', datetime.date.today()) + price.fetch_cached_price(source, "HOOL", datetime.date.today()) self.assertTrue(source.get_historical_price.called) def test_fetch_cached_price__latest(self): tmpdir = tempfile.mkdtemp() - tmpfile = path.join(tmpdir, 'prices.cache') + tmpfile = path.join(tmpdir, "prices.cache") try: price.setup_cache(tmpfile, False) - srcprice = SourcePrice(Decimal('1.723'), datetime.datetime.now(tz.tzutc()), - 'USD') + srcprice = SourcePrice( + Decimal("1.723"), datetime.datetime.now(tz.tzutc()), "USD" + ) source = mock.MagicMock() source.get_latest_price.return_value = srcprice - source.__file__ = '' + source.__file__ = "" # Cache miss. - result = price.fetch_cached_price(source, 'HOOL', None) + result = price.fetch_cached_price(source, "HOOL", None) self.assertTrue(source.get_latest_price.called) self.assertEqual(1, len(price._CACHE)) self.assertEqual(srcprice, result) @@ -93,20 +93,21 @@ def test_fetch_cached_price__latest(self): source.get_latest_price.reset_mock() # Cache hit. - result = price.fetch_cached_price(source, 'HOOL', None) + result = price.fetch_cached_price(source, "HOOL", None) self.assertFalse(source.get_latest_price.called) self.assertEqual(1, len(price._CACHE)) self.assertEqual(srcprice, result) srcprice2 = SourcePrice( - Decimal('1.894'), datetime.datetime.now(tz.tzutc()), 'USD') + Decimal("1.894"), datetime.datetime.now(tz.tzutc()), "USD" + ) source.get_latest_price.reset_mock() source.get_latest_price.return_value = srcprice2 # Cache expired. time_beyond = datetime.datetime.now() + price._CACHE.expiration * 2 - with mock.patch('beanprice.price.now', return_value=time_beyond): - result = price.fetch_cached_price(source, 'HOOL', None) + with mock.patch("beanprice.price.now", return_value=time_beyond): + result = price.fetch_cached_price(source, "HOOL", None) self.assertTrue(source.get_latest_price.called) self.assertEqual(1, len(price._CACHE)) self.assertEqual(srcprice2, result) @@ -117,18 +118,19 @@ def test_fetch_cached_price__latest(self): def test_fetch_cached_price__clear_cache(self): tmpdir = tempfile.mkdtemp() - tmpfile = path.join(tmpdir, 'prices.cache') + tmpfile = path.join(tmpdir, "prices.cache") try: price.setup_cache(tmpfile, False) - srcprice = SourcePrice(Decimal('1.723'), datetime.datetime.now(tz.tzutc()), - 'USD') + srcprice = SourcePrice( + Decimal("1.723"), datetime.datetime.now(tz.tzutc()), "USD" + ) source = mock.MagicMock() source.get_latest_price.return_value = srcprice - source.__file__ = '' + source.__file__ = "" # Cache miss. - result = price.fetch_cached_price(source, 'HOOL', None) + result = price.fetch_cached_price(source, "HOOL", None) self.assertTrue(source.get_latest_price.called) self.assertEqual(1, len(price._CACHE)) self.assertEqual(srcprice, result) @@ -136,13 +138,14 @@ def test_fetch_cached_price__clear_cache(self): source.get_latest_price.reset_mock() # Cache hit. - result = price.fetch_cached_price(source, 'HOOL', None) + result = price.fetch_cached_price(source, "HOOL", None) self.assertFalse(source.get_latest_price.called) self.assertEqual(1, len(price._CACHE)) self.assertEqual(srcprice, result) srcprice2 = SourcePrice( - Decimal('1.894'), datetime.datetime.now(tz.tzutc()), 'USD') + Decimal("1.894"), datetime.datetime.now(tz.tzutc()), "USD" + ) source.get_latest_price.reset_mock() source.get_latest_price.return_value = srcprice2 @@ -151,7 +154,7 @@ def test_fetch_cached_price__clear_cache(self): price.setup_cache(tmpfile, True) # Cache cleared. - result = price.fetch_cached_price(source, 'HOOL', None) + result = price.fetch_cached_price(source, "HOOL", None) self.assertTrue(source.get_latest_price.called) self.assertEqual(1, len(price._CACHE)) self.assertEqual(srcprice2, result) @@ -162,19 +165,20 @@ def test_fetch_cached_price__clear_cache(self): def test_fetch_cached_price__historical(self): tmpdir = tempfile.mkdtemp() - tmpfile = path.join(tmpdir, 'prices.cache') + tmpfile = path.join(tmpdir, "prices.cache") try: price.setup_cache(tmpfile, False) srcprice = SourcePrice( - Decimal('1.723'), datetime.datetime.now(tz.tzutc()), 'USD') + Decimal("1.723"), datetime.datetime.now(tz.tzutc()), "USD" + ) source = mock.MagicMock() source.get_historical_price.return_value = srcprice - source.__file__ = '' + source.__file__ = "" # Cache miss. day = datetime.date(2006, 1, 2) - result = price.fetch_cached_price(source, 'HOOL', day) + result = price.fetch_cached_price(source, "HOOL", day) self.assertTrue(source.get_historical_price.called) self.assertEqual(1, len(price._CACHE)) self.assertEqual(srcprice, result) @@ -182,7 +186,7 @@ def test_fetch_cached_price__historical(self): source.get_historical_price.reset_mock() # Cache hit. - result = price.fetch_cached_price(source, 'HOOL', day) + result = price.fetch_cached_price(source, "HOOL", day) self.assertFalse(source.get_historical_price.called) self.assertEqual(1, len(price._CACHE)) self.assertEqual(srcprice, result) @@ -193,12 +197,10 @@ def test_fetch_cached_price__historical(self): class TestProcessArguments(unittest.TestCase): - def test_filename_not_exists(self): - with test_utils.capture('stderr'): + with test_utils.capture("stderr"): with self.assertRaises(SystemExit): - run_with_args( - price.process_args, ['--no-cache', '/some/file.beancount']) + run_with_args(price.process_args, ["--no-cache", "/some/file.beancount"]) @test_utils.docfile def test_explicit_file__badcontents(self, filename): @@ -206,49 +208,54 @@ def test_explicit_file__badcontents(self, filename): 2015-01-01 open Assets:Invest 2015-01-01 open USD ;; Error """ - with test_utils.capture('stderr'): - args, jobs, _, __ = run_with_args( - price.process_args, ['--no-cache', filename]) + with test_utils.capture("stderr"): + args, jobs, _, __ = run_with_args(price.process_args, ["--no-cache", filename]) self.assertEqual([], jobs) def test_filename_exists(self): - with tempfile.NamedTemporaryFile('w') as tmpfile: - with test_utils.capture('stderr'): + with tempfile.NamedTemporaryFile("w") as tmpfile: + with test_utils.capture("stderr"): args, jobs, _, __ = run_with_args( - price.process_args, ['--no-cache', tmpfile.name]) + price.process_args, ["--no-cache", tmpfile.name] + ) self.assertEqual([], jobs) # Empty file. def test_expressions(self): - with test_utils.capture('stderr'): + with test_utils.capture("stderr"): args, jobs, _, __ = run_with_args( - price.process_args, ['--no-cache', '-e', 'USD:yahoo/AAPL']) + price.process_args, ["--no-cache", "-e", "USD:yahoo/AAPL"] + ) self.assertEqual( - [price.DatedPrice( - 'AAPL', 'USD', None, - [price.PriceSource(yahoo, 'AAPL', False)])], jobs) + [ + price.DatedPrice( + "AAPL", "USD", None, [price.PriceSource(yahoo, "AAPL", False)] + ) + ], + jobs, + ) class TestClobber(cmptest.TestCase): - @loader.load_doc() def setUp(self, entries, _, __): """ - ;; Existing file. - 2015-01-05 price HDV 75.56 USD - 2015-01-23 price HDV 77.34 USD - 2015-02-06 price HDV 77.16 USD - 2015-02-12 price HDV 78.17 USD - 2015-05-01 price HDV 77.48 USD - 2015-06-02 price HDV 76.33 USD - 2015-06-29 price HDV 73.74 USD - 2015-07-06 price HDV 73.79 USD - 2015-08-11 price HDV 74.19 USD - 2015-09-04 price HDV 68.98 USD + ;; Existing file. + 2015-01-05 price HDV 75.56 USD + 2015-01-23 price HDV 77.34 USD + 2015-02-06 price HDV 77.16 USD + 2015-02-12 price HDV 78.17 USD + 2015-05-01 price HDV 77.48 USD + 2015-06-02 price HDV 76.33 USD + 2015-06-29 price HDV 73.74 USD + 2015-07-06 price HDV 73.79 USD + 2015-08-11 price HDV 74.19 USD + 2015-09-04 price HDV 68.98 USD """ self.entries = entries # New entries. - self.price_entries, _, __ = loader.load_string(""" + self.price_entries, _, __ = loader.load_string( + """ 2015-01-27 price HDV 76.83 USD 2015-02-06 price HDV 77.16 USD 2015-02-19 price HDV 77.5 USD @@ -257,163 +264,185 @@ def setUp(self, entries, _, __): 2015-07-06 price HDV 73.79 USD 2015-07-31 price HDV 74.64 USD 2015-08-11 price HDV 74.20 USD ;; Different - """, dedent=True) + """, + dedent=True, + ) def test_clobber_nodiffs(self): - new_price_entries, _ = price.filter_redundant_prices(self.price_entries, - self.entries, - diffs=False) - self.assertEqualEntries(""" + new_price_entries, _ = price.filter_redundant_prices( + self.price_entries, self.entries, diffs=False + ) + self.assertEqualEntries( + """ 2015-01-27 price HDV 76.83 USD 2015-02-19 price HDV 77.5 USD 2015-06-19 price HDV 76 USD 2015-07-31 price HDV 74.64 USD - """, new_price_entries) + """, + new_price_entries, + ) def test_clobber_diffs(self): - new_price_entries, _ = price.filter_redundant_prices(self.price_entries, - self.entries, - diffs=True) - self.assertEqualEntries(""" + new_price_entries, _ = price.filter_redundant_prices( + self.price_entries, self.entries, diffs=True + ) + self.assertEqualEntries( + """ 2015-01-27 price HDV 76.83 USD 2015-02-19 price HDV 77.5 USD 2015-06-19 price HDV 76 USD 2015-07-31 price HDV 74.64 USD 2015-08-11 price HDV 74.20 USD ;; Different - """, new_price_entries) + """, + new_price_entries, + ) class TestTimezone(unittest.TestCase): - - @mock.patch.object(price, 'fetch_cached_price') + @mock.patch.object(price, "fetch_cached_price") def test_fetch_price__naive_time_no_timeozne(self, fetch_cached): fetch_cached.return_value = SourcePrice( - Decimal('125.00'), datetime.datetime(2015, 11, 22, 16, 0, 0), 'JPY') - dprice = price.DatedPrice('JPY', 'USD', datetime.date(2015, 11, 22), None) + Decimal("125.00"), datetime.datetime(2015, 11, 22, 16, 0, 0), "JPY" + ) + dprice = price.DatedPrice("JPY", "USD", datetime.date(2015, 11, 22), None) with self.assertRaises(ValueError): - price.fetch_price(dprice._replace(sources=[ - price.PriceSource(yahoo, 'USDJPY', False)]), False) + price.fetch_price( + dprice._replace(sources=[price.PriceSource(yahoo, "USDJPY", False)]), False + ) class TestInverted(unittest.TestCase): - def setUp(self): - fetch_cached = mock.patch('beanprice.price.fetch_cached_price').start() + fetch_cached = mock.patch("beanprice.price.fetch_cached_price").start() fetch_cached.return_value = SourcePrice( - Decimal('125.00'), datetime.datetime(2015, 11, 22, 16, 0, 0, - tzinfo=tz.tzlocal()), - 'JPY') - self.dprice = price.DatedPrice('JPY', 'USD', datetime.date(2015, 11, 22), - None) + Decimal("125.00"), + datetime.datetime(2015, 11, 22, 16, 0, 0, tzinfo=tz.tzlocal()), + "JPY", + ) + self.dprice = price.DatedPrice("JPY", "USD", datetime.date(2015, 11, 22), None) self.addCleanup(mock.patch.stopall) def test_fetch_price__normal(self): - entry = price.fetch_price(self.dprice._replace(sources=[ - price.PriceSource(yahoo, 'USDJPY', False)]), False) - self.assertEqual(('JPY', 'USD'), (entry.currency, entry.amount.currency)) - self.assertEqual(Decimal('125.00'), entry.amount.number) + entry = price.fetch_price( + self.dprice._replace(sources=[price.PriceSource(yahoo, "USDJPY", False)]), False + ) + self.assertEqual(("JPY", "USD"), (entry.currency, entry.amount.currency)) + self.assertEqual(Decimal("125.00"), entry.amount.number) def test_fetch_price__inverted(self): - entry = price.fetch_price(self.dprice._replace(sources=[ - price.PriceSource(yahoo, 'USDJPY', True)]), False) - self.assertEqual(('JPY', 'USD'), (entry.currency, entry.amount.currency)) - self.assertEqual(Decimal('0.008'), entry.amount.number) + entry = price.fetch_price( + self.dprice._replace(sources=[price.PriceSource(yahoo, "USDJPY", True)]), False + ) + self.assertEqual(("JPY", "USD"), (entry.currency, entry.amount.currency)) + self.assertEqual(Decimal("0.008"), entry.amount.number) def test_fetch_price__swapped(self): - entry = price.fetch_price(self.dprice._replace(sources=[ - price.PriceSource(yahoo, 'USDJPY', True)]), True) - self.assertEqual(('USD', 'JPY'), (entry.currency, entry.amount.currency)) - self.assertEqual(Decimal('125.00'), entry.amount.number) + entry = price.fetch_price( + self.dprice._replace(sources=[price.PriceSource(yahoo, "USDJPY", True)]), True + ) + self.assertEqual(("USD", "JPY"), (entry.currency, entry.amount.currency)) + self.assertEqual(Decimal("125.00"), entry.amount.number) class TestImportSource(unittest.TestCase): - def test_import_source_valid(self): - for name in 'oanda', 'yahoo': + for name in "oanda", "yahoo": module = price.import_source(name) self.assertIsInstance(module, types.ModuleType) - module = price.import_source('beanprice.sources.yahoo') + module = price.import_source("beanprice.sources.yahoo") self.assertIsInstance(module, types.ModuleType) def test_import_source_invalid(self): with self.assertRaises(ImportError): - price.import_source('non.existing.module') + price.import_source("non.existing.module") class TestParseSource(unittest.TestCase): - def test_source_invalid(self): with self.assertRaises(ValueError): - price.parse_single_source('AAPL') + price.parse_single_source("AAPL") with self.assertRaises(ValueError): - price.parse_single_source('***//--') + price.parse_single_source("***//--") # The module gets imported at this stage. with self.assertRaises(ImportError): - price.parse_single_source('invalid.module.name/NASDAQ:AAPL') + price.parse_single_source("invalid.module.name/NASDAQ:AAPL") def test_source_valid(self): - psource = price.parse_single_source('yahoo/CNYUSD=X') - self.assertEqual(PS(yahoo, 'CNYUSD=X', False), psource) + psource = price.parse_single_source("yahoo/CNYUSD=X") + self.assertEqual(PS(yahoo, "CNYUSD=X", False), psource) # Make sure that an invalid name at the tail doesn't succeed. with self.assertRaises(ValueError): - psource = price.parse_single_source('yahoo/CNYUSD&X') + psource = price.parse_single_source("yahoo/CNYUSD&X") - psource = price.parse_single_source('beanprice.sources.yahoo/AAPL') - self.assertEqual(PS(yahoo, 'AAPL', False), psource) + psource = price.parse_single_source("beanprice.sources.yahoo/AAPL") + self.assertEqual(PS(yahoo, "AAPL", False), psource) class TestParseSourceMap(unittest.TestCase): - def _clean_source_map(self, smap): - return {currency: [PS(s[0].__name__, s[1], s[2]) for s in sources] - for currency, sources in smap.items()} + return { + currency: [PS(s[0].__name__, s[1], s[2]) for s in sources] + for currency, sources in smap.items() + } def test_source_map_invalid(self): - for expr in 'USD', 'something else', 'USD:NASDAQ:AAPL': + for expr in "USD", "something else", "USD:NASDAQ:AAPL": with self.assertRaises(ValueError): price.parse_source_map(expr) def test_source_map_onecur_single(self): - smap = price.parse_source_map('USD:yahoo/AAPL') + smap = price.parse_source_map("USD:yahoo/AAPL") self.assertEqual( - {'USD': [PS('beanprice.sources.yahoo', 'AAPL', False)]}, - self._clean_source_map(smap)) + {"USD": [PS("beanprice.sources.yahoo", "AAPL", False)]}, + self._clean_source_map(smap), + ) def test_source_map_onecur_multiple(self): - smap = price.parse_source_map('USD:oanda/USDCAD,yahoo/CAD=X') + smap = price.parse_source_map("USD:oanda/USDCAD,yahoo/CAD=X") self.assertEqual( - {'USD': [PS('beanprice.sources.oanda', 'USDCAD', False), - PS('beanprice.sources.yahoo', 'CAD=X', False)]}, - self._clean_source_map(smap)) + { + "USD": [ + PS("beanprice.sources.oanda", "USDCAD", False), + PS("beanprice.sources.yahoo", "CAD=X", False), + ] + }, + self._clean_source_map(smap), + ) def test_source_map_manycur_single(self): - smap = price.parse_source_map('USD:yahoo/USDCAD ' - 'CAD:yahoo/CAD=X') + smap = price.parse_source_map("USD:yahoo/USDCAD " "CAD:yahoo/CAD=X") self.assertEqual( - {'USD': [PS('beanprice.sources.yahoo', 'USDCAD', False)], - 'CAD': [PS('beanprice.sources.yahoo', 'CAD=X', False)]}, - self._clean_source_map(smap)) + { + "USD": [PS("beanprice.sources.yahoo", "USDCAD", False)], + "CAD": [PS("beanprice.sources.yahoo", "CAD=X", False)], + }, + self._clean_source_map(smap), + ) def test_source_map_manycur_multiple(self): - smap = price.parse_source_map('USD:yahoo/GBPUSD,oanda/GBPUSD ' - 'CAD:yahoo/GBPCAD') + smap = price.parse_source_map("USD:yahoo/GBPUSD,oanda/GBPUSD " "CAD:yahoo/GBPCAD") self.assertEqual( - {'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', False), - PS('beanprice.sources.oanda', 'GBPUSD', False)], - 'CAD': [PS('beanprice.sources.yahoo', 'GBPCAD', False)]}, - self._clean_source_map(smap)) + { + "USD": [ + PS("beanprice.sources.yahoo", "GBPUSD", False), + PS("beanprice.sources.oanda", "GBPUSD", False), + ], + "CAD": [PS("beanprice.sources.yahoo", "GBPCAD", False)], + }, + self._clean_source_map(smap), + ) def test_source_map_inverse(self): - smap = price.parse_source_map('USD:yahoo/^GBPUSD') + smap = price.parse_source_map("USD:yahoo/^GBPUSD") self.assertEqual( - {'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', True)]}, - self._clean_source_map(smap)) + {"USD": [PS("beanprice.sources.yahoo", "GBPUSD", True)]}, + self._clean_source_map(smap), + ) class TestFilters(unittest.TestCase): - @loader.load_doc() def test_get_price_jobs__date(self, entries, _, __): """ @@ -441,22 +470,21 @@ def test_get_price_jobs__date(self, entries, _, __): Assets:US:Invest:VEA -200 VEA {43.22 USD} @ 41.01 USD Assets:US:Invest:Margin """ - jobs = price.get_price_jobs_at_date(entries, datetime.date(2014, 1, 1), - False, None) + jobs = price.get_price_jobs_at_date(entries, datetime.date(2014, 1, 1), False, None) self.assertEqual(set(), {(job.base, job.quote) for job in jobs}) - jobs = price.get_price_jobs_at_date(entries, datetime.date(2014, 6, 1), - False, None) - self.assertEqual({('QQQ', 'USD'), ('VEA', 'USD')}, - {(job.base, job.quote) for job in jobs}) + jobs = price.get_price_jobs_at_date(entries, datetime.date(2014, 6, 1), False, None) + self.assertEqual( + {("QQQ", "USD"), ("VEA", "USD")}, {(job.base, job.quote) for job in jobs} + ) - jobs = price.get_price_jobs_at_date(entries, datetime.date(2014, 10, 1), - False, None) - self.assertEqual({('VEA', 'USD')}, - {(job.base, job.quote) for job in jobs}) + jobs = price.get_price_jobs_at_date( + entries, datetime.date(2014, 10, 1), False, None + ) + self.assertEqual({("VEA", "USD")}, {(job.base, job.quote) for job in jobs}) jobs = price.get_price_jobs_at_date(entries, None, False, None) - self.assertEqual({('QQQ', 'USD')}, {(job.base, job.quote) for job in jobs}) + self.assertEqual({("QQQ", "USD")}, {(job.base, job.quote) for job in jobs}) @loader.load_doc() def test_get_price_jobs__inactive(self, entries, _, __): @@ -481,11 +509,12 @@ def test_get_price_jobs__inactive(self, entries, _, __): Assets:US:Invest:Margin """ jobs = price.get_price_jobs_at_date(entries, None, False, None) - self.assertEqual({('VEA', 'USD')}, {(job.base, job.quote) for job in jobs}) + self.assertEqual({("VEA", "USD")}, {(job.base, job.quote) for job in jobs}) jobs = price.get_price_jobs_at_date(entries, None, True, None) - self.assertEqual({('VEA', 'USD'), ('QQQ', 'USD')}, - {(job.base, job.quote) for job in jobs}) + self.assertEqual( + {("VEA", "USD"), ("QQQ", "USD")}, {(job.base, job.quote) for job in jobs} + ) @loader.load_doc() def test_get_price_jobs__undeclared(self, entries, _, __): @@ -503,11 +532,12 @@ def test_get_price_jobs__undeclared(self, entries, _, __): Assets:US:Invest:Margin """ jobs = price.get_price_jobs_at_date(entries, None, False, None) - self.assertEqual({('QQQ', 'USD')}, {(job.base, job.quote) for job in jobs}) + self.assertEqual({("QQQ", "USD")}, {(job.base, job.quote) for job in jobs}) - jobs = price.get_price_jobs_at_date(entries, None, False, 'yahoo') - self.assertEqual({('QQQ', 'USD'), ('VEA', 'USD')}, - {(job.base, job.quote) for job in jobs}) + jobs = price.get_price_jobs_at_date(entries, None, False, "yahoo") + self.assertEqual( + {("QQQ", "USD"), ("VEA", "USD")}, {(job.base, job.quote) for job in jobs} + ) @loader.load_doc() def test_get_price_jobs__default_source(self, entries, _, __): @@ -522,7 +552,7 @@ def test_get_price_jobs__default_source(self, entries, _, __): Assets:US:Invest:QQQ 100 QQQ {86.23 USD} Assets:US:Invest:Margin """ - jobs = price.get_price_jobs_at_date(entries, None, False, 'yahoo') + jobs = price.get_price_jobs_at_date(entries, None, False, "yahoo") self.assertEqual(1, len(jobs[0].sources)) self.assertIsInstance(jobs[0].sources[0], price.PriceSource) @@ -541,17 +571,16 @@ def test_get_price_jobs__currencies_not_at_cost(self, entries, _, __): Assets:US:BofA:CHF -110 CHF @@ 100 USD """ # TODO: Shouldn't we actually return (CHF, USD) here? - jobs = price.get_price_jobs_at_date(entries, datetime.date(2021, 1, 4), - False, None) + jobs = price.get_price_jobs_at_date(entries, datetime.date(2021, 1, 4), False, None) self.assertEqual(set(), {(job.base, job.quote) for job in jobs}) - jobs = price.get_price_jobs_at_date(entries, datetime.date(2021, 1, 6), - False, None) - self.assertEqual({('CHF', 'USD')}, {(job.base, job.quote) for job in jobs}) + jobs = price.get_price_jobs_at_date(entries, datetime.date(2021, 1, 6), False, None) + self.assertEqual({("CHF", "USD")}, {(job.base, job.quote) for job in jobs}) # TODO: Shouldn't we return (CHF, USD) here, as above? - jobs = price.get_price_jobs_up_to_date(entries, datetime.date(2021, 1, 6), - False, None) + jobs = price.get_price_jobs_up_to_date( + entries, datetime.date(2021, 1, 6), False, None + ) self.assertEqual(set(), {(job.base, job.quote) for job in jobs}) @loader.load_doc() @@ -582,19 +611,21 @@ def test_get_price_jobs_up_to_date(self, entries, _, __): Assets:US:Invest:Margin """ jobs = price.get_price_jobs_up_to_date(entries, datetime.date(2021, 1, 8)) - self.assertEqual({ - ('QQQ', 'USD', datetime.date(2021, 1, 4)), - ('QQQ', 'USD', datetime.date(2021, 1, 5)), - ('QQQ', 'USD', datetime.date(2021, 1, 7)), - ('VEA', 'USD', datetime.date(2021, 1, 4)), - ('VEA', 'USD', datetime.date(2021, 1, 5)), - ('VEA', 'USD', datetime.date(2021, 1, 6)), - ('VEA', 'USD', datetime.date(2021, 1, 7)), - }, {(job.base, job.quote, job.date) for job in jobs}) + self.assertEqual( + { + ("QQQ", "USD", datetime.date(2021, 1, 4)), + ("QQQ", "USD", datetime.date(2021, 1, 5)), + ("QQQ", "USD", datetime.date(2021, 1, 7)), + ("VEA", "USD", datetime.date(2021, 1, 4)), + ("VEA", "USD", datetime.date(2021, 1, 5)), + ("VEA", "USD", datetime.date(2021, 1, 6)), + ("VEA", "USD", datetime.date(2021, 1, 7)), + }, + {(job.base, job.quote, job.date) for job in jobs}, + ) class TestFromFile(unittest.TestCase): - @loader.load_doc() def setUp(self, entries, _, __): """ @@ -624,8 +655,8 @@ def setUp(self, entries, _, __): def test_find_currencies_declared(self): currencies = price.find_currencies_declared(self.entries, None) currencies2 = [(base, quote) for base, quote, _ in currencies] - self.assertEqual([('QQQ', 'USD')], currencies2) + self.assertEqual([("QQQ", "USD")], currencies2) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/beanprice/source.py b/beanprice/source.py index b17ffb8..42d9a12 100644 --- a/beanprice/source.py +++ b/beanprice/source.py @@ -8,6 +8,7 @@ decide whether to share this with the user or to ignore and continue with other sources. """ + __copyright__ = "Copyright (C) 2015-2020 Martin Blais" __license__ = "GNU GPLv2" @@ -25,10 +26,14 @@ # used to compute a corresponding date in the user's timezone. # quote-currency: A string, the quote currency of the given price, if # available. -SourcePrice = NamedTuple('SourcePrice', - [('price', Decimal), - ('time', Optional[datetime.datetime]), - ('quote_currency', Optional[str])]) +SourcePrice = NamedTuple( + "SourcePrice", + [ + ("price", Decimal), + ("time", Optional[datetime.datetime]), + ("quote_currency", Optional[str]), + ], +) class Source: @@ -67,8 +72,9 @@ def get_latest_price(self, ticker: str) -> Optional[SourcePrice]: A SourcePrice instance, or None if we failed to fetch. """ - def get_historical_price(self, ticker: str, - time: datetime.datetime) -> Optional[SourcePrice]: + def get_historical_price( + self, ticker: str, time: datetime.datetime + ) -> Optional[SourcePrice]: """Return the lastest historical price found for the symbol at the given date. This could be the price of the close of the day, for instance. We assume @@ -85,9 +91,9 @@ def get_historical_price(self, ticker: str, A SourcePrice instance, or None if we failed to fetch. """ - def get_prices_series(self, ticker: str, - time_begin: datetime.datetime, - time_end: datetime.datetime) -> Optional[List[SourcePrice]]: + def get_prices_series( + self, ticker: str, time_begin: datetime.datetime, time_end: datetime.datetime + ) -> Optional[List[SourcePrice]]: """Return the historical daily price series between two dates. Note that weekends don't have any prices, so there's no guarantee that diff --git a/beanprice/sources/__init__.py b/beanprice/sources/__init__.py index 6381a5f..5eee444 100644 --- a/beanprice/sources/__init__.py +++ b/beanprice/sources/__init__.py @@ -3,5 +3,6 @@ This package is looked up by the driver script to figure out which extractor to use. """ + __copyright__ = "Copyright (C) 2015-2020 Martin Blais" __license__ = "GNU GPLv2" diff --git a/beanprice/sources/alphavantage.py b/beanprice/sources/alphavantage.py index dbea8fe..409be17 100644 --- a/beanprice/sources/alphavantage.py +++ b/beanprice/sources/alphavantage.py @@ -37,6 +37,7 @@ class AlphavantageApiError(ValueError): "An error from the Alphavantage API." + def _parse_ticker(ticker): """Parse the base and quote currencies from the ticker. @@ -45,60 +46,61 @@ def _parse_ticker(ticker): Returns: A (kind, symbol, base) tuple. """ - match = re.match(r'^(?Pprice|fx):(?P[^:]+):(?P\w+)$', ticker) + match = re.match(r"^(?Pprice|fx):(?P[^:]+):(?P\w+)$", ticker) if not match: - raise ValueError( - 'Invalid ticker. Use "price:SYMBOL:BASE" or "fx:CCY:BASE" format.') + raise ValueError('Invalid ticker. Use "price:SYMBOL:BASE" or "fx:CCY:BASE" format.') return match.groups() + def _do_fetch(params): - params['apikey'] = environ['ALPHAVANTAGE_API_KEY'] + params["apikey"] = environ["ALPHAVANTAGE_API_KEY"] - resp = requests.get(url='https://www.alphavantage.co/query', params=params) + resp = requests.get(url="https://www.alphavantage.co/query", params=params) data = resp.json() # This is for dealing with the rate limit, sleep for 60 seconds and then retry - if 'Note' in data: + if "Note" in data: sleep(60) - resp = requests.get(url='https://www.alphavantage.co/query', params=params) + resp = requests.get(url="https://www.alphavantage.co/query", params=params) data = resp.json() if resp.status_code != requests.codes.ok: - raise AlphavantageApiError("Invalid response ({}): {}".format(resp.status_code, - resp.text)) + raise AlphavantageApiError( + "Invalid response ({}): {}".format(resp.status_code, resp.text) + ) - if 'Error Message' in data: - raise AlphavantageApiError("Invalid response: {}".format(data['Error Message'])) + if "Error Message" in data: + raise AlphavantageApiError("Invalid response: {}".format(data["Error Message"])) return data class Source(source.Source): - def get_latest_price(self, ticker): kind, symbol, base = _parse_ticker(ticker) - if kind == 'price': + if kind == "price": params = { - 'function': 'GLOBAL_QUOTE', - 'symbol': symbol, + "function": "GLOBAL_QUOTE", + "symbol": symbol, } data = _do_fetch(params) - price_data = data['Global Quote'] - price = Decimal(price_data['05. price']) - date = parse(price_data['07. latest trading day']).replace(tzinfo=tz.tzutc()) + price_data = data["Global Quote"] + price = Decimal(price_data["05. price"]) + date = parse(price_data["07. latest trading day"]).replace(tzinfo=tz.tzutc()) else: params = { - 'function': 'CURRENCY_EXCHANGE_RATE', - 'from_currency': symbol, - 'to_currency': base, + "function": "CURRENCY_EXCHANGE_RATE", + "from_currency": symbol, + "to_currency": base, } data = _do_fetch(params) - price_data = data['Realtime Currency Exchange Rate'] - price = Decimal(price_data['5. Exchange Rate']) - date = parse(price_data['6. Last Refreshed']).replace( - tzinfo=tz.gettz(price_data['7. Time Zone'])) + price_data = data["Realtime Currency Exchange Rate"] + price = Decimal(price_data["5. Exchange Rate"]) + date = parse(price_data["6. Last Refreshed"]).replace( + tzinfo=tz.gettz(price_data["7. Time Zone"]) + ) return source.SourcePrice(price, date, base) diff --git a/beanprice/sources/alphavantage_test.py b/beanprice/sources/alphavantage_test.py index 8a851a7..a29242b 100644 --- a/beanprice/sources/alphavantage_test.py +++ b/beanprice/sources/alphavantage_test.py @@ -18,33 +18,30 @@ def response(contents, status_code=requests.codes.ok): response.status_code = status_code response.text = "" response.json.return_value = contents - return mock.patch('requests.get', return_value=response) + return mock.patch("requests.get", return_value=response) class AlphavantagePriceFetcher(unittest.TestCase): - def setUp(self): - environ['ALPHAVANTAGE_API_KEY'] = 'foo' + environ["ALPHAVANTAGE_API_KEY"] = "foo" def tearDown(self): - del environ['ALPHAVANTAGE_API_KEY'] + del environ["ALPHAVANTAGE_API_KEY"] def test_error_invalid_ticker(self): with self.assertRaises(ValueError): - alphavantage.Source().get_latest_price('INVALID') + alphavantage.Source().get_latest_price("INVALID") def test_error_network(self): - with response('Foobar', 404): + with response("Foobar", 404): with self.assertRaises(alphavantage.AlphavantageApiError): - alphavantage.Source().get_latest_price('price:IBM:USD') + alphavantage.Source().get_latest_price("price:IBM:USD") def test_error_response(self): - contents = { - "Error Message": "Something wrong" - } + contents = {"Error Message": "Something wrong"} with response(contents): with self.assertRaises(alphavantage.AlphavantageApiError): - alphavantage.Source().get_latest_price('price:IBM:USD') + alphavantage.Source().get_latest_price("price:IBM:USD") def test_valid_response_price(self): contents = { @@ -54,12 +51,13 @@ def test_valid_response_price(self): } } with response(contents): - srcprice = alphavantage.Source().get_latest_price('price:FOO:USD') + srcprice = alphavantage.Source().get_latest_price("price:FOO:USD") self.assertIsInstance(srcprice, source.SourcePrice) - self.assertEqual(Decimal('144.7400'), srcprice.price) - self.assertEqual('USD', srcprice.quote_currency) - self.assertEqual(datetime.datetime(2021, 1, 21, 0, 0, 0, tzinfo=tz.tzutc()), - srcprice.time) + self.assertEqual(Decimal("144.7400"), srcprice.price) + self.assertEqual("USD", srcprice.quote_currency) + self.assertEqual( + datetime.datetime(2021, 1, 21, 0, 0, 0, tzinfo=tz.tzutc()), srcprice.time + ) def test_valid_response_fx(self): contents = { @@ -70,13 +68,14 @@ def test_valid_response_fx(self): } } with response(contents): - srcprice = alphavantage.Source().get_latest_price('fx:USD:CHF') + srcprice = alphavantage.Source().get_latest_price("fx:USD:CHF") self.assertIsInstance(srcprice, source.SourcePrice) - self.assertEqual(Decimal('108.94000000'), srcprice.price) - self.assertEqual('CHF', srcprice.quote_currency) - self.assertEqual(datetime.datetime(2021, 2, 21, 20, 32, 25, tzinfo=tz.tzutc()), - srcprice.time) + self.assertEqual(Decimal("108.94000000"), srcprice.price) + self.assertEqual("CHF", srcprice.quote_currency) + self.assertEqual( + datetime.datetime(2021, 2, 21, 20, 32, 25, tzinfo=tz.tzutc()), srcprice.time + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/beanprice/sources/coinbase.py b/beanprice/sources/coinbase.py index 2353134..e5bb6a3 100644 --- a/beanprice/sources/coinbase.py +++ b/beanprice/sources/coinbase.py @@ -30,18 +30,19 @@ def fetch_quote(ticker, time=None): url = "https://api.coinbase.com/v2/prices/{}/spot".format(ticker.lower()) options = {} if time is not None: - options['date'] = time.astimezone(tz.tzutc()).date().isoformat() + options["date"] = time.astimezone(tz.tzutc()).date().isoformat() response = requests.get(url, options) if response.status_code != requests.codes.ok: - raise CoinbaseError("Invalid response ({}): {}".format(response.status_code, - response.text)) + raise CoinbaseError( + "Invalid response ({}): {}".format(response.status_code, response.text) + ) result = response.json() - price = Decimal(result['data']['amount']) + price = Decimal(result["data"]["amount"]) if time is None: time = datetime.datetime.now(tz.tzutc()) - currency = result['data']['currency'] + currency = result["data"]["currency"] return source.SourcePrice(price, time, currency) diff --git a/beanprice/sources/coinbase_test.py b/beanprice/sources/coinbase_test.py index 0b89448..23311cd 100644 --- a/beanprice/sources/coinbase_test.py +++ b/beanprice/sources/coinbase_test.py @@ -17,40 +17,36 @@ def response(contents, status_code=requests.codes.ok): response.status_code = status_code response.text = "" response.json.return_value = contents - return mock.patch('requests.get', return_value=response) + return mock.patch("requests.get", return_value=response) class CoinbasePriceFetcher(unittest.TestCase): - def test_error_network(self): with response(None, 404): with self.assertRaises(ValueError) as exc: - coinbase.fetch_quote('AAPL') - self.assertRegex(exc.message, 'premium') + coinbase.fetch_quote("AAPL") + self.assertRegex(exc.message, "premium") def test_valid_response(self): - contents = {"data": {"base": "BTC", - "currency": "USD", - "amount": "101.23456"}} + contents = {"data": {"base": "BTC", "currency": "USD", "amount": "101.23456"}} with response(contents): - srcprice = coinbase.Source().get_latest_price('BTC-GBP') + srcprice = coinbase.Source().get_latest_price("BTC-GBP") self.assertIsInstance(srcprice, source.SourcePrice) - self.assertEqual(Decimal('101.23456'), srcprice.price) - self.assertEqual('USD', srcprice.quote_currency) + self.assertEqual(Decimal("101.23456"), srcprice.price) + self.assertEqual("USD", srcprice.quote_currency) def test_historical_price(self): - contents = {"data": {"base": "BTC", - "currency": "USD", - "amount": "101.23456"}} + contents = {"data": {"base": "BTC", "currency": "USD", "amount": "101.23456"}} with response(contents): time = datetime.datetime(2018, 3, 27, 0, 0, 0, tzinfo=tz.tzutc()) - srcprice = coinbase.Source().get_historical_price('BTC-GBP', time) + srcprice = coinbase.Source().get_historical_price("BTC-GBP", time) self.assertIsInstance(srcprice, source.SourcePrice) - self.assertEqual(Decimal('101.23456'), srcprice.price) - self.assertEqual('USD', srcprice.quote_currency) - self.assertEqual(datetime.datetime(2018, 3, 27, 0, 0, 0, tzinfo=tz.tzutc()), - srcprice.time) + self.assertEqual(Decimal("101.23456"), srcprice.price) + self.assertEqual("USD", srcprice.quote_currency) + self.assertEqual( + datetime.datetime(2018, 3, 27, 0, 0, 0, tzinfo=tz.tzutc()), srcprice.time + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/beanprice/sources/coinmarketcap.py b/beanprice/sources/coinmarketcap.py index 7ca0c55..70a8e4e 100644 --- a/beanprice/sources/coinmarketcap.py +++ b/beanprice/sources/coinmarketcap.py @@ -29,43 +29,44 @@ def _parse_ticker(ticker): Returns: A pair of (base, quote) currencies. """ - match = re.match(r'^(?P\w+)-(?P\w+)$', ticker) + match = re.match(r"^(?P\w+)-(?P\w+)$", ticker) if not match: - raise ValueError( - 'Invalid ticker. Use "BASE-SYMBOL" format.') + raise ValueError('Invalid ticker. Use "BASE-SYMBOL" format.') return match.groups() class Source(source.Source): - def get_latest_price(self, ticker): symbol, base = _parse_ticker(ticker) headers = { - 'X-CMC_PRO_API_KEY': environ['COINMARKETCAP_API_KEY'], + "X-CMC_PRO_API_KEY": environ["COINMARKETCAP_API_KEY"], } params = { - 'symbol': symbol, - 'convert': base, + "symbol": symbol, + "convert": base, } resp = requests.get( - url='https://pro-api.coinmarketcap.com/v1/cryptocurrency/quotes/latest', - params=params, headers=headers) + url="https://pro-api.coinmarketcap.com/v1/cryptocurrency/quotes/latest", + params=params, + headers=headers, + ) if resp.status_code != requests.codes.ok: raise CoinmarketcapApiError( "Invalid response ({}): {}".format(resp.status_code, resp.text) ) data = resp.json() - if data['status']['error_code'] != 0: - status = data['status'] + if data["status"]["error_code"] != 0: + status = data["status"] raise CoinmarketcapApiError( "Invalid response ({}): {}".format( - status['error_code'], status['error_message']) + status["error_code"], status["error_message"] + ) ) - quote = data['data'][symbol]['quote'][base] - price = Decimal(str(quote['price'])) - date = parse(quote['last_updated']) + quote = data["data"][symbol]["quote"][base] + price = Decimal(str(quote["price"])) + date = parse(quote["last_updated"]) return source.SourcePrice(price, date, base) diff --git a/beanprice/sources/coinmarketcap_test.py b/beanprice/sources/coinmarketcap_test.py index 6c6befe..4a0defa 100644 --- a/beanprice/sources/coinmarketcap_test.py +++ b/beanprice/sources/coinmarketcap_test.py @@ -16,25 +16,24 @@ def response(contents, status_code=requests.codes.ok): response.status_code = status_code response.text = "" response.json.return_value = contents - return mock.patch('requests.get', return_value=response) + return mock.patch("requests.get", return_value=response) class CoinmarketcapPriceFetcher(unittest.TestCase): - def setUp(self): - environ['COINMARKETCAP_API_KEY'] = 'foo' + environ["COINMARKETCAP_API_KEY"] = "foo" def tearDown(self): - del environ['COINMARKETCAP_API_KEY'] + del environ["COINMARKETCAP_API_KEY"] def test_error_invalid_ticker(self): with self.assertRaises(ValueError): - coinmarketcap.Source().get_latest_price('INVALID') + coinmarketcap.Source().get_latest_price("INVALID") def test_error_network(self): - with response('Foobar', 404): + with response("Foobar", 404): with self.assertRaises(ValueError): - coinmarketcap.Source().get_latest_price('BTC-CHF') + coinmarketcap.Source().get_latest_price("BTC-CHF") def test_error_request(self): contents = { @@ -45,7 +44,7 @@ def test_error_request(self): } with response(contents): with self.assertRaises(ValueError): - coinmarketcap.Source().get_latest_price('BTC-CHF') + coinmarketcap.Source().get_latest_price("BTC-CHF") def test_valid_response(self): contents = { @@ -54,7 +53,7 @@ def test_valid_response(self): "quote": { "CHF": { "price": 1234.56, - "last_updated": "2018-08-09T21:56:28.000Z" + "last_updated": "2018-08-09T21:56:28.000Z", } } } @@ -62,14 +61,14 @@ def test_valid_response(self): "status": { "error_code": 0, "error_message": "", - } + }, } with response(contents): - srcprice = coinmarketcap.Source().get_latest_price('BTC-CHF') + srcprice = coinmarketcap.Source().get_latest_price("BTC-CHF") self.assertIsInstance(srcprice, source.SourcePrice) - self.assertEqual(Decimal('1234.56'), srcprice.price) - self.assertEqual('CHF', srcprice.quote_currency) + self.assertEqual(Decimal("1234.56"), srcprice.price) + self.assertEqual("CHF", srcprice.quote_currency) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/beanprice/sources/eastmoneyfund.py b/beanprice/sources/eastmoneyfund.py index e6373c9..42a14a8 100644 --- a/beanprice/sources/eastmoneyfund.py +++ b/beanprice/sources/eastmoneyfund.py @@ -40,9 +40,7 @@ class EastMoneyFundError(ValueError): "An error from the EastMoneyFund API." -UnsupportTickerError = EastMoneyFundError( - "header not match, dont support this ticker type" -) +UnsupportTickerError = EastMoneyFundError("header not match, dont support this ticker type") def parse_page(page): @@ -107,7 +105,6 @@ def get_price_series( class Source(source.Source): - def get_latest_price(self, ticker): end_time = datetime.datetime.now(TIMEZONE) begin_time = end_time - datetime.timedelta(days=10) diff --git a/beanprice/sources/eastmoneyfund_test.py b/beanprice/sources/eastmoneyfund_test.py index 5988f0f..defd9e2 100644 --- a/beanprice/sources/eastmoneyfund_test.py +++ b/beanprice/sources/eastmoneyfund_test.py @@ -30,7 +30,6 @@ def response(contents, status_code=requests.codes.ok): class EastMoneyFundFetcher(unittest.TestCase): - def test_error_network(self): with response(None, 404): with self.assertRaises(ValueError): diff --git a/beanprice/sources/iex.py b/beanprice/sources/iex.py index 08f0fe2..cbac92d 100644 --- a/beanprice/sources/iex.py +++ b/beanprice/sources/iex.py @@ -5,6 +5,7 @@ Timezone information: There is currency no support for historical prices. The output datetime is provided as a UNIX timestamp. """ + __copyright__ = "Copyright (C) 2018-2020 Martin Blais" __license__ = "GNU GPLv2" @@ -27,24 +28,24 @@ def fetch_quote(ticker): url = "https://api.iextrading.com/1.0/tops/last?symbols={}".format(ticker.upper()) response = requests.get(url) if response.status_code != requests.codes.ok: - raise IEXError("Invalid response ({}): {}".format( - response.status_code, response.text)) + raise IEXError( + "Invalid response ({}): {}".format(response.status_code, response.text) + ) results = response.json() if len(results) != 1: - raise IEXError("Invalid number of responses from IEX: {}".format( - response.text)) + raise IEXError("Invalid number of responses from IEX: {}".format(response.text)) result = results[0] - price = Decimal(result['price']).quantize(Decimal('0.01')) + price = Decimal(result["price"]).quantize(Decimal("0.01")) # IEX is American markets. us_timezone = tz.gettz("America/New_York") - time = datetime.datetime.fromtimestamp(result['time'] / 1000) + time = datetime.datetime.fromtimestamp(result["time"] / 1000) time = time.astimezone(us_timezone) # As far as can tell, all the instruments on IEX are priced in USD. - return source.SourcePrice(price, time, 'USD') + return source.SourcePrice(price, time, "USD") class Source(source.Source): @@ -58,4 +59,5 @@ def get_historical_price(self, ticker, time): """See contract in beanprice.source.Source.""" raise NotImplementedError( "This is now implemented at https://iextrading.com/developers/docs/#hist and " - "needs to be added here.") + "needs to be added here." + ) diff --git a/beanprice/sources/iex_test.py b/beanprice/sources/iex_test.py index 76a3b69..73b5e1c 100644 --- a/beanprice/sources/iex_test.py +++ b/beanprice/sources/iex_test.py @@ -20,30 +20,27 @@ def response(contents, status_code=requests.codes.ok): response.status_code = status_code response.text = "" response.json.return_value = contents - return mock.patch('requests.get', return_value=response) + return mock.patch("requests.get", return_value=response) class IEXPriceFetcher(unittest.TestCase): - def test_error_network(self): with response(None, 404): with self.assertRaises(ValueError) as exc: - iex.fetch_quote('AAPL') - self.assertRegex(exc.message, 'premium') + iex.fetch_quote("AAPL") + self.assertRegex(exc.message, "premium") def _test_valid_response(self): - contents = [{"symbol": "HOOL", - "price": 183.61, - "size": 100, - "time": 1590177596030}] + contents = [{"symbol": "HOOL", "price": 183.61, "size": 100, "time": 1590177596030}] with response(contents): - srcprice = iex.fetch_quote('HOOL') + srcprice = iex.fetch_quote("HOOL") self.assertIsInstance(srcprice, source.SourcePrice) - self.assertEqual(Decimal('183.61'), srcprice.price) + self.assertEqual(Decimal("183.61"), srcprice.price) self.assertEqual( datetime.datetime(2020, 5, 22, 19, 59, 56, 30000, tzinfo=tz.tzutc()), - srcprice.time.astimezone(tz.tzutc())) - self.assertEqual('USD', srcprice.quote_currency) + srcprice.time.astimezone(tz.tzutc()), + ) + self.assertEqual("USD", srcprice.quote_currency) def test_valid_response(self): for tzname in "America/New_York", "Europe/Berlin", "Asia/Tokyo": @@ -51,5 +48,5 @@ def test_valid_response(self): self._test_valid_response() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/beanprice/sources/oanda.py b/beanprice/sources/oanda.py index 7245b52..848a6b0 100644 --- a/beanprice/sources/oanda.py +++ b/beanprice/sources/oanda.py @@ -11,6 +11,7 @@ Timezone information: Input and output datetimes are specified via UTC timestamps. """ + __copyright__ = "Copyright (C) 2018-2020 Martin Blais" __license__ = "GNU GPLv2" @@ -53,14 +54,14 @@ def _fetch_candles(params): A sorted list of (time, price) points. """ - url = '?'.join((URL, parse.urlencode(sorted(params.items())))) + url = "?".join((URL, parse.urlencode(sorted(params.items())))) logging.info("Fetching '%s'", url) # Fetch the data. response = net_utils.retrying_urlopen(url) if response is None: return None - data_string = response.read().decode('utf-8') + data_string = response.read().decode("utf-8") # Parse it. data = json.loads(data_string, parse_float=Decimal) @@ -68,11 +69,12 @@ def _fetch_candles(params): # Find the candle with the latest time before the given time we're searching # for. time_prices = [] - candles = sorted(data['candles'], key=lambda candle: candle['time']) + candles = sorted(data["candles"], key=lambda candle: candle["time"]) for candle in candles: candle_dt_utc = datetime.datetime.strptime( - candle['time'], r"%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=tz.tzutc()) - candle_price = Decimal(candle['openMid']) + candle["time"], r"%Y-%m-%dT%H:%M:%S.%fZ" + ).replace(tzinfo=tz.tzutc()) + candle_price = Decimal(candle["openMid"]) time_prices.append((candle_dt_utc, candle_price)) except KeyError: logging.error("Unexpected response data: %s", data) @@ -82,11 +84,10 @@ def _fetch_candles(params): def _fetch_price(params_dict, time): """Fetch a price from OANDA using the given parameters.""" - ticker = params_dict['instrument'] + ticker = params_dict["instrument"] _, quote_currency = _get_currencies(ticker) if quote_currency is None: - logging.error("Invalid price source ticker '%s'; must be like 'EUR_USD'", - ticker) + logging.error("Invalid price source ticker '%s'; must be like 'EUR_USD'", ticker) return time_prices = _fetch_candles(params_dict) @@ -111,23 +112,23 @@ def get_latest_price(self, ticker): """See contract in beanprice.source.Source.""" time = datetime.datetime.now(tz.tzutc()) params_dict = { - 'instrument': ticker, - 'granularity': 'S5', # Every two hours. - 'count': '10', - 'candleFormat': 'midpoint', + "instrument": ticker, + "granularity": "S5", # Every two hours. + "count": "10", + "candleFormat": "midpoint", } return _fetch_price(params_dict, time) def get_historical_price(self, ticker, time): """See contract in beanprice.source.Source.""" time = time.astimezone(tz.tzutc()) - query_interval_begin = (time - datetime.timedelta(days=5)) - query_interval_end = (time + datetime.timedelta(days=1)) + query_interval_begin = time - datetime.timedelta(days=5) + query_interval_end = time + datetime.timedelta(days=1) params_dict = { - 'instrument': ticker, - 'granularity': 'H2', # Every two hours. - 'candleFormat': 'midpoint', - 'start': query_interval_begin.isoformat('T'), - 'end': query_interval_end.isoformat('T'), + "instrument": ticker, + "granularity": "H2", # Every two hours. + "candleFormat": "midpoint", + "start": query_interval_begin.isoformat("T"), + "end": query_interval_end.isoformat("T"), } return _fetch_price(params_dict, time) diff --git a/beanprice/sources/oanda_test.py b/beanprice/sources/oanda_test.py index 86a755f..4512c0b 100644 --- a/beanprice/sources/oanda_test.py +++ b/beanprice/sources/oanda_test.py @@ -1,5 +1,5 @@ -"""Test for price extractor of OANDA. -""" +"""Test for price extractor of OANDA.""" + __copyright__ = "Copyright (C) 2018-2020 Martin Blais" __license__ = "GNU GPLv2" @@ -25,53 +25,55 @@ def response(code, contents=None): urlopen = mock.MagicMock(return_value=None) if isinstance(contents, str): response = mock.MagicMock() - response.read = mock.MagicMock(return_value=contents.encode('utf-8')) + response.read = mock.MagicMock(return_value=contents.encode("utf-8")) response.getcode = mock.MagicMock(return_value=200) urlopen.return_value = response - return mock.patch.object(net_utils, 'retrying_urlopen', urlopen) + return mock.patch.object(net_utils, "retrying_urlopen", urlopen) class TestOandaMisc(unittest.TestCase): - def test_get_currencies(self): - self.assertEqual(('USD', 'CAD'), oanda._get_currencies('USD_CAD')) + self.assertEqual(("USD", "CAD"), oanda._get_currencies("USD_CAD")) def test_get_currencies_invalid(self): - self.assertEqual((None, None), oanda._get_currencies('USDCAD')) + self.assertEqual((None, None), oanda._get_currencies("USDCAD")) class TimezoneTestBase: - def setUp(self): - tz_value = 'Europe/Berlin' - self.tz_old = os.environ.get('TZ', None) - os.environ['TZ'] = tz_value + tz_value = "Europe/Berlin" + self.tz_old = os.environ.get("TZ", None) + os.environ["TZ"] = tz_value time.tzset() def tearDown(self): if self.tz_old is None: - del os.environ['TZ'] + del os.environ["TZ"] else: - os.environ['TZ'] = self.tz_old + os.environ["TZ"] = self.tz_old time.tzset() class TestOandaFetchCandles(TimezoneTestBase, unittest.TestCase): - @response(404) def test_null_response(self): self.assertIs(None, oanda._fetch_candles({})) - @response(200, ''' + @response( + 200, + """ { "instrument" : "USD_CAD", "granularity" : "S5" } - ''') + """, + ) def test_key_error(self): self.assertIs(None, oanda._fetch_candles({})) - @response(200, ''' + @response( + 200, + """ { "instrument" : "USD_CAD", "granularity" : "S5", @@ -96,39 +98,52 @@ def test_key_error(self): } ] } - ''') + """, + ) def test_valid(self): - self.assertEqual([ - (datetime.datetime(2017, 1, 23, 0, 45, 15, tzinfo=UTC), Decimal('1.330115')), - (datetime.datetime(2017, 1, 23, 0, 45, 20, tzinfo=UTC), Decimal('1.330065')) - ], oanda._fetch_candles({})) + self.assertEqual( + [ + ( + datetime.datetime(2017, 1, 23, 0, 45, 15, tzinfo=UTC), + Decimal("1.330115"), + ), + ( + datetime.datetime(2017, 1, 23, 0, 45, 20, tzinfo=UTC), + Decimal("1.330065"), + ), + ], + oanda._fetch_candles({}), + ) class TestOandaGetLatest(unittest.TestCase): - def setUp(self): self.fetcher = oanda.Source() def test_invalid_ticker(self): - srcprice = self.fetcher.get_latest_price('NOTATICKER') + srcprice = self.fetcher.get_latest_price("NOTATICKER") self.assertIsNone(srcprice) def test_no_candles(self): - with mock.patch.object(oanda, '_fetch_candles', return_value=None): - self.assertEqual(None, self.fetcher.get_latest_price('USD_CAD')) + with mock.patch.object(oanda, "_fetch_candles", return_value=None): + self.assertEqual(None, self.fetcher.get_latest_price("USD_CAD")) def _test_valid(self): candles = [ - (datetime.datetime(2017, 1, 21, 0, 45, 15, tzinfo=UTC), Decimal('1.330115')), - (datetime.datetime(2017, 1, 21, 0, 45, 20, tzinfo=UTC), Decimal('1.330065')), + (datetime.datetime(2017, 1, 21, 0, 45, 15, tzinfo=UTC), Decimal("1.330115")), + (datetime.datetime(2017, 1, 21, 0, 45, 20, tzinfo=UTC), Decimal("1.330065")), ] - with mock.patch.object(oanda, '_fetch_candles', return_value=candles): - srcprice = self.fetcher.get_latest_price('USD_CAD') + with mock.patch.object(oanda, "_fetch_candles", return_value=candles): + srcprice = self.fetcher.get_latest_price("USD_CAD") # Latest price, with current time as time. - self.assertEqual(source.SourcePrice( - Decimal('1.330065'), - datetime.datetime(2017, 1, 21, 0, 45, 20, tzinfo=UTC), - 'CAD'), srcprice) + self.assertEqual( + source.SourcePrice( + Decimal("1.330065"), + datetime.datetime(2017, 1, 21, 0, 45, 20, tzinfo=UTC), + "CAD", + ), + srcprice, + ) def test_valid(self): for tzname in "America/New_York", "Europe/Berlin", "Asia/Tokyo": @@ -137,37 +152,37 @@ def test_valid(self): class TestOandaGetHistorical(TimezoneTestBase, unittest.TestCase): - def setUp(self): self.fetcher = oanda.Source() super().setUp() def test_invalid_ticker(self): - srcprice = self.fetcher.get_latest_price('NOTATICKER') + srcprice = self.fetcher.get_latest_price("NOTATICKER") self.assertIsNone(srcprice) def test_no_candles(self): - with mock.patch.object(oanda, '_fetch_candles', return_value=None): - self.assertEqual(None, self.fetcher.get_latest_price('USD_CAD')) + with mock.patch.object(oanda, "_fetch_candles", return_value=None): + self.assertEqual(None, self.fetcher.get_latest_price("USD_CAD")) def _check_valid(self, query_date, out_time, out_price): candles = [ - (datetime.datetime(2017, 1, 21, 0, 0, 0, tzinfo=UTC), Decimal('1.3100')), - (datetime.datetime(2017, 1, 21, 8, 0, 0, tzinfo=UTC), Decimal('1.3300')), - (datetime.datetime(2017, 1, 21, 16, 0, 0, tzinfo=UTC), Decimal('1.3500')), - (datetime.datetime(2017, 1, 22, 0, 0, 0, tzinfo=UTC), Decimal('1.3700')), - (datetime.datetime(2017, 1, 22, 8, 0, 0, tzinfo=UTC), Decimal('1.3900')), - (datetime.datetime(2017, 1, 22, 16, 0, 0, tzinfo=UTC), Decimal('1.4100')), - (datetime.datetime(2017, 1, 23, 0, 0, 0, tzinfo=UTC), Decimal('1.4300')), - (datetime.datetime(2017, 1, 23, 8, 0, 0, tzinfo=UTC), Decimal('1.4500')), - (datetime.datetime(2017, 1, 23, 16, 0, 0, tzinfo=UTC), Decimal('1.4700')), + (datetime.datetime(2017, 1, 21, 0, 0, 0, tzinfo=UTC), Decimal("1.3100")), + (datetime.datetime(2017, 1, 21, 8, 0, 0, tzinfo=UTC), Decimal("1.3300")), + (datetime.datetime(2017, 1, 21, 16, 0, 0, tzinfo=UTC), Decimal("1.3500")), + (datetime.datetime(2017, 1, 22, 0, 0, 0, tzinfo=UTC), Decimal("1.3700")), + (datetime.datetime(2017, 1, 22, 8, 0, 0, tzinfo=UTC), Decimal("1.3900")), + (datetime.datetime(2017, 1, 22, 16, 0, 0, tzinfo=UTC), Decimal("1.4100")), + (datetime.datetime(2017, 1, 23, 0, 0, 0, tzinfo=UTC), Decimal("1.4300")), + (datetime.datetime(2017, 1, 23, 8, 0, 0, tzinfo=UTC), Decimal("1.4500")), + (datetime.datetime(2017, 1, 23, 16, 0, 0, tzinfo=UTC), Decimal("1.4700")), ] - with mock.patch.object(oanda, '_fetch_candles', return_value=candles): + with mock.patch.object(oanda, "_fetch_candles", return_value=candles): query_time = datetime.datetime.combine( - query_date, time=datetime.time(16, 0, 0), tzinfo=UTC) - srcprice = self.fetcher.get_historical_price('USD_CAD', query_time) + query_date, time=datetime.time(16, 0, 0), tzinfo=UTC + ) + srcprice = self.fetcher.get_historical_price("USD_CAD", query_time) if out_time is not None: - self.assertEqual(source.SourcePrice(out_price, out_time, 'CAD'), srcprice) + self.assertEqual(source.SourcePrice(out_price, out_time, "CAD"), srcprice) else: self.assertEqual(None, srcprice) @@ -177,7 +192,8 @@ def test_valid_same_date(self): self._check_valid( datetime.date(2017, 1, 22), datetime.datetime(2017, 1, 22, 16, 0, tzinfo=UTC), - Decimal('1.4100')) + Decimal("1.4100"), + ) def test_valid_before(self): for tzname in "America/New_York", "Europe/Berlin", "Asia/Tokyo": @@ -185,7 +201,8 @@ def test_valid_before(self): self._check_valid( datetime.date(2017, 1, 23), datetime.datetime(2017, 1, 23, 16, 0, tzinfo=UTC), - Decimal('1.4700')) + Decimal("1.4700"), + ) def test_valid_after(self): for tzname in "America/New_York", "Europe/Berlin", "Asia/Tokyo": @@ -193,5 +210,5 @@ def test_valid_after(self): self._check_valid(datetime.date(2017, 1, 20), None, None) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/beanprice/sources/quandl.py b/beanprice/sources/quandl.py index 33a8414..967e7d8 100644 --- a/beanprice/sources/quandl.py +++ b/beanprice/sources/quandl.py @@ -28,6 +28,7 @@ believe the dates are presumed to live in the timezone of each particular data source. (It's unclear, not documented.) """ + __copyright__ = "Copyright (C) 2018-2020 Martin Blais" __license__ = "GNU GPLv2" @@ -48,18 +49,17 @@ class QuandlError(ValueError): "An error from the Quandl API." -TickerSpec = collections.namedtuple('TickerSpec', 'database dataset column') +TickerSpec = collections.namedtuple("TickerSpec", "database dataset column") def parse_ticker(ticker): """Convert ticker to Quandl codes.""" if not re.match(r"[A-Z0-9]+:[A-Z0-9]+(:[^:; ]+)?$", ticker): - raise ValueError( - 'Invalid code. Use ":[:]" format.') + raise ValueError('Invalid code. Use ":[:]" format.') split = ticker.split(":") if len(split) == 2: return TickerSpec(split[0], split[1], None) - return TickerSpec(split[0], split[1], split[2].replace('_', ' ')) + return TickerSpec(split[0], split[1], split[2].replace("_", " ")) def fetch_time_series(ticker, time=None): @@ -67,7 +67,8 @@ def fetch_time_series(ticker, time=None): # Create request payload. ticker_spec = parse_ticker(ticker) url = "https://www.quandl.com/api/v3/datasets/{}/{}.json".format( - ticker_spec.database, ticker_spec.dataset) + ticker_spec.database, ticker_spec.dataset + ) payload = {"limit": 1} if time is not None: date = time.date() @@ -75,41 +76,42 @@ def fetch_time_series(ticker, time=None): payload["end_date"] = date.isoformat() # Add API key, if it is set in the environment. - if 'QUANDL_API_KEY' in os.environ: - payload['api_key'] = os.environ['QUANDL_API_KEY'] + if "QUANDL_API_KEY" in os.environ: + payload["api_key"] = os.environ["QUANDL_API_KEY"] # Fetch and process errors. response = requests.get(url, params=payload) if response.status_code != requests.codes.ok: - raise QuandlError("Invalid response ({}): {}".format(response.status_code, - response.text)) + raise QuandlError( + "Invalid response ({}): {}".format(response.status_code, response.text) + ) result = response.json() - if 'quandl_error' in result: - raise QuandlError(result['quandl_error']['message']) + if "quandl_error" in result: + raise QuandlError(result["quandl_error"]["message"]) # Parse result container. - dataset = result['dataset'] - column_names = dataset['column_names'] - date_index = column_names.index('Date') + dataset = result["dataset"] + column_names = dataset["column_names"] + date_index = column_names.index("Date") if ticker_spec.column is not None: data_index = column_names.index(ticker_spec.column) else: try: - data_index = column_names.index('Adj. Close') + data_index = column_names.index("Adj. Close") except ValueError: - data_index = column_names.index('Close') - data = dataset['data'][0] + data_index = column_names.index("Close") + data = dataset["data"][0] # Gather time and assume it's in UTC timezone (Quandl does not provide the # market's timezone). - time = datetime.datetime.strptime(data[date_index], '%Y-%m-%d') + time = datetime.datetime.strptime(data[date_index], "%Y-%m-%d") time = time.replace(tzinfo=tz.tzutc()) # Gather price. # Quantize with the same precision default rendering of floats occur. price_float = data[data_index] price = Decimal(price_float) - match = re.search(r'(\..*)', str(price_float)) + match = re.search(r"(\..*)", str(price_float)) if match: price = price.quantize(Decimal(match.group(1))) diff --git a/beanprice/sources/quandl_test.py b/beanprice/sources/quandl_test.py index f4ab153..06e8bf0 100644 --- a/beanprice/sources/quandl_test.py +++ b/beanprice/sources/quandl_test.py @@ -20,110 +20,139 @@ def response(contents, status_code=requests.codes.ok): response.status_code = status_code response.text = "" response.json.return_value = contents - return mock.patch('requests.get', return_value=response) + return mock.patch("requests.get", return_value=response) class QuandlPriceFetcher(unittest.TestCase): - def test_parse_ticker(self): # NOTE(pmarciniak): "LBMA:GOLD:USD (PM)" is a valid ticker in Quandl # requests, but since space is not allowed in price source syntax, we're # representing space with an underscore. - self.assertEqual(quandl.TickerSpec('WIKI', 'FB', None), - quandl.parse_ticker('WIKI:FB')) - self.assertEqual(quandl.TickerSpec('LBMA', 'GOLD', 'USD (PM)'), - quandl.parse_ticker('LBMA:GOLD:USD_(PM)')) - for test in ['WIKI/FB', 'FB', 'WIKI.FB', 'WIKI,FB', - 'LBMA:GOLD:USD (PM)', 'LBMA:GOLD:col:umn']: + self.assertEqual( + quandl.TickerSpec("WIKI", "FB", None), quandl.parse_ticker("WIKI:FB") + ) + self.assertEqual( + quandl.TickerSpec("LBMA", "GOLD", "USD (PM)"), + quandl.parse_ticker("LBMA:GOLD:USD_(PM)"), + ) + for test in [ + "WIKI/FB", + "FB", + "WIKI.FB", + "WIKI,FB", + "LBMA:GOLD:USD (PM)", + "LBMA:GOLD:col:umn", + ]: with self.assertRaises(ValueError): quandl.parse_ticker(test) def test_error_premium(self): - contents = {'quandl_error': { - 'code': 'QEPx05', - 'message': ('You have attempted to view a premium database in ' - 'anonymous mode, i.e., without providing a Quandl ' - 'key. Please register for a free Quandl account, ' - 'and then include your API key with your ' - 'requests.')}} + contents = { + "quandl_error": { + "code": "QEPx05", + "message": ( + "You have attempted to view a premium database in " + "anonymous mode, i.e., without providing a Quandl " + "key. Please register for a free Quandl account, " + "and then include your API key with your " + "requests." + ), + } + } with response(contents): with self.assertRaises(ValueError) as exc: - quandl.fetch_time_series('WIKI:FB', None) - self.assertRegex(exc.message, 'premium') + quandl.fetch_time_series("WIKI:FB", None) + self.assertRegex(exc.message, "premium") def test_error_subscription(self): - contents = {'quandl_error': { - 'code': 'QEPx04', - 'message': ('You do not have permission to view this dataset. ' - 'Please subscribe to this database to get ' - 'access.')}} + contents = { + "quandl_error": { + "code": "QEPx04", + "message": ( + "You do not have permission to view this dataset. " + "Please subscribe to this database to get " + "access." + ), + } + } with response(contents): with self.assertRaises(ValueError) as exc: - quandl.fetch_time_series('WIKI:FB', None) - self.assertRegex(exc.message, 'premium') + quandl.fetch_time_series("WIKI:FB", None) + self.assertRegex(exc.message, "premium") def test_error_network(self): with response(None, 404): with self.assertRaises(ValueError) as exc: - quandl.fetch_time_series('WIKI:FB', None) - self.assertRegex(exc.message, 'premium') + quandl.fetch_time_series("WIKI:FB", None) + self.assertRegex(exc.message, "premium") def _test_valid_response(self): contents = { - 'dataset': {'collapse': None, - 'column_index': None, - 'column_names': ['Date', - 'Open', - 'High', - 'Low', - 'Close', - 'Volume', - 'Ex-Dividend', - 'Split Ratio', - 'Adj. Open', - 'Adj. High', - 'Adj. Low', - 'Adj. Close', - 'Adj. Volume'], - 'data': [['2018-03-27', - 1063.9, - 1064.54, - 997.62, - 1006.94, - 2940957.0, - 0.0, - 1.0, - 1063.9, - 1064.54, - 997.62, - 1006.94, - 2940957.0]], - 'database_code': 'WIKI', - 'database_id': 4922, - 'dataset_code': 'GOOGL', - 'description': 'This dataset has no description.', - 'end_date': '2018-03-27', - 'frequency': 'daily', - 'id': 11304017, - 'limit': 1, - 'name': ('Alphabet Inc (GOOGL) Prices, Dividends, Splits and ' - 'Trading Volume'), - 'newest_available_date': '2018-03-27', - 'oldest_available_date': '2004-08-19', - 'order': None, - 'premium': False, - 'refreshed_at': '2018-03-27T21:46:11.201Z', - 'start_date': '2004-08-19', - 'transform': None, - 'type': 'Time Series'}} + "dataset": { + "collapse": None, + "column_index": None, + "column_names": [ + "Date", + "Open", + "High", + "Low", + "Close", + "Volume", + "Ex-Dividend", + "Split Ratio", + "Adj. Open", + "Adj. High", + "Adj. Low", + "Adj. Close", + "Adj. Volume", + ], + "data": [ + [ + "2018-03-27", + 1063.9, + 1064.54, + 997.62, + 1006.94, + 2940957.0, + 0.0, + 1.0, + 1063.9, + 1064.54, + 997.62, + 1006.94, + 2940957.0, + ] + ], + "database_code": "WIKI", + "database_id": 4922, + "dataset_code": "GOOGL", + "description": "This dataset has no description.", + "end_date": "2018-03-27", + "frequency": "daily", + "id": 11304017, + "limit": 1, + "name": ( + "Alphabet Inc (GOOGL) Prices, Dividends, Splits and " "Trading Volume" + ), + "newest_available_date": "2018-03-27", + "oldest_available_date": "2004-08-19", + "order": None, + "premium": False, + "refreshed_at": "2018-03-27T21:46:11.201Z", + "start_date": "2004-08-19", + "transform": None, + "type": "Time Series", + } + } with response(contents): - srcprice = quandl.fetch_time_series('WIKI:FB', None) + srcprice = quandl.fetch_time_series("WIKI:FB", None) self.assertIsInstance(srcprice, source.SourcePrice) - self.assertEqual(Decimal('1006.94'), srcprice.price) - self.assertEqual(datetime.datetime(2018, 3, 27, 0, 0, 0, - tzinfo=tz.tzutc()), - srcprice.time.astimezone(tz.tzutc())) + self.assertEqual(Decimal("1006.94"), srcprice.price) + self.assertEqual( + datetime.datetime(2018, 3, 27, 0, 0, 0, tzinfo=tz.tzutc()), + srcprice.time.astimezone(tz.tzutc()), + ) self.assertEqual(None, srcprice.quote_currency) def test_valid_response(self): @@ -133,38 +162,40 @@ def test_valid_response(self): def test_non_standard_columns(self): contents = { - 'dataset': {'collapse': None, - 'column_index': None, - 'column_names': ['Date', - 'USD (AM)', - 'USD (PM)', - 'GBP (AM)', - 'GBP (PM)', - 'EURO (AM)', - 'EURO (PM)'], - 'data': [['2019-06-18', - 1344.55, - 1341.35, - 1073.22, - 1070.67, - 1201.89, - 1198.09]], - 'end_date': '2019-06-18', - 'frequency': 'daily', - 'order': None, - 'limit': 1, - 'start_date': '2019-06-08', - 'transform': None}} + "dataset": { + "collapse": None, + "column_index": None, + "column_names": [ + "Date", + "USD (AM)", + "USD (PM)", + "GBP (AM)", + "GBP (PM)", + "EURO (AM)", + "EURO (PM)", + ], + "data": [ + ["2019-06-18", 1344.55, 1341.35, 1073.22, 1070.67, 1201.89, 1198.09] + ], + "end_date": "2019-06-18", + "frequency": "daily", + "order": None, + "limit": 1, + "start_date": "2019-06-08", + "transform": None, + } + } with response(contents): - srcprice = quandl.fetch_time_series('LBMA:GOLD:USD_(PM)', None) + srcprice = quandl.fetch_time_series("LBMA:GOLD:USD_(PM)", None) self.assertIsInstance(srcprice, source.SourcePrice) - self.assertEqual(Decimal('1341.35'), srcprice.price) - self.assertEqual(datetime.datetime(2019, 6, 18, 0, 0, 0, - tzinfo=tz.tzutc()), - srcprice.time.astimezone(tz.tzutc())) + self.assertEqual(Decimal("1341.35"), srcprice.price) + self.assertEqual( + datetime.datetime(2019, 6, 18, 0, 0, 0, tzinfo=tz.tzutc()), + srcprice.time.astimezone(tz.tzutc()), + ) self.assertEqual(None, srcprice.quote_currency) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/beanprice/sources/ratesapi.py b/beanprice/sources/ratesapi.py index 3ee652d..e7a107e 100644 --- a/beanprice/sources/ratesapi.py +++ b/beanprice/sources/ratesapi.py @@ -27,6 +27,7 @@ class RatesApiError(ValueError): "An error from the Rates API." + def _parse_ticker(ticker): """Parse the base and quote currencies from the ticker. @@ -35,37 +36,37 @@ def _parse_ticker(ticker): Returns: A pair of (base, quote) currencies. """ - match = re.match(r'^(?P\w+)-(?P\w+)$', ticker) + match = re.match(r"^(?P\w+)-(?P\w+)$", ticker) if not match: - raise ValueError( - 'Invalid ticker. Use "BASE-SYMBOL" format.') + raise ValueError('Invalid ticker. Use "BASE-SYMBOL" format.') return match.groups() + def _get_quote(ticker, date): """Fetch a exchangerate from ratesapi.""" base, symbol = _parse_ticker(ticker) params = { - 'base': base, - 'symbol': symbol, + "base": base, + "symbol": symbol, } - response = requests.get(url='https://api.frankfurter.app/' + date, params=params) + response = requests.get(url="https://api.frankfurter.app/" + date, params=params) if response.status_code != requests.codes.ok: - raise RatesApiError("Invalid response ({}): {}".format(response.status_code, - response.text)) + raise RatesApiError( + "Invalid response ({}): {}".format(response.status_code, response.text) + ) result = response.json() - price = Decimal(str(result['rates'][symbol])) - time = parse(result['date']).replace(tzinfo=tz.tzutc()) + price = Decimal(str(result["rates"][symbol])) + time = parse(result["date"]).replace(tzinfo=tz.tzutc()) return source.SourcePrice(price, time, symbol) class Source(source.Source): - def get_latest_price(self, ticker): - return _get_quote(ticker, 'latest') + return _get_quote(ticker, "latest") def get_historical_price(self, ticker, time): return _get_quote(ticker, time.date().isoformat()) diff --git a/beanprice/sources/ratesapi_test.py b/beanprice/sources/ratesapi_test.py index 96718f4..26410aa 100644 --- a/beanprice/sources/ratesapi_test.py +++ b/beanprice/sources/ratesapi_test.py @@ -17,19 +17,18 @@ def response(contents, status_code=requests.codes.ok): response.status_code = status_code response.text = "" response.json.return_value = contents - return mock.patch('requests.get', return_value=response) + return mock.patch("requests.get", return_value=response) class RatesapiPriceFetcher(unittest.TestCase): - def test_error_invalid_ticker(self): with self.assertRaises(ValueError): - ratesapi.Source().get_latest_price('INVALID') + ratesapi.Source().get_latest_price("INVALID") def test_error_network(self): - with response('Foobar', 404): + with response("Foobar", 404): with self.assertRaises(ValueError): - ratesapi.Source().get_latest_price('EUR-CHF') + ratesapi.Source().get_latest_price("EUR-CHF") def test_valid_response(self): contents = { @@ -38,10 +37,10 @@ def test_valid_response(self): "date": "2019-04-20", } with response(contents): - srcprice = ratesapi.Source().get_latest_price('EUR-CHF') + srcprice = ratesapi.Source().get_latest_price("EUR-CHF") self.assertIsInstance(srcprice, source.SourcePrice) - self.assertEqual(Decimal('1.2001'), srcprice.price) - self.assertEqual('CHF', srcprice.quote_currency) + self.assertEqual(Decimal("1.2001"), srcprice.price) + self.assertEqual("CHF", srcprice.quote_currency) def test_historical_price(self): time = datetime.datetime(2018, 3, 27, 0, 0, 0, tzinfo=tz.tzutc()) @@ -51,13 +50,14 @@ def test_historical_price(self): "date": "2018-03-27", } with response(contents): - srcprice = ratesapi.Source().get_historical_price('EUR-CHF', time) + srcprice = ratesapi.Source().get_historical_price("EUR-CHF", time) self.assertIsInstance(srcprice, source.SourcePrice) - self.assertEqual(Decimal('1.2001'), srcprice.price) - self.assertEqual('CHF', srcprice.quote_currency) - self.assertEqual(datetime.datetime(2018, 3, 27, 0, 0, 0, tzinfo=tz.tzutc()), - srcprice.time) + self.assertEqual(Decimal("1.2001"), srcprice.price) + self.assertEqual("CHF", srcprice.quote_currency) + self.assertEqual( + datetime.datetime(2018, 3, 27, 0, 0, 0, tzinfo=tz.tzutc()), srcprice.time + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/beanprice/sources/tsp.py b/beanprice/sources/tsp.py index 9430529..29755c5 100644 --- a/beanprice/sources/tsp.py +++ b/beanprice/sources/tsp.py @@ -1,4 +1,4 @@ -""" Fetch prices from US Government Thrift Savings Plan +"""Fetch prices from US Government Thrift Savings Plan As of 7 July 2020, the Thrift Savings Plan (TSP) rolled out a new web site that has an API (instead of scraping a CSV). Unable to @@ -7,6 +7,7 @@ https://secure.tsp.gov/components/CORS/ """ + __copyright__ = "Copyright (C) 2020 Martin Blais" __license__ = "GNU GPLv2" @@ -20,40 +21,44 @@ from beanprice import source # All of the TSP funds are in USD. -CURRENCY = 'USD' +CURRENCY = "USD" -TIMEZONE = datetime.timezone(datetime.timedelta(hours=-4), 'America/New_York') +TIMEZONE = datetime.timezone(datetime.timedelta(hours=-4), "America/New_York") TSP_FUND_NAMES = [ - "LInco", #0 - "L2025", #1 - "L2030", #2 - "L2035", #3 - "L2040", #4 - "L2045", #5 - "L2050", #6 - "L2055", #7 - "L2060", #8 - "L2065", #9 - "GFund", #10 - "FFund", #11 - "CFund", #12 - "SFund", #13 - "IFund", #14 + "LInco", # 0 + "L2025", # 1 + "L2030", # 2 + "L2035", # 3 + "L2040", # 4 + "L2045", # 5 + "L2050", # 6 + "L2055", # 7 + "L2060", # 8 + "L2065", # 9 + "GFund", # 10 + "FFund", # 11 + "CFund", # 12 + "SFund", # 13 + "IFund", # 14 ] -csv.register_dialect('tsp', - delimiter=',', - quoting=csv.QUOTE_NONE, - # NOTE(blais): This fails to import in 3.12 (and perhaps before). - # quotechar='', - lineterminator='\n') +csv.register_dialect( + "tsp", + delimiter=",", + quoting=csv.QUOTE_NONE, + # NOTE(blais): This fails to import in 3.12 (and perhaps before). + # quotechar='', + lineterminator="\n", +) + class TSPError(ValueError): "An error from the Thrift Savings Plan (TSP) API." + def parse_tsp_csv(response: requests.models.Response) -> OrderedDict: - """ Parses a Thrift Savings Plan output CSV file. + """Parses a Thrift Savings Plan output CSV file. Function takes in a requests response and returns an OrderedDict with newest closing cost at front of OrderedDict. @@ -63,33 +68,37 @@ def parse_tsp_csv(response: requests.models.Response) -> OrderedDict: text = response.iter_lines(decode_unicode=True) - reader = csv.DictReader(text, dialect='tsp') + reader = csv.DictReader(text, dialect="tsp") for row in reader: # Date from TSP looks like "July 30. 2020" # There is indeed a period after the day of month. - date = datetime.datetime.strptime(row['Date'], "%b %d. %Y") + date = datetime.datetime.strptime(row["Date"], "%b %d. %Y") date = date.replace(hour=16, tzinfo=TIMEZONE) - names = ['L Income', - 'L 2025', - 'L 2030', - 'L 2035', - 'L 2040', - 'L 2045', - 'L 2050', - 'L 2055', - 'L 2060', - 'L 2065', - 'G Fund', - 'F Fund', - 'C Fund', - 'S Fund', - 'I Fund'] - data[date] = [Decimal(row[name]) if row[name] else Decimal() - for name in map(str.strip, names)] + names = [ + "L Income", + "L 2025", + "L 2030", + "L 2035", + "L 2040", + "L 2045", + "L 2050", + "L 2055", + "L 2060", + "L 2065", + "G Fund", + "F Fund", + "C Fund", + "S Fund", + "I Fund", + ] + data[date] = [ + Decimal(row[name]) if row[name] else Decimal() for name in map(str.strip, names) + ] return OrderedDict(sorted(data.items(), key=lambda t: t[0], reverse=True)) + def parse_response(response: requests.models.Response) -> OrderedDict: """Process as response from TSP. @@ -117,17 +126,18 @@ def get_historical_price(self, fund, time): if fund not in TSP_FUND_NAMES: raise TSPError( "Invalid TSP Fund Name '{}'. Valid Funds are:\n\t{}".format( - fund, - "\n\t".join(TSP_FUND_NAMES))) + fund, "\n\t".join(TSP_FUND_NAMES) + ) + ) url = "https://secure.tsp.gov/components/CORS/getSharePricesRaw.html" payload = { # Grabbing the last fourteen days of data in event the markets were closed. - 'startdate' : (time - datetime.timedelta(days=14)).strftime("%Y%m%d"), - 'enddate': time.strftime("%Y%m%d"), - 'download': '0', - 'Lfunds': '1', - 'InvFunds' : '1' + "startdate": (time - datetime.timedelta(days=14)).strftime("%Y%m%d"), + "enddate": time.strftime("%Y%m%d"), + "download": "0", + "Lfunds": "1", + "InvFunds": "1", } response = requests.get(url, params=payload) diff --git a/beanprice/sources/tsp_test.py b/beanprice/sources/tsp_test.py index 9e1323d..00fbff9 100644 --- a/beanprice/sources/tsp_test.py +++ b/beanprice/sources/tsp_test.py @@ -32,7 +32,8 @@ "Jul 14. 2020, 21.2898, 10.1398, 34.5110, 10.1829, 37.8542, 10.2115," " 22.0301, 10.2651, 10.2651, 10.2652, 16.4515, 21.0608, 47.2391, 53.8560, 30.0643\n" "Jul 15. 2020, 21.3513, 10.2067, 34.7862, 10.2723, 38.2174, 10.3170," - " 22.2736, 10.4025, 10.4026, 10.4027, 16.4519, 21.0574, 47.6702, 55.2910, 30.4751") + " 22.2736, 10.4025, 10.4026, 10.4027, 16.4519, 21.0574, 47.6702, 55.2910, 30.4751" +) HISTORIC_DATA = ( @@ -59,7 +60,8 @@ "Jun 18. 2020, 21.1562,, 33.9827,, 37.1713,, 21.5824,,,," "16.4432, 20.8718, 45.9718, 52.9328, 29.3908\n" "Jun 19. 2020, 21.1354,, 33.8890,, 37.0491,, 21.5018,,,," - "16.4435, 20.8742, 45.7171, 52.7196, 29.2879") + "16.4435, 20.8742, 45.7171, 52.7196, 29.2879" +) class MockResponse: @@ -72,63 +74,69 @@ def __init__(self, contents, status_code=requests.codes.ok): def iter_lines(self, decode_unicode=False): return iter(self._content.splitlines()) -class TSPFinancePriceFetcher(unittest.TestCase): +class TSPFinancePriceFetcher(unittest.TestCase): def test_get_latest_price_L2050(self): response = MockResponse(textwrap.dedent(CURRENT_DATA)) - with mock.patch('requests.get', return_value=response): - srcprice = tsp.Source().get_latest_price('L2050') + with mock.patch("requests.get", return_value=response): + srcprice = tsp.Source().get_latest_price("L2050") self.assertTrue(isinstance(srcprice.price, Decimal)) - self.assertEqual(Decimal('22.2736'), srcprice.price) - timezone = datetime.timezone(datetime.timedelta(hours=-4), 'America/New_York') - self.assertEqual(datetime.datetime(2020, 7, 15, 16, 0, 0, tzinfo=timezone), - srcprice.time) - self.assertEqual('USD', srcprice.quote_currency) + self.assertEqual(Decimal("22.2736"), srcprice.price) + timezone = datetime.timezone(datetime.timedelta(hours=-4), "America/New_York") + self.assertEqual( + datetime.datetime(2020, 7, 15, 16, 0, 0, tzinfo=timezone), srcprice.time + ) + self.assertEqual("USD", srcprice.quote_currency) def test_get_latest_price_SFund(self): response = MockResponse(textwrap.dedent(CURRENT_DATA)) - with mock.patch('requests.get', return_value=response): - srcprice = tsp.Source().get_latest_price('SFund') + with mock.patch("requests.get", return_value=response): + srcprice = tsp.Source().get_latest_price("SFund") self.assertTrue(isinstance(srcprice.price, Decimal)) - self.assertEqual(Decimal('55.2910'), srcprice.price) - timezone = datetime.timezone(datetime.timedelta(hours=-4), 'America/New_York') - self.assertEqual(datetime.datetime(2020, 7, 15, 16, 0, 0, tzinfo=timezone), - srcprice.time) - self.assertEqual('USD', srcprice.quote_currency) + self.assertEqual(Decimal("55.2910"), srcprice.price) + timezone = datetime.timezone(datetime.timedelta(hours=-4), "America/New_York") + self.assertEqual( + datetime.datetime(2020, 7, 15, 16, 0, 0, tzinfo=timezone), srcprice.time + ) + self.assertEqual("USD", srcprice.quote_currency) def test_get_historical_price(self): response = MockResponse(textwrap.dedent(HISTORIC_DATA)) - with mock.patch('requests.get', return_value=response): + with mock.patch("requests.get", return_value=response): srcprice = tsp.Source().get_historical_price( - 'CFund', time=datetime.datetime(2020, 6, 19)) + "CFund", time=datetime.datetime(2020, 6, 19) + ) self.assertTrue(isinstance(srcprice.price, Decimal)) - self.assertEqual(Decimal('45.7171'), srcprice.price) - timezone = datetime.timezone(datetime.timedelta(hours=-4), 'America/New_York') - self.assertEqual(datetime.datetime(2020, 6, 19, 16, 0, 0, tzinfo=timezone), - srcprice.time) - self.assertEqual('USD', srcprice.quote_currency) + self.assertEqual(Decimal("45.7171"), srcprice.price) + timezone = datetime.timezone(datetime.timedelta(hours=-4), "America/New_York") + self.assertEqual( + datetime.datetime(2020, 6, 19, 16, 0, 0, tzinfo=timezone), srcprice.time + ) + self.assertEqual("USD", srcprice.quote_currency) def test_get_historical_price_L2060(self): # This fund did not exist until 01 Jul 2020. Ensuring we get a Decimal(0.0) back. response = MockResponse(textwrap.dedent(HISTORIC_DATA)) - with mock.patch('requests.get', return_value=response): + with mock.patch("requests.get", return_value=response): srcprice = tsp.Source().get_historical_price( - 'L2060', time=datetime.datetime(2020, 6, 19)) + "L2060", time=datetime.datetime(2020, 6, 19) + ) self.assertTrue(isinstance(srcprice.price, Decimal)) - self.assertEqual(Decimal('0.0'), srcprice.price) - timezone = datetime.timezone(datetime.timedelta(hours=-4), 'America/New_York') - self.assertEqual(datetime.datetime(2020, 6, 19, 16, 0, 0, tzinfo=timezone), - srcprice.time) - self.assertEqual('USD', srcprice.quote_currency) + self.assertEqual(Decimal("0.0"), srcprice.price) + timezone = datetime.timezone(datetime.timedelta(hours=-4), "America/New_York") + self.assertEqual( + datetime.datetime(2020, 6, 19, 16, 0, 0, tzinfo=timezone), srcprice.time + ) + self.assertEqual("USD", srcprice.quote_currency) def test_invalid_fund_latest(self): with self.assertRaises(tsp.TSPError): - tsp.Source().get_latest_price('InvalidFund') + tsp.Source().get_latest_price("InvalidFund") def test_invalid_fund_historical(self): with self.assertRaises(tsp.TSPError): - tsp.Source().get_historical_price('InvalidFund', time=datetime.datetime.now()) + tsp.Source().get_historical_price("InvalidFund", time=datetime.datetime.now()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/beanprice/sources/yahoo.py b/beanprice/sources/yahoo.py index 9e327f2..6c4b28b 100644 --- a/beanprice/sources/yahoo.py +++ b/beanprice/sources/yahoo.py @@ -14,6 +14,7 @@ Timezone information: Input and output datetimes are specified via UNIX timestamps, but the timezone of the particular market is included in the output. """ + __copyright__ = "Copyright (C) 2015-2020 Martin Blais" __license__ = "GNU GPLv2" @@ -25,6 +26,7 @@ from beanprice import source + class YahooError(ValueError): "An error from the Yahoo API." @@ -38,72 +40,80 @@ def parse_response(response: requests.models.Response) -> Dict: json = response.json(parse_float=Decimal) content = next(iter(json.values())) if response.status_code != requests.codes.ok: - raise YahooError("Status {}: {}".format(response.status_code, content['error'])) + raise YahooError("Status {}: {}".format(response.status_code, content["error"])) if len(json) != 1: - raise YahooError("Invalid format in response from Yahoo; many keys: {}".format( - ','.join(json.keys()))) - if content['error'] is not None: - raise YahooError("Error fetching Yahoo data: {}".format(content['error'])) - if not content['result']: + raise YahooError( + "Invalid format in response from Yahoo; many keys: {}".format( + ",".join(json.keys()) + ) + ) + if content["error"] is not None: + raise YahooError("Error fetching Yahoo data: {}".format(content["error"])) + if not content["result"]: raise YahooError("No data returned from Yahoo, ensure that the symbol is correct") - return content['result'][0] + return content["result"][0] # Note: Feel free to suggest more here via a PR. _MARKETS = { - 'us_market': 'USD', - 'ca_market': 'CAD', - 'ch_market': 'CHF', + "us_market": "USD", + "ca_market": "CAD", + "ch_market": "CHF", } def parse_currency(result: Dict[str, Any]) -> Optional[str]: """Infer the currency from the result.""" - if 'market' not in result: + if "market" not in result: return None - return _MARKETS.get(result['market'], None) + return _MARKETS.get(result["market"], None) _DEFAULT_PARAMS = { - 'lang': 'en-US', - 'corsDomain': 'finance.yahoo.com', - '.tsrc': 'finance', + "lang": "en-US", + "corsDomain": "finance.yahoo.com", + ".tsrc": "finance", } -def get_price_series(ticker: str, - time_begin: datetime, - time_end: datetime) -> Tuple[List[Tuple[datetime, Decimal]], str]: +def get_price_series( + ticker: str, time_begin: datetime, time_end: datetime +) -> Tuple[List[Tuple[datetime, Decimal]], str]: """Return a series of timestamped prices.""" if requests is None: raise YahooError("You must install the 'requests' library.") url = "https://query1.finance.yahoo.com/v8/finance/chart/{}".format(ticker) payload: Dict[str, Union[int, str]] = { - 'period1': int(time_begin.timestamp()), - 'period2': int(time_end.timestamp()), - 'interval': '1d', + "period1": int(time_begin.timestamp()), + "period2": int(time_end.timestamp()), + "interval": "1d", } payload.update(_DEFAULT_PARAMS) - response = requests.get(url, params=payload, headers={'User-Agent': None}) + response = requests.get(url, params=payload, headers={"User-Agent": None}) result = parse_response(response) - meta = result['meta'] - tzone = timezone(timedelta(hours=meta['gmtoffset'] / 3600), - meta['exchangeTimezoneName']) + meta = result["meta"] + tzone = timezone( + timedelta(hours=meta["gmtoffset"] / 3600), meta["exchangeTimezoneName"] + ) - if 'timestamp' not in result: + if "timestamp" not in result: raise YahooError( "Yahoo returned no data for ticker {} for time range {} - {}".format( - ticker, time_begin, time_end)) - - timestamp_array = result['timestamp'] - close_array = result['indicators']['quote'][0]['close'] - series = [(datetime.fromtimestamp(timestamp, tz=tzone), Decimal(price)) - for timestamp, price in zip(timestamp_array, close_array) - if price is not None] - - currency = result['meta']['currency'] + ticker, time_begin, time_end + ) + ) + + timestamp_array = result["timestamp"] + close_array = result["indicators"]["quote"][0]["close"] + series = [ + (datetime.fromtimestamp(timestamp, tz=tzone), Decimal(price)) + for timestamp, price in zip(timestamp_array, close_array) + if price is not None + ] + + currency = result["meta"]["currency"] return series, currency @@ -114,21 +124,23 @@ def get_latest_price(self, ticker: str) -> Optional[source.SourcePrice]: """See contract in beanprice.source.Source.""" session = requests.Session() - session.headers.update({ - 'User-Agent': - 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/110.0' - }) + session.headers.update( + { + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:109.0) " + "Gecko/20100101 Firefox/110.0" + } + ) # This populates the correct cookies in the session - session.get('https://fc.yahoo.com') - crumb = session.get('https://query1.finance.yahoo.com/v1/test/getcrumb').text + session.get("https://fc.yahoo.com") + crumb = session.get("https://query1.finance.yahoo.com/v1/test/getcrumb").text url = "https://query1.finance.yahoo.com/v7/finance/quote" - fields = ['symbol', 'regularMarketPrice', 'regularMarketTime'] + fields = ["symbol", "regularMarketPrice", "regularMarketTime"] payload = { - 'symbols': ticker, - 'fields': ','.join(fields), - 'exchange': 'NYSE', - 'crumb': crumb, + "symbols": ticker, + "fields": ",".join(fields), + "exchange": "NYSE", + "crumb": crumb, } payload.update(_DEFAULT_PARAMS) response = session.get(url, params=payload) @@ -139,23 +151,25 @@ def get_latest_price(self, ticker: str) -> Optional[source.SourcePrice]: # but the user definitely needs to know which ticker failed! raise YahooError("%s (ticker: %s)" % (error, ticker)) from error try: - price = Decimal(result['regularMarketPrice']) + price = Decimal(result["regularMarketPrice"]) tzone = timezone( - timedelta(hours=result['gmtOffSetMilliseconds'] / 3600000), - result['exchangeTimezoneName']) - trade_time = datetime.fromtimestamp(result['regularMarketTime'], - tz=tzone) + timedelta(hours=result["gmtOffSetMilliseconds"] / 3600000), + result["exchangeTimezoneName"], + ) + trade_time = datetime.fromtimestamp(result["regularMarketTime"], tz=tzone) except KeyError as exc: - raise YahooError("Invalid response from Yahoo: {}".format( - repr(result))) from exc + raise YahooError( + "Invalid response from Yahoo: {}".format(repr(result)) + ) from exc currency = parse_currency(result) return source.SourcePrice(price, trade_time, currency) - def get_historical_price(self, ticker: str, - time: datetime) -> Optional[source.SourcePrice]: + def get_historical_price( + self, ticker: str, time: datetime + ) -> Optional[source.SourcePrice]: """See contract in beanprice.source.Source.""" # Get the latest data returned over the last 5 days. @@ -170,11 +184,9 @@ def get_historical_price(self, ticker: str, return source.SourcePrice(price, data_dt, currency) - def get_daily_prices(self, - ticker: str, - time_begin: datetime, - time_end: datetime) -> Optional[List[source.SourcePrice]]: + def get_daily_prices( + self, ticker: str, time_begin: datetime, time_end: datetime + ) -> Optional[List[source.SourcePrice]]: """See contract in beanprice.source.Source.""" series, currency = get_price_series(ticker, time_begin, time_end) - return [source.SourcePrice(price, time, currency) - for time, price in series] + return [source.SourcePrice(price, time, currency) for time, price in series] diff --git a/experiments/dividends/download_dividends.py b/experiments/dividends/download_dividends.py index b195673..6c2b493 100755 --- a/experiments/dividends/download_dividends.py +++ b/experiments/dividends/download_dividends.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -"""Download all dividends in a particular date interval. -""" +"""Download all dividends in a particular date interval.""" __copyright__ = "Copyright (C) 2020 Martin Blais" __license__ = "GNU GPLv2" @@ -18,16 +17,18 @@ import requests -def download_dividends(instrument: str, - start_date: Date, - end_date: Date) -> List[Tuple[Date, Decimal]]: +def download_dividends( + instrument: str, start_date: Date, end_date: Date +) -> List[Tuple[Date, Decimal]]: """Download a list of dividends issued over a time interval.""" tim = datetime.time() - payload = {'period1': str(int(datetime.datetime.combine(start_date, tim).timestamp())), - 'period2': str(int(datetime.datetime.combine(end_date, tim).timestamp())), - 'interval': '1d', - 'events': 'div', - 'includeAdjustedClose': 'true'} + payload = { + "period1": str(int(datetime.datetime.combine(start_date, tim).timestamp())), + "period2": str(int(datetime.datetime.combine(end_date, tim).timestamp())), + "interval": "1d", + "events": "div", + "includeAdjustedClose": "true", + } template = " https://query1.finance.yahoo.com/v7/finance/download/{ticker}" url = template.format(ticker=instrument) resp = requests.get(url, params=payload) @@ -37,8 +38,9 @@ def download_dividends(instrument: str, rows = iter(csv.reader(io.StringIO(resp.text))) header = next(rows) if header != ["Date", "Dividends"]: - raise ValueError("Error fetching dividends: " - "invalid response format: {}".format(header)) + raise ValueError( + "Error fetching dividends: " "invalid response format: {}".format(header) + ) dividends = [] for row in rows: @@ -52,16 +54,21 @@ def main(): """Top-level function.""" today = datetime.date.today() parser = argparse.ArgumentParser(description=__doc__.strip()) - parser.add_argument('instrument', - help="Yahoo!Finance code for financial instrument.") - parser.add_argument('start', action='store', - type=lambda x: dateutil.parser.parse(x).date(), - default=today.replace(year=today.year-1), - help="Start date of interval. Default is one year ago.") - parser.add_argument('end', action='store', - type=lambda x: dateutil.parser.parse(x).date(), - default=today, - help="End date of interval. Default is today ago.") + parser.add_argument("instrument", help="Yahoo!Finance code for financial instrument.") + parser.add_argument( + "start", + action="store", + type=lambda x: dateutil.parser.parse(x).date(), + default=today.replace(year=today.year - 1), + help="Start date of interval. Default is one year ago.", + ) + parser.add_argument( + "end", + action="store", + type=lambda x: dateutil.parser.parse(x).date(), + default=today, + help="End date of interval. Default is today ago.", + ) args = parser.parse_args() @@ -69,5 +76,5 @@ def main(): pprint.pprint(dividends) -if __name__ == '__main__': +if __name__ == "__main__": main()