import { styled } from "@linaria/react";
import { Group } from "@visx/group";
import { ParentSize } from "@visx/responsive";
import { scaleOrdinal } from "@visx/scale";
import { Zoom } from "@visx/zoom";
import type { ProvidedZoom, TransformMatrix } from "@visx/zoom/lib/types";
import { max, min } from "d3";
import {
  type Simulation,
  forceCenter,
  forceCollide,
  forceLink,
  forceManyBody,
  forceSimulation,
  forceX,
  forceY,
} from "d3-force";
import {
  type Dispatch,
  type MouseEvent,
  type SetStateAction,
  type TouchEvent,
  useCallback,
  useEffect,
  useRef,
  useState,
  type FunctionComponent,
} from "react";

import Card from "~/components/Card";
import { getSpectrumPalette } from "~/components/charts/utils";
import { type Platform } from "~/constants";

import { MAX_WEIGHT, MIN_CLUSTER_RADIUS, PLATFORM_SPACING } from "../constants";
import {
  type ActiveEdge,
  type ClusterDatum,
  type Datum,
  type LinkDatum,
  type NodeDatum,
} from "../types";

import { CoordinatedPostingActivity } from "./CoordinatedPostingActivity/CoordinatedPostingActivity";
import { GraphGroup } from "./GraphElements/GraphGroup";
import { RADIUS as NODE_RADIUS } from "./GraphElements/NodeCircle";
import { Legend } from "./Legend";

const LINK_COLORSCALE = scaleOrdinal({
  domain: Array.from({ length: MAX_WEIGHT }, (_, index) => index + 1).reverse(),
  range: getSpectrumPalette(MAX_WEIGHT),
});

// Cluster Forces
const CLUSTER_STRENGTH = -5;
const RADIUS_COLLISION_FACTOR = 1.2;
const FORCE_X = 0.1;
const FORCE_Y = 0.15;

const SCALE_MAX = 4;
const SCALE_MIN = 1 / 8;
const BBOX_PADDING = 100;
const MIN_SCALE_FACTOR = 0.5; //to allow user to zoom out further than force layout

const INITIAL_OFFSET = 0;

// Nodes and Links Forces
const NODE_STRENGTH = -150;
const NODE_RADIUS_COLLISION_FACTOR = 0.8;

const LINK_DISTANCE = 5;
const LINK_STRENGTH = 0.1;

const NODE_RADIUS_SCALE_FACTOR = 1.25;

const StyledFigure = styled.figure`
  width: 100%;
  height: 100%;
`;

const StyledSvg = styled.svg`
  height: 100%;
  width: 100%;
  touch-action: none;
`;

const INIT_MATRIX: TransformMatrix = {
  scaleX: 1,
  scaleY: 1,
  translateX: 0,
  translateY: 0,
  skewX: 0,
  skewY: 0,
};

function getScaledRadius(baseRadius: number, factor: number) {
  return Math.min(baseRadius * factor, baseRadius + MIN_CLUSTER_RADIUS);
}

// CPA === "Coordinated Posting Activity"
const CPA_WIDTH_PERCENTAGE = 30;
const CPA_WIDTH_PIXEL = 500;
function getCPAWidth(containerWidth: number) {
  return Math.max(
    (containerWidth * CPA_WIDTH_PERCENTAGE) / 100,
    CPA_WIDTH_PIXEL,
  );
}

function getDatumIdentifier(datum: Datum) {
  switch (datum.type) {
    case "node":
      return datum.id;
    case "edge":
      return datum.index;
    case "cluster":
      return datum.clusterId;
    default:
      return null;
  }
}

