import { styled } from "@linaria/react";
import {
  type AccessorFnColumnDef,
  type CellContext,
  type ColumnDef,
  type ColumnSort,
  type Header,
  type Row,
  type SortingState,
  type RowSelectionState,
  flexRender,
  getCoreRowModel,
  getExpandedRowModel,
  getFilteredRowModel,
  getPaginationRowModel,
  getSortedRowModel,
  useReactTable,
} from "@tanstack/react-table";
import { useVirtualizer, type Virtualizer } from "@tanstack/react-virtual";
import {
  type StyledComponent,
  type TableHTMLAttributes,
  type FunctionComponent,
  type ReactNode,
  useCallback,
  useEffect,
  useMemo,
  useRef,
  useState,
} from "react";

import { useDebounce } from "~/hooks/useDebounce";

import { LoadingIndicator } from "./library";
import { FormSize } from "./library/Form";
import Icon from "./library/Icon";
import Input from "./library/Input";
import { type PageChangeData, TablePagination } from "./library/Pagination";

export type PostContentProps<T> = {
  columnCount: number;
  data: T;
  id: string;
};

export type RowDetailProps<T> = {
  "aria-expanded": "true";
  columnCount: number;
  data: T;
  id: string;
  toggleExpand: (event: unknown) => void;
};

export type RowDetailRenderer<T> = FunctionComponent<RowDetailProps<T>>;
export type CellRenderer<T> = FunctionComponent<{
  data: T;
  onToggleRow: (event: unknown) => void;
  rowId: string;
}>;

export interface TableColumn<T extends object> {
  /** include if the default accessor function doesn't correctly extract the value for this column for global filtering, etc */
  accessorFn?: (row: T, index: number) => any;
  /** the name of the object's property to render in this column (if any) */
  field?: string;
  /** default is true; set to false to opt out */
  filterable?: boolean;
  /** include if the default filter function doesn't work for this column */
  filterFn?: (row: T, column: string, value: any) => boolean;
  /** a custom display name for this column, or `null` for no display name */
  label?: string | null;
  /** include if you need full control over what goes inside this column's cells */
  renderFn?: CellRenderer<T>;
  /** default is true; set to false to opt out */
  sortable?: boolean;
  /** include if the sorting order needs to be controlled, any nullable columns must have this set to go back to non-sorting  */
  sortDescFirst?: boolean;
  /** include if the default sort function doesn't work for this column's data */
  sortFn?: (arg0: T, arg1: T) => number;
}

export interface CustomTableSortEvent {
  /** a user-visible table column name (usually but not always a DTO property) */
  field: string;
  /** a sort direction ('asc' or 'desc') (or `false` to clear sorting) */
  sorted: false | "asc" | "desc";
}

export interface CustomTableFilterEvent {
  /** a user-visible table column name (usually but not always a DTO property) */
  field: string;
  /** a search string */
  query: string | undefined;
}

export interface TableProps<T extends object>
  extends StyledComponent,
    TableHTMLAttributes<HTMLTableElement> {
  disableFilter?: boolean;
  disableMultiRowSelection?: boolean;
  getClassNameForRow?: (row: { data: T }) => string | undefined;
  id: string;
  initialSort?: ColumnSort | ColumnSort[]; // TODO: unify this w/ CustomTableSortEvent
  onCellClick?: (rowObj: T, fieldName: string) => void;
  onCustomFilter?: (filterEvent: CustomTableFilterEvent) => void;
  onCustomSort?: (sortEvent: CustomTableSortEvent) => void;
  onRowClick?: (rowObj: T) => void;
  pageSize?: number;
  searchTerm?: string;
  tableColumns: readonly TableColumn<T>[];
  tableData: T[];
  virtual?: boolean;
  withPagination?: boolean;
  /* include if you want to be able to expand table rows and show detail.  Must
   * return a `<tr>` element.
  @note there doesn't seem to be a way to ensure this at compile-time yet:
  https://stackoverflow.com/a/71515552/635678
  https://github.com/microsoft/TypeScript/issues/21699
  so please do the right thing */
  postContentRenderer?: FunctionComponent<PostContentProps<T>>;
  renderEmptyState?: () => ReactNode;
}

