diff --git a/CHANGELOG.md b/CHANGELOG.md index ab22f90..5071532 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## Unreleased +- Improve the performance of `FiniteDatetimeRange.intersection`, + `FiniteDatetimeRange.union` and `FiniteDatetimeRange.__lt__` [#187](https://github.com/octoenergy/xocto/pull/187). + ## V7.1.0 - 2025-01-13 - Add `ranges.any_gaps` function [#185](https://github.com/octoenergy/xocto/pull/185). diff --git a/makefile b/makefile index 4e5a4d7..34d8fd0 100644 --- a/makefile +++ b/makefile @@ -17,7 +17,7 @@ test: py.test --benchmark-skip benchmark: - py.test --benchmark-only --benchmark-autosave --benchmark-compare + py.test --benchmark-only --benchmark-autosave --benchmark-compare --benchmark-group-by=func --benchmark-columns mean,rounds,iterations mypy: mypy diff --git a/tests/benchmarks/test_ranges.py b/tests/benchmarks/test_ranges.py index f572e8f..350bddd 100644 --- a/tests/benchmarks/test_ranges.py +++ b/tests/benchmarks/test_ranges.py @@ -1,8 +1,7 @@ +import datetime import random from decimal import Decimal as D -import pytest - from xocto import ranges @@ -13,15 +12,91 @@ def _shuffled(ranges_, *, seed=42): return ranges_ -@pytest.mark.benchmark(group="ranges.any_overlapping") def test_any_overlapping(benchmark): ranges_ = _shuffled([ranges.Range(D(i), D(i + 1)) for i in range(1000)]) any_overlapping = benchmark(ranges.any_overlapping, ranges_) assert any_overlapping is False -@pytest.mark.benchmark(group="ranges.any_gaps") def test_any_gaps(benchmark): ranges_ = _shuffled([ranges.Range(D(i), D(i + 1)) for i in range(1000)]) any_overlapping = benchmark(ranges.any_gaps, ranges_) assert any_overlapping is False + + +class TestFiniteDatetimeRange: + def test_intersection_is_none(self, benchmark): + r1 = ranges.FiniteDatetimeRange( + datetime.datetime(2020, 1, 1), + datetime.datetime(2020, 1, 2), + ) + r2 = ranges.FiniteDatetimeRange( + datetime.datetime(2020, 1, 3), + datetime.datetime(2020, 1, 4), + ) + + result = benchmark(lambda: r2 & r1) + + assert result is None + + def test_intersection_is_not_none(self, benchmark): + r1 = ranges.FiniteDatetimeRange( + datetime.datetime(2020, 1, 1), + datetime.datetime(2020, 1, 3), + ) + r2 = ranges.FiniteDatetimeRange( + datetime.datetime(2020, 1, 2), + datetime.datetime(2020, 1, 4), + ) + + result = benchmark(lambda: r2 & r1) + + assert result == ranges.FiniteDatetimeRange( + datetime.datetime(2020, 1, 2), + datetime.datetime(2020, 1, 3), + ) + + def test_union_is_none(self, benchmark): + r1 = ranges.FiniteDatetimeRange( + datetime.datetime(2020, 1, 1), + datetime.datetime(2020, 1, 2), + ) + r2 = ranges.FiniteDatetimeRange( + datetime.datetime(2020, 1, 3), + datetime.datetime(2020, 1, 4), + ) + + result = benchmark(lambda: r2 | r1) + + assert result is None + + def test_union_is_not_none(self, benchmark): + r1 = ranges.FiniteDatetimeRange( + datetime.datetime(2020, 1, 1), + datetime.datetime(2020, 1, 3), + ) + r2 = ranges.FiniteDatetimeRange( + datetime.datetime(2020, 1, 2), + datetime.datetime(2020, 1, 4), + ) + + result = benchmark(lambda: r2 | r1) + + assert result == ranges.FiniteDatetimeRange( + datetime.datetime(2020, 1, 1), + datetime.datetime(2020, 1, 4), + ) + + def test_sorting(self, benchmark): + sorted_ranges_ = [] + dt = datetime.datetime(2020, 1, 1) + for _ in range(100_000): + sorted_ranges_.append( + ranges.FiniteDatetimeRange(dt, dt + datetime.timedelta(hours=1)) + ) + dt += datetime.timedelta(hours=1) + + ranges_ = _shuffled(sorted_ranges_) + + result = benchmark(lambda: sorted(ranges_)) + assert result == sorted_ranges_ diff --git a/tests/test_ranges.py b/tests/test_ranges.py index 3c81b04..4550d4d 100644 --- a/tests/test_ranges.py +++ b/tests/test_ranges.py @@ -1019,6 +1019,60 @@ def test_finite_range(self): class TestFiniteDatetimeRange: + @pytest.mark.parametrize( + "r1, r2, expected", + [ + [ + ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 1), + end=datetime.datetime(2000, 1, 3), + ), + ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 2), + end=datetime.datetime(2000, 1, 4), + ), + True, + ], + [ + ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 2), + end=datetime.datetime(2000, 1, 4), + ), + ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 1), + end=datetime.datetime(2000, 1, 3), + ), + False, + ], + [ + ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 1), + end=datetime.datetime(2000, 1, 3), + ), + ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 1), + end=datetime.datetime(2000, 1, 5), + ), + # False, since only `start` is considered. + False, + ], + [ + ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 1), + end=datetime.datetime(2000, 1, 3), + ), + ranges.Range( + start=None, + end=datetime.datetime(3000, 1, 1), + boundaries=ranges.RangeBoundaries.EXCLUSIVE_EXCLUSIVE, + ), + False, + ], + ], + ) + def test__lt__(self, r1, r2, expected): + assert (r1 < r2) is expected + class TestUnion: def test_union_of_touching_ranges(self): range = ranges.FiniteDatetimeRange( @@ -1030,11 +1084,13 @@ def test_union_of_touching_ranges(self): end=datetime.datetime(2000, 1, 3), ) - union = range | other - - assert union == ranges.FiniteDatetimeRange( - start=datetime.datetime(2000, 1, 1), - end=datetime.datetime(2000, 1, 3), + assert ( + range | other + == other | range + == ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 1), + end=datetime.datetime(2000, 1, 3), + ) ) def test_union_of_disjoint_ranges(self): @@ -1047,7 +1103,7 @@ def test_union_of_disjoint_ranges(self): end=datetime.datetime(2020, 1, 2), ) - assert range | other is None + assert (range | other is None) and (other | range is None) def test_union_of_overlapping_ranges(self): range = ranges.FiniteDatetimeRange( @@ -1059,13 +1115,59 @@ def test_union_of_overlapping_ranges(self): end=datetime.datetime(2000, 1, 4), ) - union = range | other + assert ( + range | other + == other | range + == ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 1), + end=datetime.datetime(2000, 1, 4), + ) + ) + + class TestIntersection: + def test_intersection_of_touching_ranges(self): + range = ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 1), + end=datetime.datetime(2000, 1, 2), + ) + other = ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 2), + end=datetime.datetime(2000, 1, 3), + ) + + assert (range & other is None) and (other & range is None) + + def test_intersection_of_disjoint_ranges(self): + range = ranges.FiniteDateRange( + start=datetime.datetime(2000, 1, 1), + end=datetime.datetime(2000, 1, 2), + ) + other = ranges.FiniteDatetimeRange( + start=datetime.datetime(2020, 1, 1), + end=datetime.datetime(2020, 1, 2), + ) + + assert (range & other is None) and (other & range is None) - assert union == ranges.FiniteDatetimeRange( + def test_intersection_of_overlapping_ranges(self): + range = ranges.FiniteDatetimeRange( start=datetime.datetime(2000, 1, 1), + end=datetime.datetime(2000, 1, 3), + ) + other = ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 2), end=datetime.datetime(2000, 1, 4), ) + assert ( + range & other + == other & range + == ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 2), + end=datetime.datetime(2000, 1, 3), + ) + ) + class TestLocalize: def test_converts_timezone(self): # Create a datetime range in Sydney, which is diff --git a/xocto/ranges.py b/xocto/ranges.py index add25fe..34581e9 100644 --- a/xocto/ranges.py +++ b/xocto/ranges.py @@ -837,12 +837,31 @@ def __init__(self, start: datetime.datetime, end: datetime.datetime): """ super().__init__(start, end, boundaries=RangeBoundaries.INCLUSIVE_EXCLUSIVE) + def __lt__(self, other: Range[datetime.datetime]) -> bool: + # We're deliberately overriding the base class here for better performance. + if other.start is None: + # We don't need to check anything more if the other range + # is open-ended + return False + else: + return self.start < other.start + def intersection( self, other: Range[datetime.datetime] ) -> Optional["FiniteDatetimeRange"]: """ Intersections with finite ranges will always be finite. """ + if isinstance(other, FiniteDatetimeRange): + # We're deliberately overriding the base class here for better performance. + # We can simplify the implementation since we know we're dealing with finite + # ranges with INCLUSIVE_EXCLUSIVE bounds. + left, right = (self, other) if self.start < other.start else (other, self) + if left.end <= right.start: + return None + else: + return FiniteDatetimeRange(right.start, min(left.end, right.end)) + base_intersection = super().intersection(other) if base_intersection is None: return None @@ -854,6 +873,16 @@ def union(self, other: Range[datetime.datetime]) -> Optional["FiniteDatetimeRang """ Unions between two FiniteDatetimeRanges should produce a FiniteDatetimeRange. """ + if isinstance(other, FiniteDatetimeRange): + # We're deliberately overriding the base class here for better performance. + # We can simplify the implementation since we know we're dealing with finite + # ranges with INCLUSIVE_EXCLUSIVE bounds. + left, right = (self, other) if self.start < other.start else (other, self) + if left.end < right.start: + return None + else: + return FiniteDatetimeRange(left.start, max(left.end, right.end)) + try: base_union = super().union(other) except ValueError: