Skip to content

Commit

Permalink
Merge pull request #187 from octoenergy/optimize-datetime-range-inter…
Browse files Browse the repository at this point in the history
…section-union

Optimise FiniteDatetimeRange __lt__, intersection and union
  • Loading branch information
Peter554 authored Jan 22, 2025
2 parents c0c97ec + 58220aa commit 5f8e226
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 13 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Unreleased

- Improve the performance of `FiniteDatetimeRange.intersection`,
`FiniteDatetimeRange.union` and `FiniteDatetimeRange.__lt__` [#187](https://github.com/octoenergy/xocto/pull/187).

## V7.1.0 - 2025-01-13

- Add `ranges.any_gaps` function [#185](https://github.com/octoenergy/xocto/pull/185).
Expand Down
2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ test:
py.test --benchmark-skip

benchmark:
py.test --benchmark-only --benchmark-autosave --benchmark-compare
py.test --benchmark-only --benchmark-autosave --benchmark-compare --benchmark-group-by=func --benchmark-columns mean,rounds,iterations

mypy:
mypy
Expand Down
83 changes: 79 additions & 4 deletions tests/benchmarks/test_ranges.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import datetime
import random
from decimal import Decimal as D

import pytest

from xocto import ranges


Expand All @@ -13,15 +12,91 @@ def _shuffled(ranges_, *, seed=42):
return ranges_


@pytest.mark.benchmark(group="ranges.any_overlapping")
def test_any_overlapping(benchmark):
ranges_ = _shuffled([ranges.Range(D(i), D(i + 1)) for i in range(1000)])
any_overlapping = benchmark(ranges.any_overlapping, ranges_)
assert any_overlapping is False


@pytest.mark.benchmark(group="ranges.any_gaps")
def test_any_gaps(benchmark):
ranges_ = _shuffled([ranges.Range(D(i), D(i + 1)) for i in range(1000)])
any_overlapping = benchmark(ranges.any_gaps, ranges_)
assert any_overlapping is False


class TestFiniteDatetimeRange:
def test_intersection_is_none(self, benchmark):
r1 = ranges.FiniteDatetimeRange(
datetime.datetime(2020, 1, 1),
datetime.datetime(2020, 1, 2),
)
r2 = ranges.FiniteDatetimeRange(
datetime.datetime(2020, 1, 3),
datetime.datetime(2020, 1, 4),
)

result = benchmark(lambda: r2 & r1)

assert result is None

def test_intersection_is_not_none(self, benchmark):
r1 = ranges.FiniteDatetimeRange(
datetime.datetime(2020, 1, 1),
datetime.datetime(2020, 1, 3),
)
r2 = ranges.FiniteDatetimeRange(
datetime.datetime(2020, 1, 2),
datetime.datetime(2020, 1, 4),
)

result = benchmark(lambda: r2 & r1)

assert result == ranges.FiniteDatetimeRange(
datetime.datetime(2020, 1, 2),
datetime.datetime(2020, 1, 3),
)

def test_union_is_none(self, benchmark):
r1 = ranges.FiniteDatetimeRange(
datetime.datetime(2020, 1, 1),
datetime.datetime(2020, 1, 2),
)
r2 = ranges.FiniteDatetimeRange(
datetime.datetime(2020, 1, 3),
datetime.datetime(2020, 1, 4),
)

result = benchmark(lambda: r2 | r1)

assert result is None

def test_union_is_not_none(self, benchmark):
r1 = ranges.FiniteDatetimeRange(
datetime.datetime(2020, 1, 1),
datetime.datetime(2020, 1, 3),
)
r2 = ranges.FiniteDatetimeRange(
datetime.datetime(2020, 1, 2),
datetime.datetime(2020, 1, 4),
)

result = benchmark(lambda: r2 | r1)

assert result == ranges.FiniteDatetimeRange(
datetime.datetime(2020, 1, 1),
datetime.datetime(2020, 1, 4),
)

def test_sorting(self, benchmark):
sorted_ranges_ = []
dt = datetime.datetime(2020, 1, 1)
for _ in range(100_000):
sorted_ranges_.append(
ranges.FiniteDatetimeRange(dt, dt + datetime.timedelta(hours=1))
)
dt += datetime.timedelta(hours=1)

ranges_ = _shuffled(sorted_ranges_)

result = benchmark(lambda: sorted(ranges_))
assert result == sorted_ranges_
118 changes: 110 additions & 8 deletions tests/test_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,60 @@ def test_finite_range(self):


class TestFiniteDatetimeRange:
@pytest.mark.parametrize(
"r1, r2, expected",
[
[
ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 1),
end=datetime.datetime(2000, 1, 3),
),
ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 2),
end=datetime.datetime(2000, 1, 4),
),
True,
],
[
ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 2),
end=datetime.datetime(2000, 1, 4),
),
ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 1),
end=datetime.datetime(2000, 1, 3),
),
False,
],
[
ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 1),
end=datetime.datetime(2000, 1, 3),
),
ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 1),
end=datetime.datetime(2000, 1, 5),
),
# False, since only `start` is considered.
False,
],
[
ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 1),
end=datetime.datetime(2000, 1, 3),
),
ranges.Range(
start=None,
end=datetime.datetime(3000, 1, 1),
boundaries=ranges.RangeBoundaries.EXCLUSIVE_EXCLUSIVE,
),
False,
],
],
)
def test__lt__(self, r1, r2, expected):
assert (r1 < r2) is expected