export interface PaginatedTableProps<T extends object> extends TableProps<T> {
  pageSize: number;
  withPagination: true;
}

interface TableHeaderCellProps<T extends object> {
  disableFilter?: boolean;
  header: Header<T, unknown>;
  onCustomFilter?: TableProps<T>["onCustomFilter"];
  onCustomSort?: TableProps<T>["onCustomSort"];
  tableId: string;
}

const defaultRenderEmptyState = () => "No data to display";

const getOptionalColumnConfig = function <T extends object>(
  col: TableColumn<T>,
) {
  const { accessorFn, filterable, filterFn, sortable, sortFn } = col;

  const optionalValues: Partial<AccessorFnColumnDef<T>> = {};

  if (sortable !== false && sortFn !== undefined) {
    optionalValues.sortingFn = (r1: Row<T>, r2: Row<T>) =>
      sortFn(r1.original, r2.original);
  }

  if (filterable !== false) {
    if (filterFn !== undefined) {
      optionalValues.filterFn = (
        row: Row<T>,
        columnId: string,
        filterValue: any,
      ) => {
        return filterFn(row.original, columnId, filterValue);
      };
    } else {
      optionalValues.filterFn = "includesString";
    }
  }

  if (accessorFn !== undefined) {
    optionalValues.accessorFn = accessorFn;
  }

  return optionalValues;
};

export const SkipRow = styled.tr`
  padding: 0;
`;

const TableHeaderCell = function <T extends object>(
  props: TableHeaderCellProps<T>,
) {
  const { disableFilter, header, onCustomFilter, onCustomSort, tableId } =
    props;
  const { column } = header;
  const { columnDef } = column;
  const contents = flexRender(
    header.column.columnDef.header,
    header.getContext(),
  );
  const sorted = header.column.getIsSorted();
  const ariaSort =
    sorted === "asc"
      ? "ascending"
      : sorted === "desc"
      ? "descending"
      : undefined;
  const sortIcon = sorted ? (
    <Icon
      family="untitled"
      name={ariaSort === "ascending" ? "arrow-up" : "arrow-down"}
    />
  ) : (
    <span data-icon="true" />
  ); /* placeholder to avoid layout shift */
  const [columnFilterValue, setColumnFilterValue] = useState<string>();
  const debouncedValue = useDebounce(columnFilterValue, 350);
  const sortHandler = useCallback(
    (event: unknown) => {
      onCustomSort?.({
        field: String(columnDef.id),
        sorted: column.getNextSortingOrder(),
      });
      column.getToggleSortingHandler()?.(event);
    },
    [column, columnDef, onCustomSort],
  );
  useEffect(() => {
    column.setFilterValue(debouncedValue);
    onCustomFilter?.({
      field: String(column.id),
      query: debouncedValue,
    });
  }, [column, debouncedValue, onCustomFilter]);

  // TODO: the aria-label on the filter field reads "Filter by X" where X is a column index :(

  return (
    <th aria-sort={ariaSort} colSpan={header.colSpan}>
      {header.column.getCanSort() ? (
        <button onClick={sortHandler} type="button">
          {contents}
          {sortIcon}
        </button>
      ) : (
        contents
      )}
      {column.getCanFilter() && !disableFilter && (
        <Input
          aria-label={`filter by ${column.id}`}
          data-lpignore="true"
          data-size={FormSize.sm}
          id={`${tableId}-filter-${column.id}`}
          onChange={setColumnFilterValue}
          placeholder="Search"
          preIcon={<Icon family="untitled" name="search-md" />}
          type="text"
          value={String(columnFilterValue ?? "")}
        />
      )}
    </th>
  );
};

