import json
from textwrap import dedent

from meshroom.core import desc
from meshroom.core.graph import Graph
from meshroom.core.node import CompatibilityIssue

from .utils import registeredNodeTypes, overrideNodeTypeVersion


class SimpleNode(desc.Node):
    inputs = [
        desc.File(name="input", label="Input", description="", value=""),
    ]
    outputs = [
        desc.File(name="output", label="Output", description="", value=""),
    ]


class NodeWithListAttributes(desc.Node):
    inputs = [
        desc.ListAttribute(
            name="listInput",
            label="List Input",
            description="",
            elementDesc=desc.File(name="file", label="File", description="", value=""),
            exposed=True,
        ),
        desc.GroupAttribute(
            name="group",
            label="Group",
            description="",
            groupDesc=[
                desc.ListAttribute(
                    name="listInput",
                    label="List Input",
                    description="",
                    elementDesc=desc.File(name="file", label="File", description="", value=""),
                    exposed=True,
                ),
            ],
        ),
    ]


def compareGraphsContent(graphA: Graph, graphB: Graph) -> bool:
    """Returns whether the content (node and deges) of two graphs are considered identical.

    Similar nodes: nodes with the same name, type and compatibility status.
    Similar edges: edges with the same source and destination attribute names.
    """

    def _buildNodesSet(graph: Graph):
        return set([(node.name, node.nodeType, node.isCompatibilityNode) for node in graph.nodes])

    def _buildEdgesSet(graph: Graph):
        return set([(edge.src.fullName, edge.dst.fullName) for edge in graph.edges])

    nodesSetA, edgesSetA = _buildNodesSet(graphA), _buildEdgesSet(graphA)
    nodesSetB, edgesSetB = _buildNodesSet(graphB), _buildEdgesSet(graphB)

    return nodesSetA == nodesSetB and edgesSetA == edgesSetB