class TestUnion:
def test_union_of_touching_ranges(self):
range = ranges.FiniteDatetimeRange(
Expand All @@ -1030,11 +1084,13 @@ def test_union_of_touching_ranges(self):
end=datetime.datetime(2000, 1, 3),
)

union = range | other

assert union == ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 1),
end=datetime.datetime(2000, 1, 3),
assert (
range | other
== other | range
== ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 1),
end=datetime.datetime(2000, 1, 3),
)
)

def test_union_of_disjoint_ranges(self):
Expand All @@ -1047,7 +1103,7 @@ def test_union_of_disjoint_ranges(self):
end=datetime.datetime(2020, 1, 2),
)

assert range | other is None
assert (range | other is None) and (other | range is None)

def test_union_of_overlapping_ranges(self):
range = ranges.FiniteDatetimeRange(
Expand All @@ -1059,13 +1115,59 @@ def test_union_of_overlapping_ranges(self):
end=datetime.datetime(2000, 1, 4),
)

union = range | other
assert (
range | other
== other | range
== ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 1),
end=datetime.datetime(2000, 1, 4),
)
)

class TestIntersection:
def test_intersection_of_touching_ranges(self):
range = ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 1),
end=datetime.datetime(2000, 1, 2),
)
other = ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 2),
end=datetime.datetime(2000, 1, 3),
)

assert (range & other is None) and (other & range is None)

def test_intersection_of_disjoint_ranges(self):
range = ranges.FiniteDateRange(
start=datetime.datetime(2000, 1, 1),
end=datetime.datetime(2000, 1, 2),
)
other = ranges.FiniteDatetimeRange(
start=datetime.datetime(2020, 1, 1),
end=datetime.datetime(2020, 1, 2),
)

assert (range & other is None) and (other & range is None)

assert union == ranges.FiniteDatetimeRange(
def test_intersection_of_overlapping_ranges(self):
range = ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 1),
end=datetime.datetime(2000, 1, 3),
)
other = ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 2),
end=datetime.datetime(2000, 1, 4),
)

assert (
range & other
== other & range
== ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 2),
end=datetime.datetime(2000, 1, 3),
)
)

class TestLocalize:
def test_converts_timezone(self):
# Create a datetime range in Sydney, which is
Expand Down
29 changes: 29 additions & 0 deletions xocto/ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,12 +837,31 @@ def __init__(self, start: datetime.datetime, end: datetime.datetime):
"""
super().__init__(start, end, boundaries=RangeBoundaries.INCLUSIVE_EXCLUSIVE)

def __lt__(self, other: Range[datetime.datetime]) -> bool:
# We're deliberately overriding the base class here for better performance.
if other.start is None:
# We don't need to check anything more if the other range
# is open-ended
return False
else:
return self.start < other.start

def intersection(
self, other: Range[datetime.datetime]
) -> Optional["FiniteDatetimeRange"]:
"""
Intersections with finite ranges will always be finite.
"""
if isinstance(other, FiniteDatetimeRange):
# We're deliberately overriding the base class here for better performance.
# We can simplify the implementation since we know we're dealing with finite
# ranges with INCLUSIVE_EXCLUSIVE bounds.
left, right = (self, other) if self.start < other.start else (other, self)
if left.end <= right.start:
return None
else:
return FiniteDatetimeRange(right.start, min(left.end, right.end))

base_intersection = super().intersection(other)
if base_intersection is None:
return None
Expand All @@ -854,6 +873,16 @@ def union(self, other: Range[datetime.datetime]) -> Optional["FiniteDatetimeRang
"""
Unions between two FiniteDatetimeRanges should produce a FiniteDatetimeRange.
"""
if isinstance(other, FiniteDatetimeRange):
# We're deliberately overriding the base class here for better performance.
# We can simplify the implementation since we know we're dealing with finite
# ranges with INCLUSIVE_EXCLUSIVE bounds.
left, right = (self, other) if self.start < other.start else (other, self)
if left.end < right.start:
return None
else:
return FiniteDatetimeRange(left.start, max(left.end, right.end))

try:
base_union = super().union(other)
except ValueError:
Expand Down

0 comments on commit 5f8e226

Please sign in to comment.