const getCellRenderer = function <T>(consumerRenderer?: CellRenderer<T>) {
  return (props: CellContext<T, any>) =>
    consumerRenderer?.({
      data: props.row.original,
      onToggleRow: props.row.getToggleSelectedHandler(),
      rowId: props.row.id,
    }) ?? props.getValue();
};

const useTable = function <T extends object>(
  props: TableProps<T> | PaginatedTableProps<T>,
) {
  const {
    className: _className,
    disableFilter: _disableFilter,
    disableMultiRowSelection,
    getClassNameForRow,
    onCellClick,
    onCustomFilter,
    onCustomSort,
    onRowClick,
    pageSize = 20,
    postContentRenderer,
    renderEmptyState = defaultRenderEmptyState,
    searchTerm,
    tableColumns,
    tableData,
    withPagination,
    initialSort = [],
    virtual: _virtual,
    ...otherProps
  } = props;
  /* some data structures we don't want to re-create every render cycle */
  const data = useMemo(() => tableData ?? [], [tableData]);
  const columns: ColumnDef<T>[] = useMemo(
    () =>
      tableColumns?.map((c, idx) => ({
        accessorKey: c.field,
        cell: getCellRenderer<T>(c.renderFn),
        enableColumnFilter: c.filterable !== false,
        enableSorting: c.sortable !== false,
        header: c.label === null ? "" : c.label ?? c.field,
        id: c.field ?? String(idx),
        sortDescFirst: c.sortDescFirst,
        /* tanstack table doesn't like undefined values for some things */
        ...getOptionalColumnConfig<T>(c),
      })) ?? [],
    [tableColumns],
  );
  const [sortingState, setSortingState] = useState<SortingState>(() =>
    Array.isArray(initialSort) ? initialSort : [initialSort],
  );
  const [selectionState, setSelectionState] = useState<RowSelectionState>({});
  const table = useReactTable({
    columns,
    data,
    enableMultiRowSelection: !disableMultiRowSelection,
    getCoreRowModel: getCoreRowModel(),
    getExpandedRowModel: getExpandedRowModel(),
    getFilteredRowModel: getFilteredRowModel(),
    getPaginationRowModel: getPaginationRowModel(),
    getSortedRowModel: getSortedRowModel(),
    manualFiltering: !!onCustomFilter,
    manualPagination: !withPagination,
    manualSorting: !!onCustomSort,
    onRowSelectionChange: setSelectionState,
    onSortingChange: setSortingState,
    state: {
      globalFilter: searchTerm,
      /* use tanstack table's row-selection mechanism to control rendering
       * of any detail views, since its native expansion mechanism seems to
       * only support showing and hiding sub-rows which use the same column
       * set as the parent row */
      rowSelection: selectionState,
      sorting: sortingState,
    },
  });
  const onPageChange = useCallback(
    (pageChange: PageChangeData) => {
      table.setPageIndex(pageChange.page);
    },
    [table],
  );
  const onClickRow = useCallback(
    (e: React.MouseEvent<HTMLTableRowElement>) => {
      const rowId = e.currentTarget?.getAttribute("data-for");
      const row = table.getRowModel().rows.find((row) => row.id === rowId);

      if (!row) {
        return;
      }

      onRowClick?.(row.original);
    },
    [onRowClick, table],
  );
  const onClickCell = useCallback(
    (e: React.MouseEvent<HTMLTableCellElement>) => {
      const cellIndex = e.currentTarget?.cellIndex ?? -1;
      const rowId = e.currentTarget?.closest("tr")?.getAttribute("data-for");
      const row = table.getRowModel().rows.find((row) => row.id === rowId);
      const column = tableColumns[cellIndex].field;

      if (!row || !column) {
        return;
      }

      onCellClick?.(row.original, column);
    },
    [onCellClick, table, tableColumns],
  );
  const tableRef = useRef<HTMLTableElement>(null);

  useEffect(() => {
    if (!table || !tableData) {
      return;
    }

    table.setPageSize(withPagination ? pageSize : tableData.length);
  }, [table, tableData, pageSize, withPagination]);

  useEffect(() => {
    setSelectionState({});

    if (tableRef.current) {
      tableRef.current.scrollTop = 0;
    }
  }, [tableData]);

  return {
    getClassNameForRow,
    numRecords: tableData?.length ?? 0,
    onPageChange,
    onCellClick: onClickCell,
    onCustomFilter,
    onCustomSort,
    onRowClick: onClickRow,
    pageSize,
    postContentRenderer,
    renderEmptyState,
    selectionState,
    table,
    tableProps: otherProps,
    tableRef,
    withPagination,
  };
};