class TestImportGraphContent:
    def test_importEmptyGraph(self):
        graph = Graph("")

        otherGraph = Graph("")
        nodes = otherGraph.importGraphContent(graph)

        assert len(nodes) == 0
        assert len(graph.nodes) == 0

    def test_importGraphWithSingleNode(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode]):
            graph.addNewNode(SimpleNode.__name__)

            otherGraph = Graph("")
            otherGraph.importGraphContent(graph)

            assert compareGraphsContent(graph, otherGraph)

    def test_importGraphWithSeveralNodes(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode]):
            graph.addNewNode(SimpleNode.__name__)
            graph.addNewNode(SimpleNode.__name__)

            otherGraph = Graph("")
            otherGraph.importGraphContent(graph)

            assert compareGraphsContent(graph, otherGraph)

    def test_importingGraphWithNodesAndEdges(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode]):
            nodeA_1 = graph.addNewNode(SimpleNode.__name__)
            nodeA_2 = graph.addNewNode(SimpleNode.__name__)

            graph.addEdge(nodeA_1.output, nodeA_2.input)

            otherGraph = Graph("")
            otherGraph.importGraphContent(graph)
            assert compareGraphsContent(graph, otherGraph)

    def test_edgeRemappingOnImportingGraphSeveralTimes(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode]):
            nodeA_1 = graph.addNewNode(SimpleNode.__name__)
            nodeA_2 = graph.addNewNode(SimpleNode.__name__)

            graph.addEdge(nodeA_1.output, nodeA_2.input)

            otherGraph = Graph("")
            otherGraph.importGraphContent(graph)
            otherGraph.importGraphContent(graph)

    def test_edgeRemappingOnImportingGraphWithUnkownNodeTypesSeveralTimes(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode]):
            nodeA_1 = graph.addNewNode(SimpleNode.__name__)
            nodeA_2 = graph.addNewNode(SimpleNode.__name__)

            graph.addEdge(nodeA_1.output, nodeA_2.input)

        otherGraph = Graph("")
        otherGraph.importGraphContent(graph)
        otherGraph.importGraphContent(graph)

        assert len(otherGraph.nodes) == 4
        assert len(otherGraph.compatibilityNodes) == 4
        assert len(otherGraph.edges) == 2

    def test_importGraphWithUnknownNodeTypesCreatesCompatibilityNodes(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode]):
            graph.addNewNode(SimpleNode.__name__)

        otherGraph = Graph("")
        importedNode = otherGraph.importGraphContent(graph)

        assert len(importedNode) == 1
        assert importedNode[0].isCompatibilityNode

    def test_importGraphContentInPlace(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode]):
            nodeA_1 = graph.addNewNode(SimpleNode.__name__)
            nodeA_2 = graph.addNewNode(SimpleNode.__name__)

            graph.addEdge(nodeA_1.output, nodeA_2.input)

            graph.importGraphContent(graph)

            assert len(graph.nodes) == 4

    def test_importGraphContentFromFile(self, graphSavedOnDisk):
        graph: Graph = graphSavedOnDisk

        with registeredNodeTypes([SimpleNode]):
            nodeA_1 = graph.addNewNode(SimpleNode.__name__)
            nodeA_2 = graph.addNewNode(SimpleNode.__name__)

            graph.addEdge(nodeA_1.output, nodeA_2.input)
            graph.save()

            otherGraph = Graph("")
            nodes = otherGraph.importGraphContentFromFile(graph.filepath)

            assert len(nodes) == 2

            assert compareGraphsContent(graph, otherGraph)

    def test_importGraphContentFromFileWithCompatibilityNodes(self, graphSavedOnDisk):
        graph: Graph = graphSavedOnDisk

        with registeredNodeTypes([SimpleNode]):
            nodeA_1 = graph.addNewNode(SimpleNode.__name__)
            nodeA_2 = graph.addNewNode(SimpleNode.__name__)

            graph.addEdge(nodeA_1.output, nodeA_2.input)
            graph.save()

        otherGraph = Graph("")
        nodes = otherGraph.importGraphContentFromFile(graph.filepath)

        assert len(nodes) == 2
        assert len(otherGraph.compatibilityNodes) == 2
        assert not compareGraphsContent(graph, otherGraph)

    def test_importingDifferentNodeVersionCreatesCompatibilityNodes(self, graphSavedOnDisk):
        graph: Graph = graphSavedOnDisk

        with registeredNodeTypes([SimpleNode]):
            with overrideNodeTypeVersion(SimpleNode, "1.0"):
                node = graph.addNewNode(SimpleNode.__name__)
                graph.save()

            with overrideNodeTypeVersion(SimpleNode, "2.0"):
                otherGraph = Graph("")
                nodes = otherGraph.importGraphContentFromFile(graph.filepath)

        assert len(nodes) == 1
        assert len(otherGraph.compatibilityNodes) == 1
        assert otherGraph.node(node.name).issue is CompatibilityIssue.VersionConflict

