import { startTransition, useCallback, useMemo, useState } from 'react';
import {
  AccessorKeyColumnDef,
  ColumnDef,
  RowSelectionState,
  SortingFn,
  SortingState,
  Row as TableRow,
  getCoreRowModel,
  getSortedRowModel,
  sortingFns,
  useReactTable,
} from '@tanstack/react-table';
import { useVirtualizer } from '@tanstack/react-virtual';
import { Checkbox } from '@karnott/form';
import { Column, KeyColumn, RowId, RowShape, StringKeys } from './types';

function getSortingFn<Row extends RowShape, Key extends StringKeys<Row>>(col: KeyColumn<Row, Key>): SortingFn<Row> {
  return (rowA, rowB, columnId) => {
    // @ts-expect-error Too hard to type, a and b are the values to compare to sort this column
    const a = rowA.original[columnId]?.[col.sortOnKey];
    // @ts-expect-error Same
    const b = rowB.original[columnId]?.[col.sortOnKey];

    if ('sortingFnName' in col && typeof col.sortingFnName === 'string') {
      return sortingFns[col.sortingFnName](
        {
          getValue: () => a,
        } as unknown as TableRow<Row>,
        {
          getValue: () => b,
        } as unknown as TableRow<Row>,
        columnId,
      );
    } else if ('sortingFn' in col && typeof col.sortingFn === 'function') {
      return col.sortingFn(a, b);
    }

    return 0;
  };
}

type UseTableProps<Row extends RowShape> = {
  rows: Row[];
  columns: Column<Row>[];
  onSelectionChange: (selectedIds: RowId[]) => void;
  sortAll?: boolean;
};
export function useTable<Row extends RowShape>({ columns, rows, onSelectionChange, sortAll }: UseTableProps<Row>) {
  /* Selection state */
  const [rowSelection, setRowSelection] = useState<RowSelectionState>({});
  const toggleAllRows = useCallback(() => {
    startTransition(() =>
      setRowSelection((selection) => {
        const newSelection =
          Object.keys(selection).length === rows.length ? {} : Object.fromEntries(rows.map((row) => [row.id, true]));
        onSelectionChange(Object.keys(newSelection));
        return newSelection;
      }),
    );
  }, [onSelectionChange, rows]);

  /* Sorting state */
  const [sorting, setSorting] = useState<SortingState>([]);

  /* Column definitions */
  const columnDefs = useMemo(() => {
    return (
      columns
        .map((col, i) => {
          const colDef: Partial<AccessorKeyColumnDef<Row>> = {
            enableSorting: col.sort === true || (sortAll === true && col.sort !== false),
            meta: {
              displayTotal: col.total !== false || false,
              displayOnHover: col.displayOnHover || false,
            },
          };

          // Data column
          if ('key' in col) {
            colDef.accessorKey = col.key;
            colDef.header = () => col.header;
            colDef.cell = ({ getValue, row }) => {
              const value = getValue();
              return value === null || value === undefined
                ? '-'
                : (col as KeyColumn<Row, StringKeys<Row>>).displayFn?.({
                    ...(row?.original || {}),
                    [col.key!]: value,
                    row: row?.original || {},
                  } as Parameters<NonNullable<KeyColumn<Row, StringKeys<Row>>['displayFn']>>[0]) || value;
            };
            colDef.sortUndefined = -1;
            if ('sortOnKey' in col && typeof col.sortOnKey === 'string') {
              colDef.sortingFn = getSortingFn(col);
            }
          } else {
            // An id is required, but we don't need it so we generate it on the fly
            colDef.id = `no_data_${i}`;

            if ('type' in col) {
              // Selection column
              if (col.type === 'selection') {
                colDef.header = ({ table }) =>
                  col.header === false ? null : (
                    <div className="cell_no_data">
                      <Checkbox
                        checked={table.getIsAllRowsSelected()}
                        onChange={toggleAllRows}
                        hideLabel
                        label={`Select all / ${i}`}
                      />
                    </div>
                  );

                colDef.cell = ({ row }) => (
                  <div className="cell_no_data">
                    <Checkbox
                      checked={row.getIsSelected()}
                      onChange={row.getToggleSelectedHandler()}
                      label={`Select ${row.id} / ${i}`}
                      hideLabel
                    />
                  </div>
                );
              }
            }

            // No data column
            else if ('displayFn' in col) {
              colDef.header = () => col.header || null;
              colDef.cell = ({ row }) => <div className="cell_no_data">{col.displayFn?.(row.original)}</div> || null;
            }
          }
          return colDef;
        })
        // Remove illegal columns
        .filter((colDef) => 'cell' in colDef) as ColumnDef<Row>[]
    );
  }, [columns, sortAll, toggleAllRows]);

  const table = useReactTable({
    columns: columnDefs,
    data: rows,

    getCoreRowModel: getCoreRowModel(),
    getSortedRowModel: getSortedRowModel(),
    getRowId: (row) => String(row.id),
    state: {
      rowSelection,
      sorting,
    },

    enableRowSelection: true,
    onRowSelectionChange: (selection) => {
      startTransition(() => {
        setRowSelection(selection);
        if (typeof selection === 'function') {
          onSelectionChange(Object.keys(selection(rowSelection)));
        }
      });
    },

    enableMultiSort: true,
    onSortingChange: (state) => {
      startTransition(() => setSorting(state));
    },
  });

  return useMemo(
    () => ({
      table,
      columnDefs,
    }),
    [columnDefs, table],
  );
}