function getClusterPositionsMap(clusters: ClusterDatum[]) {
  const clusterYPositions: Record<string, number> = {};
  const platformXPositions: Partial<Record<Platform, number>> = {};
  const platforms = [
    ...new Set(clusters.map((cluster) => cluster.platform)),
  ].sort((a, b) => a.localeCompare(b));

  const groupedPlatformClusters = platforms.reduce(
    (acc, platform) => {
      acc[platform] = {};
      return acc;
    },
    {} as Record<Platform, Record<number, ClusterDatum[]>>,
  );

  clusters.forEach((cluster) => {
    const { nodes, platform } = cluster;
    const key = nodes.length;
    if (!groupedPlatformClusters[platform][key]) {
      groupedPlatformClusters[platform][key] = [];
    }
    groupedPlatformClusters[platform][key].push(cluster);
  });

  // Ensure that clusters with the same number of nodes have the same Y position in the same platform
  let previousPlatformRightX = 0;
  platforms.forEach((platform, pIndex) => {
    const nodeNums = Object.keys(groupedPlatformClusters[platform])
      .map(Number)
      .sort((a, b) => a - b);
    let yPosition = INITIAL_OFFSET;
    let prevRadius = 0;
    let mostRightX = 0;
    let mostLeftX: number | undefined;
    let sumX = 0;
    let numClusters = 0;

    nodeNums.forEach((nodeNum, index) => {
      const clusters = groupedPlatformClusters[platform][nodeNum];
      numClusters += clusters.length;
      clusters.forEach((cluster) => {
        sumX += cluster.x ?? 0;
      });
      const maxRadius = max(clusters, (cluster) => cluster.radius) ?? 0;
      const scaledMaxRadius = getScaledRadius(
        maxRadius,
        RADIUS_COLLISION_FACTOR,
      );

      if (index > 0) {
        yPosition -= prevRadius + scaledMaxRadius;
      }
      prevRadius = scaledMaxRadius;
      clusters.forEach(
        (cluster) => (clusterYPositions[cluster.clusterId] = yPosition),
      );

      const mostLeftXByClusters = Math.min(
        min(
          clusters,
          (cluster) =>
            (cluster.x ?? 0) -
            getScaledRadius(cluster.radius, RADIUS_COLLISION_FACTOR),
        ) ?? 0,
      );

      mostLeftX = mostLeftX
        ? Math.min(mostLeftX, mostLeftXByClusters)
        : mostLeftXByClusters;

      mostRightX = Math.max(
        max(
          clusters,
          (cluster) =>
            (cluster.x ?? 0) +
            getScaledRadius(cluster.radius, RADIUS_COLLISION_FACTOR),
        ) ?? 0,
        mostRightX,
      );
    });
    const averageX = sumX / numClusters;
    /*
    A platform's X location is calculated by adding the previous platform's right most occupied spot
    and the difference between the current platform's left most occupied spot and its average X location
    */
    const offset =
      previousPlatformRightX + PLATFORM_SPACING + (averageX - (mostLeftX ?? 0));
    platformXPositions[platform] = pIndex ? offset : 0;
    previousPlatformRightX = mostRightX;
  });

  return {
    clusterYPositions,
    platformXPositions,
  };
}

function getScaleAndCenter(
  clusters: ClusterDatum[],
  width: number,
  height: number,
) {
  let [minX, minY, maxX, maxY] = [0, 0, 0, 0];

  clusters.forEach((cluster) => {
    const { x = 0, y = 0, radius } = cluster;
    const padRadius = radius + BBOX_PADDING;
    const boundary = {
      minX: x - padRadius,
      maxX: x + padRadius,
      minY: y - padRadius,
      maxY: y + padRadius,
    };

    minX = Math.min(minX, boundary.minX);
    minY = Math.min(minY, boundary.minY);

    maxX = Math.max(maxX, boundary.maxX);
    maxY = Math.max(maxY, boundary.maxY);
  });

  const boundingBox = {
    width: maxX - minX,
    height: maxY - minY,
  };

  return {
    center: {
      x: minX + boundingBox.width / 2,
      y: minY + boundingBox.height / 2,
    },
    scale: Math.min(width / boundingBox.width, height / boundingBox.height),
  };
}