interface ColumnRowProps<T extends object> {
  getClassNameForRow: TableProps<T>["getClassNameForRow"];
  onCellClick?: (e: React.MouseEvent<HTMLTableCellElement>) => void;
  onRowClick?: (e: React.MouseEvent<HTMLTableRowElement>) => void;
  row: Row<T>;
  tableId: string;
}

const ColumnRow = <T extends object>(props: ColumnRowProps<T>) => {
  const { getClassNameForRow, row, tableId, onCellClick, onRowClick } = props;

  return (
    <tr
      key={row.id}
      className={getClassNameForRow?.({ data: row.original })}
      data-for={row.id}
      id={`${tableId}-${row.id}`}
      onClick={onRowClick}
    >
      {row.getVisibleCells().map((cell) => (
        <td
          key={cell.id}
          onClick={onCellClick}
          title={`${cell.getValue() ?? ""}`}
        >
          {flexRender(cell.column.columnDef.cell, cell.getContext())}
        </td>
      ))}
    </tr>
  );
};

const TableBody = function <T extends object>(props: {
  getClassNameForRow: TableProps<T>["getClassNameForRow"];
  onCellClick?: (e: React.MouseEvent<HTMLTableCellElement>) => void;
  onRowClick?: (e: React.MouseEvent<HTMLTableRowElement>) => void;
  rows: Row<T>[];
  postContentRenderer?: TableProps<T>["postContentRenderer"];
  rowVirtualizer: Virtualizer<HTMLTableElement, Element>;
  tableId: string;
}) {
  const {
    getClassNameForRow,
    onCellClick,
    onRowClick,
    postContentRenderer: PostContentRenderer,
    rows,
    rowVirtualizer,
    tableId,
  } = props;

  /* with virtualized rows, we only render a subset of the table rows (the visible rows + overscan rows above and below the visible rows),
   * so we need to make sure that the total height of the scrollable table region (tbody) accounts for everything we don't render.
   * we do that by adding
   *      1. a "padding" row above the rendered rows whose height is equal to the top positiopn of the first virtual row
   *      (also equal to the total height of all the non-rendered rows above the first virtual row)
   *      2. a "padding" row below the rendered rows whose height is equal to the total height of all rows minus the bottom position
   *      of the last virtual row (also equal to the total height of all non-rendered rows below the last virtual row)
   */
  const paddingTop =
    rowVirtualizer.getVirtualItems().length > 0
      ? rowVirtualizer.getVirtualItems()?.[0]?.start || 0
      : 0;
  const paddingBottom =
    rowVirtualizer.getVirtualItems().length > 0
      ? rowVirtualizer.getTotalSize() -
        (rowVirtualizer.getVirtualItems()?.[
          rowVirtualizer.getVirtualItems().length - 1
        ]?.end || 0)
      : 0;

  const virtualRows = rowVirtualizer.getVirtualItems();

  if (PostContentRenderer) {
    return (
      <>
        {virtualRows.map((virtualRow) => {
          const row = rows[virtualRow.index];
          return (
            <tbody key={row.id}>
              <ColumnRow
                getClassNameForRow={getClassNameForRow}
                onCellClick={onCellClick}
                onRowClick={onRowClick}
                row={row}
                tableId={tableId}
              />
              <PostContentRenderer
                columnCount={row.getVisibleCells().length}
                data={row.original}
                id={row.id}
              />
            </tbody>
          );
        })}
      </>
    );
  }

  return (
    <tbody>
      {
        /* because the first data row we render may not be the first row in the table due to virtual rows,
         * we always need to render at least one SkipRow to maintain the tbody's scrollHeight. as a result, if
         * the first row we render has an even index, we need to make sure that there's an even number of rows
         * above it so that the alternate table row striping stays consisitent
         */
        virtualRows?.[0]?.index % 2 === 0 ? <SkipRow /> : <></>
      }
      <SkipRow>
        {paddingTop > 0 && <td style={{ height: `${paddingTop}px` }} />}
      </SkipRow>
      {virtualRows.map((virtualRow) => {
        const row = rows[virtualRow.index];
        return (
          <ColumnRow
            key={row.id}
            getClassNameForRow={getClassNameForRow}
            onCellClick={onCellClick}
            onRowClick={onRowClick}
            row={row}
            tableId={tableId}
          />
        );
      })}
      <SkipRow>
        {paddingBottom > 0 && <td style={{ height: `${paddingBottom}px` }} />}
      </SkipRow>
    </tbody>
  );
};

