diff --git a/map_machine/road.py b/map_machine/road.py index 9cff773..0cd6c52 100644 --- a/map_machine/road.py +++ b/map_machine/road.py @@ -492,21 +492,17 @@ class Connector: def __init__( self, - road_1: Road, - index_1: int, - road_2: Road, - index_2: int, - _flinger: Flinger, + connections: list[tuple[Road, int]], + flinger: Flinger, scale: float, ) -> None: - self.road_1: Road = road_1 - self.road_2: Road = road_2 + self.connections: list[tuple[Road, int]] = connections + self.road_1: Road = connections[0][0] + self.index_1: int = connections[0][1] - self.index_1: int = index_1 - self.index_2: int = index_2 - - self.layer: float = min(road_1.layer, road_2.layer) + self.layer: float = min(x[0].layer for x in connections) self.scale: float = scale + self.flinger: Flinger = flinger def draw(self, svg: Drawing) -> None: """Draw connection fill.""" @@ -524,20 +520,21 @@ class SimpleConnector(Connector): def __init__( self, - road_1: Road, - index_1: int, - road_2: Road, - index_2: int, + connections: list[tuple[Road, int]], flinger: Flinger, scale: float, ) -> None: - super().__init__(road_1, index_1, road_2, index_2, flinger, scale) - node: OSMNode = road_1.nodes[index_1] + super().__init__(connections, flinger, scale) + + self.road_2: Road = connections[1][0] + self.index_2: int = connections[1][1] + + node: OSMNode = self.road_1.nodes[self.index_1] self.point: np.ndarray = flinger.fling(node.coordinates) def draw(self, svg: Drawing) -> None: """Draw connection fill.""" - circle = svg.circle( + circle: Circle = svg.circle( self.point, self.road_1.width * self.scale / 2, fill=self.road_1.matcher.color.hex, @@ -546,7 +543,7 @@ class SimpleConnector(Connector): def draw_border(self, svg: Drawing) -> None: """Draw connection outline.""" - circle = svg.circle( + circle: Circle = svg.circle( self.point, self.road_1.width * self.scale / 2 + 1, fill=self.road_1.matcher.border_color.hex, @@ -561,27 +558,27 @@ class ComplexConnector(Connector): def __init__( self, - road_1: Road, - index_1: int, - road_2: Road, - index_2: int, + connections: list[tuple[Road, int]], flinger: Flinger, scale: float, ) -> None: - super().__init__(road_1, index_1, road_2, index_2, flinger, scale) + super().__init__(connections, flinger, scale) - length: float = abs(road_2.width - road_1.width) * scale - road_1.line.shorten(index_1, length) - road_2.line.shorten(index_2, length) + self.road_2: Road = connections[1][0] + self.index_2: int = connections[1][1] - node: OSMNode = road_1.nodes[index_1] + length: float = abs(self.road_2.width - self.road_1.width) * scale + self.road_1.line.shorten(self.index_1, length) + self.road_2.line.shorten(self.index_2, length) + + node: OSMNode = self.road_1.nodes[self.index_1] point: np.ndarray = flinger.fling(node.coordinates) points_1: list[np.ndarray] = get_curve_points( - road_1, scale, point, road_1.line.points[index_1] + self.road_1, scale, point, self.road_1.line.points[self.index_1] ) points_2: list[np.ndarray] = get_curve_points( - road_2, scale, point, road_2.line.points[index_2] + self.road_2, scale, point, self.road_2.line.points[self.index_2] ) # fmt: off self.curve_1: PathCommands = [ @@ -622,6 +619,42 @@ class ComplexConnector(Connector): svg.add(path) +class SimpleIntersection(Connector): + """ + Connection between more than two roads. + """ + + def __init__( + self, + connections: list[tuple[Road, int]], + flinger: Flinger, + scale: float, + ) -> None: + super().__init__(connections, flinger, scale) + + def draw(self, svg: Drawing) -> None: + """Draw connection fill.""" + for road, index in self.connections: + node: OSMNode = self.road_1.nodes[self.index_1] + point: np.ndarray = self.flinger.fling(node.coordinates) + circle: Circle = svg.circle( + point, road.width * self.scale / 2, fill=road.matcher.color.hex + ) + svg.add(circle) + + def draw_border(self, svg: Drawing) -> None: + """Draw connection outline.""" + for road, index in self.connections: + node: OSMNode = self.road_1.nodes[self.index_1] + point: np.ndarray = self.flinger.fling(node.coordinates) + circle: Circle = svg.circle( + point, + road.width * self.scale / 2 + 1, + fill=road.matcher.border_color.hex, + ) + svg.add(circle) + + class Roads: """ Whole road structure. @@ -655,20 +688,18 @@ class Roads: layered_roads[road.layer].append(road) for id_ in self.connections: - if len(self.connections[id_]) != 2: - continue connected: list[tuple[Road, int]] = self.connections[id_] - road_1, index_1 = connected[0] - road_2, index_2 = connected[1] connector: Connector - if road_1.width == road_2.width: - connector = SimpleConnector( - road_1, index_1, road_2, index_2, flinger, scale - ) + + if len(self.connections[id_]) == 2: + road_1, _ = connected[0] + road_2, _ = connected[1] + if road_1.width == road_2.width: + connector = SimpleConnector(connected, flinger, scale) + else: + connector = ComplexConnector(connected, flinger, scale) else: - connector = ComplexConnector( - road_1, index_1, road_2, index_2, flinger, scale - ) + connector = SimpleIntersection(connected, flinger, scale) if connector.layer not in layered_connectors: layered_connectors[connector.layer] = []