Skip to content

Commit

Permalink
Implement interval tree
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter554 committed Feb 2, 2025
1 parent 9b1a335 commit 64fb3fa
Show file tree
Hide file tree
Showing 2 changed files with 338 additions and 0 deletions.
200 changes: 200 additions & 0 deletions tests/test_intervaltree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import pytest

from xocto.intervaltree import Interval, IntervalTree


class TestIntervalTree:
class TestStart:
def test_returns_none_if_tree_is_empty(self):
assert IntervalTree().start() is None

def test_returns_start_if_tree_is_not_empty(self):
tree = IntervalTree(
[
Interval(1, 2, "red"),
Interval(2, 3, "green"),
Interval(0, 1, "blue"),
]
)
assert tree.start() == 0

class TestEnd:
def test_returns_none_if_tree_is_empty(self):
assert IntervalTree().end() is None

def test_returns_end_if_tree_is_not_empty(self):
tree = IntervalTree(
[
Interval(1, 2, "red"),
Interval(2, 3, "green"),
Interval(0, 1, "blue"),
]
)
assert tree.end() == 3

class TestLen:
def test_returns_0_for_empty_tree(self):
assert len(IntervalTree()) == 0

def test_returns_length_of_tree(self):
tree = IntervalTree(
[
Interval(1, 2, "red"),
Interval(2, 3, "green"),
Interval(0, 1, "blue"),
]
)
assert len(tree) == 3

class TestIsEmpty:
def test_returns_true_if_tree_is_empty(self):
assert IntervalTree().is_empty()

def test_returns_false_if_tree_is_not_empty(self):
tree = IntervalTree([Interval(0, 1, "red")])
assert not tree.is_empty()

class TestAll:
def test_returns_all_the_intervals_sorted(self):
tree = IntervalTree(
[
Interval(1, 2, "red"),
Interval(2, 3, "green"),
Interval(0, 1, "blue"),
]
)
assert _as_tuples(tree.all()) == [
(0, 1, "blue"),
(1, 2, "red"),
(2, 3, "green"),
]

class TestInsert:
def test_inserts_into_tree(self):
tree = IntervalTree()
assert tree.is_empty()

tree.insert(Interval(0, 1, "red"))

assert _as_tuples(tree.all()) == [(0, 1, "red")]

def test_can_insert_same_interval_twice_with_different_data(self):
tree = IntervalTree()
assert tree.is_empty()

tree.insert(Interval(0, 1, "red"))
tree.insert(Interval(0, 1, "blue"))

assert len(tree) == 2
assert set(_as_tuples(tree.all())) == {(0, 1, "red"), (0, 1, "blue")}

def test_can_insert_same_interval_twice_with_equal_data(self):
tree = IntervalTree()
assert tree.is_empty()

tree.insert(Interval(0, 1, "red"))
tree.insert(Interval(0, 1, "red"))

assert len(tree) == 2
assert set(_as_tuples(tree.all())) == {(0, 1, "red"), (0, 1, "red")}

class TestRemove:
def test_removes_from_tree(self):
tree = IntervalTree()
interval_red = tree.insert(Interval(0, 1, "red"))
tree.insert(Interval(0, 1, "blue"))
assert len(tree) == 2

tree.remove(interval_red)
assert len(tree) == 1
assert _as_tuples(tree.all()) == [(0, 1, "blue")]

def test_only_removes_matching_interval(self):
tree = IntervalTree()
interval_red_1 = tree.insert(Interval(0, 1, "red"))
tree.insert(Interval(0, 1, "red"))
assert len(tree) == 2

tree.remove(interval_red_1)

assert len(tree) == 1
assert _as_tuples(tree.all()) == [(0, 1, "red")]

def test_does_not_error_if_interval_does_not_exist_in_tree(self):
tree = IntervalTree()
interval = tree.insert(Interval(0, 1, "red"))
assert not tree.is_empty()

tree.remove(interval)
assert tree.is_empty()

# Does not error.
tree.remove(interval)

class TestOverlapping:
@pytest.mark.parametrize(
"start, end, expected_results",
[
[0, 0, []],
[3, 3, []],
[
0,
2,
[
(0, 1, "blue"),
(1, 2, "red"),
],
],
[
1,
3,
[
(1, 2, "red"),
(2, 3, "green"),
],
],
[
0.5,
2.5,
[
(0, 1, "blue"),
(1, 2, "red"),
(2, 3, "green"),
],
],
],
)
def test_returns_intervals_overlapping_the_query(
self, start, end, expected_results
):
tree = IntervalTree(
[
Interval(1, 2, "red"),
Interval(2, 3, "green"),
Interval(0, 1, "blue"),
]
)
assert _as_tuples(tree.overlapping(start, end)) == expected_results

