import { Edge, Node, useReactFlow } from "@xyflow/react";
import { HierarchyNode, stratify, tree } from "d3-hierarchy";
import compareNodeIds from "../helpers/compareNodeIds";
import usePersistEditedEdgesAndNodes from "./usePersistEditedEdgesAndNodes";

export default function useLayOutWithD3Hierarchy() {
  const persistEditedEdgesAndNodes = usePersistEditedEdgesAndNodes();
  const { fitView, setNodes } = useReactFlow();

  const g = tree().separation((a, b) => {
    const separation = (a.parent === b.parent ? 1 : 2) / a.depth;

    return separation;
  });

  const getLaidOutNodes = ({
    edges,
    nodes,
  }: {
    edges: Edge[];
    nodes: Node[];
  }) => {
    if (nodes.length === 0) return nodes;
    const hierarchy = stratify<Node>()
      .id((node) => {
        return node.id;
      })
      .parentId((node) => {
        const parentEdge = edges.find((edge) => edge.target === node.id);
        return parentEdge ? parentEdge.source : null;
      });

    const root = hierarchy(nodes);

    const rootHeight = root?.data?.measured?.height || 200;
    const rootWidth = root?.data?.measured?.width || 400;

    const maxDepth = Math.max(...root.descendants().map((d) => d.depth));
    const radius = 500 * maxDepth;

    const sortedRoot = sortRootByNodeIds(root as HierarchyNode<unknown>);

    const layout = g.size([2 * Math.PI, radius])(sortedRoot);

    const rotation = Math.PI;

    return layout.descendants().map((node) => {
      const nodeData = node.data as Node;

      return {
        ...nodeData,
        position: {
          x: node.y * Math.cos(node.x - rotation) - rootWidth / 2,
          y: node.y * Math.sin(node.x - rotation) - rootHeight / 2,
        },
      };
    });
  };

  const sortRootByNodeIds = (root: HierarchyNode<unknown>) => {
    return root.sort((a, b) => compareNodeIds(a?.id, b?.id));
  };

  const setLaidOutNodes = ({
    edges,
    nodes,
  }: {
    edges: Edge[];
    nodes: Node[];
  }) => {
    const areAllNodesLaidOut = nodes.every(
      (node: Node) => node.position.x !== 0 && node.position.y !== 0,
    );

    if (areAllNodesLaidOut) {
      setNodes(nodes);
    } else {
      const laidOutNodes = getLaidOutNodes({ edges, nodes });

      setNodes(laidOutNodes);

      persistEditedEdgesAndNodes();
    }

    // NOTE: Wait for the next event loop iteration, i.e. the edges and nodes states were set.
    setTimeout(() => {
      fitView({ duration: 1000 });
    }, 0);
  };

  return { getLaidOutNodes, setLaidOutNodes };
}