function getClusterRadius(nodes: NodeDatum[]) {
  const { length } = nodes;
  const cx = nodes.reduce((total, node) => total + (node.x ?? 0), 0) / length;
  const cy = nodes.reduce((total, node) => total + (node.y ?? 0), 0) / length;
  const radius = nodes.reduce((max, node) => {
    if (!node.x || !node.y) {
      return max;
    }
    const xFromCenter = Math.abs(node.x - cx) + NODE_RADIUS / 2;
    const yFromCenter = Math.abs(node.y - cy) + NODE_RADIUS / 2;
    const hypotenuse = Math.hypot(yFromCenter, xFromCenter);
    return Math.max(hypotenuse, max);
  }, 0);

  const scaledRadius = getScaledRadius(radius, NODE_RADIUS_SCALE_FACTOR);
  return Math.max(MIN_CLUSTER_RADIUS, scaledRadius);
}

function handleClusterRadiusChange(
  index: number | undefined,
  newRadius: number,
  simulation: Simulation<ClusterDatum, undefined>,
) {
  if (index === undefined) {
    return;
  }
  const currentClusters = simulation.nodes();
  currentClusters[index].radius = newRadius;

  const { clusterYPositions, platformXPositions } =
    getClusterPositionsMap(currentClusters);

  simulation.force(
    "x",
    forceX<ClusterDatum>((d) => platformXPositions[d.platform] ?? 0).strength(
      FORCE_X,
    ),
  );
  simulation.force(
    "y",
    forceY<ClusterDatum>((d) => clusterYPositions[d.clusterId]).strength(
      FORCE_Y,
    ),
  );
  simulation.force(
    "collide",
    forceCollide<ClusterDatum>().radius((d) =>
      getScaledRadius(d.radius, RADIUS_COLLISION_FACTOR),
    ),
  );
  simulation.nodes(currentClusters).alpha(1).restart();
}

function handleNodeLinksPosition(
  index: number | undefined,
  links: LinkDatum[],
  nodes: NodeDatum[],
  clusterSimulation: Simulation<ClusterDatum, undefined>,
) {
  const simulation = forceSimulation<NodeDatum, LinkDatum>(nodes)
    .force(
      "link",
      forceLink<NodeDatum, LinkDatum>(links)
        .id((d) => d.id)
        .distance(LINK_DISTANCE)
        .strength(LINK_STRENGTH),
    )
    .force("charge", forceManyBody().strength(NODE_STRENGTH))
    .force("center", forceCenter())
    .force(
      "collision",
      forceCollide<NodeDatum>().radius(() =>
        nodes.length
          ? NODE_RADIUS *
            (1 + Math.log2(nodes.length) * NODE_RADIUS_COLLISION_FACTOR)
          : NODE_RADIUS,
      ),
    )
    .on("tick", () => {
      handleClusterRadiusChange(
        index,
        getClusterRadius(simulation.nodes()),
        clusterSimulation,
      );
    });
}