class TestGraphPartialSerialization:
    def test_emptyGraph(self):
        graph = Graph("")
        serializedGraph = graph.serializePartial([])

        otherGraph = Graph("")
        otherGraph._deserialize(serializedGraph)
        assert compareGraphsContent(graph, otherGraph)

    def test_serializeAllNodesIsSimilarToStandardSerialization(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode]):
            nodeA = graph.addNewNode(SimpleNode.__name__)
            nodeB = graph.addNewNode(SimpleNode.__name__)

            graph.addEdge(nodeA.output, nodeB.input)

            partialSerializedGraph = graph.serializePartial([nodeA, nodeB])
            standardSerializedGraph = graph.serialize()

            graphA = Graph("")
            graphA._deserialize(partialSerializedGraph)

            graphB = Graph("")
            graphB._deserialize(standardSerializedGraph)

            assert compareGraphsContent(graph, graphA)
            assert compareGraphsContent(graphA, graphB)

    def test_listAttributeToListAttributeConnectionIsSerialized(self):
        graph = Graph("")

        with registeredNodeTypes([NodeWithListAttributes]):
            nodeA = graph.addNewNode(NodeWithListAttributes.__name__)
            nodeB = graph.addNewNode(NodeWithListAttributes.__name__)

            graph.addEdge(nodeA.listInput, nodeB.listInput)

            otherGraph = Graph("")
            otherGraph._deserialize(graph.serializePartial([nodeA, nodeB]))

            assert otherGraph.node(nodeB.name).listInput.linkParam == otherGraph.node(nodeA.name).listInput

    def test_singleNodeWithInputConnectionFromNonSerializedNodeRemovesEdge(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode]):
            nodeA = graph.addNewNode(SimpleNode.__name__)
            nodeB = graph.addNewNode(SimpleNode.__name__)

            graph.addEdge(nodeA.output, nodeB.input)

            serializedGraph = graph.serializePartial([nodeB])

            otherGraph = Graph("")
            otherGraph._deserialize(serializedGraph)

            assert len(otherGraph.compatibilityNodes) == 0
            assert len(otherGraph.nodes) == 1
            assert len(otherGraph.edges) == 0

    def test_serializeSingleNodeWithInputConnectionToListAttributeRemovesListEntry(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode, NodeWithListAttributes]):
            nodeA = graph.addNewNode(SimpleNode.__name__)
            nodeB = graph.addNewNode(NodeWithListAttributes.__name__)

            nodeB.listInput.append("")
            graph.addEdge(nodeA.output, nodeB.listInput.at(0))

            otherGraph = Graph("")
            otherGraph._deserialize(graph.serializePartial([nodeB]))

            assert len(otherGraph.node(nodeB.name).listInput) == 0

    def test_serializeSingleNodeWithInputConnectionToNestedListAttributeRemovesListEntry(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode, NodeWithListAttributes]):
            nodeA = graph.addNewNode(SimpleNode.__name__)
            nodeB = graph.addNewNode(NodeWithListAttributes.__name__)

            nodeB.group.listInput.append("")
            graph.addEdge(nodeA.output, nodeB.group.listInput.at(0))

            otherGraph = Graph("")
            otherGraph._deserialize(graph.serializePartial([nodeB]))

            assert len(otherGraph.node(nodeB.name).group.listInput) == 0


class TestGraphCopy:
    def test_graphCopyIsIdenticalToOriginalGraph(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode]):
            nodeA = graph.addNewNode(SimpleNode.__name__)
            nodeB = graph.addNewNode(SimpleNode.__name__)

            graph.addEdge(nodeA.output, nodeB.input)

            graphCopy = graph.copy()
            assert compareGraphsContent(graph, graphCopy)

    def test_graphCopyWithUnknownNodeTypesDiffersFromOriginalGraph(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode]):
            nodeA = graph.addNewNode(SimpleNode.__name__)
            nodeB = graph.addNewNode(SimpleNode.__name__)

            graph.addEdge(nodeA.output, nodeB.input)

        graphCopy = graph.copy()
        assert not compareGraphsContent(graph, graphCopy)


class TestImportGraphContentFromMinimalGraphData:
    def test_nodeWithoutVersionInfoIsUpgraded(self):
        graph = Graph("")

        with (
            registeredNodeTypes([SimpleNode]),
            overrideNodeTypeVersion(SimpleNode, "2.0"),
        ):
            sampleGraphContent = dedent("""
            {
                "SimpleNode_1": { "nodeType": "SimpleNode" }
            }
            """)
            graph._deserialize(json.loads(sampleGraphContent))

            assert len(graph.nodes) == 1
            assert len(graph.compatibilityNodes) == 0

    def test_connectionsToMissingNodesAreDiscarded(self):
        graph = Graph("")

        with registeredNodeTypes([SimpleNode]):
            sampleGraphContent = dedent("""
            {
                "SimpleNode_1": { 
                    "nodeType": "SimpleNode", "inputs": { "input": "{NotSerializedNode.output}" } 
                }
            }
            """)
            graph._deserialize(json.loads(sampleGraphContent))