export function useVirtualRows<Row extends RowShape>({ rows }: { rows: TableRow<Row>[] }) {
  const [scrollElement, setScrollElement] = useState<HTMLElement | null>(null);

  const rowHeight = useMemo(
    () => scrollElement?.querySelector<HTMLElement>(`tbody tr[data-index]`)?.offsetHeight || 54,
    [scrollElement],
  );

  const rowVirtualizer = useVirtualizer({
    count: rows.length,
    getScrollElement: () => scrollElement,
    estimateSize: () => rowHeight,
    getItemKey: (i) => rows.at(i)?.id || i,
    overscan: 30,
  });

  const virtualRows = rowVirtualizer.getVirtualItems();
  const paddingTop = virtualRows.at(0)?.start || 0;
  const paddingBottom = virtualRows.length > 0 ? rowVirtualizer.getTotalSize() - (virtualRows.at(-1)?.end || 0) : 0;

  return {
    setScrollElement,
    virtualRows,
    paddingBottom,
    paddingTop,
    measureElement: rowVirtualizer.measureElement,
  };
}

type UseTotalRowProps<Row extends RowShape> = {
  rows: TableRow<Row>[];
  columnDefs: ColumnDef<Row>[];
  displayTotal: boolean;
};
export function useTotalRow<Row extends RowShape>({
  rows,
  columnDefs,
  displayTotal,
}: UseTotalRowProps<Row>): (number | null)[] {
  return useMemo(() => {
    if (!displayTotal) return [];

    // the first column is always reserved for the "total" label
    const cells = Array<number | null>(columnDefs.length - 1);

    // column-level total disabling
    for (let i = 1; i < columnDefs.length; i++) {
      if (!(columnDefs[i]?.meta?.displayTotal || false)) cells[i - 1] = null;
    }

    for (const row of rows) {
      for (let i = 0; i < cells.length; i++) {
        if (cells[i] === null) {
          continue;
        }

        const value = row
          .getAllCells()
          .at(i + 1)!
          .getValue();

        if (typeof value === 'number' || value === undefined || value === null) {
          cells[i] = (cells[i] || 0) + (value || 0);
        } else {
          cells[i] = null;
        }
      }
    }

    return cells;
  }, [columnDefs, displayTotal, rows]);
}