const useGraphLayout = (
  data: ClusterDatum[],
  width: number,
  height: number,
  zoom: ProvidedZoom<SVGSVGElement>,
  setMinZoomScale: Dispatch<SetStateAction<number>>,
  initPlatformPlacement?: Record<Platform, number>,
) => {
  /*
    This hook is responsible for the force layout of the clusters ONLY. The force layout of
    the nodes/edges INSIDE each cluster is handled in ClusterCircle.tsx.
    */
  const [clusters, setClusters] = useState<ClusterDatum[]>(data);

  /*
    calling methods of zoom update zoom itself -> infinite loop
    storing in ref to prevent zoom from being in dependency array
    */
  const zoomRef = useRef(zoom);

  /*
    weird case where width/height keeps changing onClick when in browser of Playwright
    works fine in local so not sure what is going on, stored them in Ref to pass the test
    */
  const widthRef = useRef(width);
  const heightRef = useRef(height);

  /*
    autoZooming is handled in simulation.on("tick"). Initially, isAutoZooming was stored in a state,
    but `.on("tick")` only looks at the state when it is called and not during the simulation.
    By storing it in useRef, the '.on("tick")' looks at the boolean on every tick and act accordingly.
    */
  const isAutoZoomingRef = useRef(false);
  const stopAutoZoom = useCallback(
    () => (isAutoZoomingRef.current = false),
    [],
  );

  const handleAutoZoom = useCallback(
    (simulationNodes: ClusterDatum[]) => {
      if (!isAutoZoomingRef.current) {
        return;
      }

      // this is to prevent the clusters from hiding underneath the Coordinated Posting Activity
      const occupiedWidth = getCPAWidth(widthRef.current);

      const { center, scale } = getScaleAndCenter(
        [...simulationNodes],
        widthRef.current - occupiedWidth,
        heightRef.current,
      );
      zoomRef.current.setTransformMatrix({
        scaleX: scale,
        scaleY: scale,
        skewX: 0,
        skewY: 0,
        translateX: (widthRef.current - occupiedWidth) / 2 - center.x * scale,
        translateY: heightRef.current / 2 - center.y * scale,
      });
      setMinZoomScale(scale * MIN_SCALE_FACTOR);
    },
    [setMinZoomScale],
  );

  useEffect(() => {
    isAutoZoomingRef.current = true;
    const clustersCopy = [...data];

    const simulation = forceSimulation<ClusterDatum>(clustersCopy)
      .force("charge", forceManyBody().strength(CLUSTER_STRENGTH))
      .force(
        "collide",
        forceCollide<ClusterDatum>().radius((d) =>
          getScaledRadius(d.radius, RADIUS_COLLISION_FACTOR),
        ),
      )
      .force(
        "x",
        forceX<ClusterDatum>((d) =>
          initPlatformPlacement ? initPlatformPlacement[d.platform] : 0,
        ).strength(FORCE_X),
      );

    clustersCopy.forEach((cluster) => {
      const { index, links, nodes } = cluster;
      handleNodeLinksPosition(index, links, nodes, simulation);
    });

    simulation.on("tick", () => {
      setClusters([...simulation.nodes()]);
      handleAutoZoom(simulation.nodes());
    });

    simulation.on("end", () => {
      isAutoZoomingRef.current = false;
    });

    return () => {
      simulation.stop();
    };
  }, [data, handleAutoZoom, initPlatformPlacement]);

  return {
    clusters,
    stopAutoZoom,
  };
};

interface CoordinationGraphProps {
  activeEdges: ActiveEdge;
  activeIds: Set<string>;
  conversationId: string;
  data: ClusterDatum[];
  height: number;
  initPlatformPlacement?: Record<Platform, number>;
  nodeLabel: string;
  onNodeClick: (node: NodeDatum | undefined) => void;
  onLinkClick: (link: LinkDatum | undefined) => void;
  onClusterClick: (cluster: ClusterDatum | undefined) => void;
  selectedDatum: Datum | null;
  setMinZoomScale: Dispatch<SetStateAction<number>>;
  setSelectedDatum: Dispatch<SetStateAction<Datum | null>>;
  width: number;
  zoom: ProvidedZoom<SVGSVGElement>;
}

const CoordinationGraph: FunctionComponent<CoordinationGraphProps> = (
  props,
) => {
  const {
    activeEdges,
    activeIds,
    data,
    height,
    initPlatformPlacement,
    nodeLabel,
    onNodeClick,
    onLinkClick,
    onClusterClick,
    selectedDatum,
    setMinZoomScale,
    setSelectedDatum,
    width,
    zoom,
  } = props;

  const { clusters, stopAutoZoom } = useGraphLayout(
    data,
    width,
    height,
    zoom,
    setMinZoomScale,
    initPlatformPlacement,
  );

  const handleSelect = useCallback(
    (datum: Datum) => {
      setSelectedDatum((prevDatum) => {
        if (
          !prevDatum ||
          prevDatum.type !== datum.type ||
          getDatumIdentifier(prevDatum) !== getDatumIdentifier(datum)
        ) {
          return datum;
        }
        return null;
      });
    },
    [setSelectedDatum],
  );

  const handleZoomDrag = useCallback(
    (event: MouseEvent | TouchEvent) => {
      zoom.dragStart(event);
      stopAutoZoom();
    },
    [stopAutoZoom, zoom],
  );

  return (
    <StyledSvg
      ref={zoom.containerRef}
      id="coordination-graph"
      onMouseDown={handleZoomDrag}
      onMouseMove={zoom.dragMove}
      onMouseUp={zoom.dragEnd}
      onTouchEnd={zoom.dragEnd}
      onTouchMove={zoom.dragMove}
      onTouchStart={handleZoomDrag}
      onWheel={stopAutoZoom}
      role="group"
    >
      <Group role="list" transform={zoom.toString()}>
        {clusters.map((cluster) => (
          <GraphGroup
            key={cluster.clusterId}
            activeEdges={activeEdges}
            activeIds={activeIds}
            cluster={cluster}
            handleSelect={handleSelect}
            linkColorScale={LINK_COLORSCALE}
            nodeLabel={nodeLabel}
            onClusterClick={onClusterClick}
            onLinkClick={onLinkClick}
            onNodeClick={onNodeClick}
            selectedDatum={selectedDatum}
          />
        ))}
      </Group>
    </StyledSvg>
  );
};