class TestOverlappingPoint:
@pytest.mark.parametrize(
"point, expected_results",
[
[0, [(0, 1, "blue")]],
[3, []],
[1.5, [(1, 2, "red")]],
],
)
def test_returns_intervals_overlapping_the_query(self, point, expected_results):
tree = IntervalTree(
[
Interval(1, 2, "red"),
Interval(2, 3, "green"),
Interval(0, 1, "blue"),
]
)
assert _as_tuples(tree.overlapping_point(point)) == expected_results


def _as_tuples(intervals):
return [(interval.start, interval.end, interval.data) for interval in intervals]
138 changes: 138 additions & 0 deletions xocto/intervaltree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import Any, Generic, Optional, TypeVar

import intervaltree # type: ignore[import-untyped]

from xocto.types.generic import Comparable


TBoundary = TypeVar("TBoundary", bound=Comparable[Any])
TData = TypeVar("TData")


class Interval(Generic[TBoundary, TData]):
_start: TBoundary
_end: TBoundary
_data: TData
_interval: Optional[intervaltree.IntervalTree]

__slots__ = ("_start", "_end", "_data", "_interval")

def __init__(
self,
start: TBoundary,
end: TBoundary,
data: TData,
) -> None:
if start > end:
raise ValueError("interval.start must be <= interval.end")
self._start = start
self._end = end
self._data = data
self._interval = None

@property
def start(self) -> TBoundary:
return self._start

@property
def end(self) -> TBoundary:
return self._end

@property
def data(self) -> TData:
return self._data

def __eq__(self, other: Any) -> bool:
# Do not simply check that fields match here.
# Using object equivalence allows us to insert the same interval multiple times.
return other is self


class IntervalTree(Generic[TBoundary, TData]):
"""
An interval tree.
https://en.wikipedia.org/wiki/Interval_tree
An interval is always inclusive-exclusive.
This allows for efficient querying or intervals overlapping a given interval/point.
Note that there is some cost to building the tree, so there is a tradeoff here.
"""

def __init__(
self, intervals: Optional[Iterable[Interval[TBoundary, TData]]] = None
) -> None:
interval_tree_intervals = [
intervaltree.Interval(interval.start, interval.end, interval)
for interval in (intervals or [])
]
for interval_tree_interval, interval in zip(
interval_tree_intervals, (intervals or [])
):
interval._interval = interval_tree_interval
self._tree = intervaltree.IntervalTree(interval_tree_intervals)

def insert(
self, interval: Interval[TBoundary, TData]
) -> Interval[TBoundary, TData]:
interval_tree_interval = intervaltree.Interval(
interval.start, interval.end, interval
)
interval._interval = interval_tree_interval
self._tree.add(interval_tree_interval)
return interval

def remove(self, interval: Interval[TBoundary, TData]) -> None:
try:
self._tree.remove(interval=interval._interval) # noqa: SLF001
except ValueError:
# If the interval wasn't present
return

def remove_all(self) -> None:
self._tree.clear()

def start(self) -> TBoundary | None:
if self._tree.is_empty():
return None
return self._tree.begin()

def end(self) -> TBoundary | None:
if self._tree.is_empty():
return None
return self._tree.end()

def __len__(self) -> int:
return len(self._tree)

def is_empty(self) -> bool:
return self._tree.is_empty()

def all(self) -> list[Interval[TBoundary, TData]]:
"""
Return all the intervals within the tree.
The returned intervals will be sorted.
"""
return [interval.data for interval in sorted(self._tree.items())]

def overlapping(
self, start: TBoundary, end: TBoundary
) -> list[Interval[TBoundary, TData]]:
"""
Return all the intervals overlapping the passed interval.
The returned intervals will be sorted.
"""
matching_intervals = self._tree.overlap(start, end)
return [interval.data for interval in sorted(matching_intervals)]

def overlapping_point(self, point: Any) -> list[Interval[TBoundary, TData]]:
"""
Return all the intervals overlapping the passed point.
The returned intervals will be sorted.
"""
matching_intervals = self._tree.at(point)
return [interval.data for interval in sorted(matching_intervals)]

0 comments on commit 64fb3fa

Please sign in to comment.