const Table = function <T extends object>(
  props: TableProps<T> | PaginatedTableProps<T>,
) {
  const {
    "aria-busy": ariaBusy,
    className,
    disableFilter,
    id,
    tableData,
    virtual = true,
  } = props;
  const {
    getClassNameForRow,
    onCellClick,
    onCustomFilter,
    onCustomSort,
    onPageChange,
    onRowClick,
    pageSize,
    postContentRenderer,
    renderEmptyState,
    table,
    tableProps,
    tableRef,
    withPagination,
  } = useTable<T>(props);
  const { rows } = table.getRowModel();
  const rowVirtualizer = useVirtualizer({
    getScrollElement: () => tableRef.current,
    estimateSize: () => 50,
    overscan: !virtual && tableData.length > 0 ? tableData.length : 10,
    count: rows.length,
  });

  return (
    <div
      className={className}
      style={{ "--column-count": table.getAllColumns().length }}
    >
      <table ref={tableRef} {...tableProps}>
        <thead>
          {table.getHeaderGroups().map((headerGroup) => (
            <tr key={headerGroup.id}>
              {headerGroup.headers.map((header) => (
                <TableHeaderCell
                  key={header.id}
                  disableFilter={disableFilter}
                  header={header}
                  onCustomFilter={onCustomFilter}
                  onCustomSort={onCustomSort}
                  tableId={id}
                />
              ))}
            </tr>
          ))}
        </thead>
        {ariaBusy ? (
          <tbody>
            <tr>
              <LoadingIndicator as="td" data-slot="loading-indicator" />
            </tr>
          </tbody>
        ) : table.getRowModel().rows.length === 0 ? (
          <tbody>
            <tr>
              <td colSpan={table.getAllColumns().length} data-slot="empty">
                {renderEmptyState()}
              </td>
            </tr>
          </tbody>
        ) : (
          <TableBody
            getClassNameForRow={getClassNameForRow}
            onCellClick={onCellClick}
            onRowClick={onRowClick}
            postContentRenderer={postContentRenderer}
            rows={table.getRowModel().rows}
            rowVirtualizer={rowVirtualizer}
            tableId={id}
          />
        )}
      </table>
      {withPagination && id && pageSize && (
        <TablePagination
          aria-controls={id}
          onPageChange={onPageChange}
          rowCount={table.getFilteredRowModel().rows.length}
          rowsPerPage={pageSize}
        />
      )}
    </div>
  );
};

export default Table;