const WidgetContainer = styled.div`
  position: absolute;
  display: flex;
  top: 0;
  left: 0;
  pointer-events: none;
  height: 100%;
  width: 100%;
  padding: var(--spacing-3xl);
  flex-direction: column;
  justify-content: space-between;

  > div {
    align-self: flex-end;
    pointer-events: auto;
    height: max(30%, 400px);
    width: max(${CPA_WIDTH_PERCENTAGE}%, ${CPA_WIDTH_PIXEL}px);
  }
`;

const StyledParentSize = styled(ParentSize)`
  position: relative;
`;

type CoordinationGraphWrapperProps = Omit<
  CoordinationGraphProps,
  "height" | "width" | "zoom" | "setMinZoomScale"
>;

const CoordinationGraphWrapper: FunctionComponent<
  CoordinationGraphWrapperProps
> = (props) => {
  const {
    activeEdges,
    activeIds,
    conversationId,
    data,
    initPlatformPlacement,
    nodeLabel,
    onLinkClick,
    onNodeClick,
    onClusterClick,
    selectedDatum,
    setSelectedDatum,
  } = props;

  const [minZoomScale, setMinZoomScale] = useState(SCALE_MIN);

  const numOfActiveClusters = data.filter((datum) =>
    activeIds.has(datum.clusterId),
  ).length;

  return (
    <StyledParentSize>
      {({ width, height }) =>
        width > 0 && height > 0 ? (
          <StyledFigure>
            <Zoom<SVGSVGElement>
              height={height}
              initialTransformMatrix={{
                ...INIT_MATRIX,
                translateX: width / 2,
                translateY: height / 2,
              }}
              scaleXMax={SCALE_MAX}
              scaleXMin={minZoomScale}
              scaleYMax={SCALE_MAX}
              scaleYMin={minZoomScale}
              width={width}
            >
              {(zoom) => (
                <>
                  <CoordinationGraph
                    activeEdges={activeEdges}
                    activeIds={activeIds}
                    conversationId={conversationId}
                    data={data}
                    height={height}
                    initPlatformPlacement={initPlatformPlacement}
                    nodeLabel={nodeLabel}
                    onClusterClick={onClusterClick}
                    onLinkClick={onLinkClick}
                    onNodeClick={onNodeClick}
                    selectedDatum={selectedDatum}
                    setMinZoomScale={setMinZoomScale}
                    setSelectedDatum={setSelectedDatum}
                    width={width}
                    zoom={zoom}
                  />
                  <WidgetContainer>
                    <Legend linkColorScale={LINK_COLORSCALE} />
                    <Card>
                      <CoordinatedPostingActivity
                        conversationId={conversationId}
                        numClusters={numOfActiveClusters}
                        selectedDatum={selectedDatum}
                      />
                    </Card>
                  </WidgetContainer>
                </>
              )}
            </Zoom>
          </StyledFigure>
        ) : null
      }
    </StyledParentSize>
  );
};

export default CoordinationGraphWrapper;
