import {
  Box,
  Divider,
  Flex,
  Table as ChakraTable,
  TableContainer,
  TableContainerProps,
  Tbody,
  Td,
  Th,
  Thead,
  Tr,
  useColorModeValue,
} from '@chakra-ui/react';
import {
  ColumnDef,
  flexRender,
  getCoreRowModel,
  getExpandedRowModel,
  getFilteredRowModel,
  getGroupedRowModel,
  Table as TanstackTable,
  useReactTable,
} from '@tanstack/react-table';

import { GlobalFilter } from './filters/global-filter';
import { useTableSearchQuery } from './use-table-search-query';

type Extract<Type> = Type extends string ? Type : never;

interface GroupedTableProps<TData> extends TableContainerProps {
  data: TData[];
  columns: ColumnDef<TData>[];
  entityName: string;
  groupBy: Extract<keyof TData>;
}

export const GroupedTable = <TData,>({
  data,
  columns,
  entityName,
  groupBy,
  ...rest
}: GroupedTableProps<TData>) => {
  const [globalFilter, setGlobalFilter] = useTableSearchQuery({
    searchParam: `${entityName}Search`,
  });

  const table = useReactTable({
    data,
    columns,
    getCoreRowModel: getCoreRowModel(),
    getExpandedRowModel: getExpandedRowModel(),
    getGroupedRowModel: getGroupedRowModel(),
    getFilteredRowModel: getFilteredRowModel(),
    onGlobalFilterChange: setGlobalFilter,
    globalFilterFn: 'includesString',
    state: {
      globalFilter,
    },
    initialState: {
      columnVisibility: {
        [groupBy]: false,
      },
      grouping: [groupBy],
      expanded: true,
    },
  });

  const tableBorderColor = useColorModeValue('gray.200', 'whiteAlpha.300');

  return (
    <TableContainer
      border="1px"
      borderColor={tableBorderColor}
      rounded="md"
      overflow={'auto'}
      {...rest}
      minW={rest.minW || 'full'}
    >
      <Box w={'fit-content'}>
        <TableFilter table={table} />
        <ChakraTable aria-label={entityName} variant="grouped" size="sm">
          <TableHeader table={table} />
          <TableBody table={table} />
        </ChakraTable>
      </Box>
    </TableContainer>
  );
};

function TableFilter<TData>({ table }: { table: TanstackTable<TData> }) {
  const toolBarBorderColor = useColorModeValue('gray.300', 'gray.600');
  return (
    <Box minW={'full'} w={'full'}>
      <Flex justifyContent={['start', 'end']} py={3} px={[2, 6]}>
        <Box w={'xs'}>
          <GlobalFilter table={table} />
        </Box>
      </Flex>
      <Divider borderBottomColor={toolBarBorderColor} w="full" />
    </Box>
  );
}

function TableHeader<TData>({ table }: { table: TanstackTable<TData> }) {
  return (
    <Thead>
      {table.getHeaderGroups().map((headerGroup) => (
        <Tr key={headerGroup.id}>
          {headerGroup.headers.map((header, headerIndex) => {
            const width = header.column.getSize();
            return (
              <Th
                sx={{
                  width: headerIndex === 0 ? width * 2 : width,
                }}
                key={header.id + headerIndex}
                colSpan={headerIndex === 0 ? 2 : 1}
              >
                {header.isPlaceholder ? null : (
                  <>{flexRender(header.column.columnDef.header, header.getContext())}</>
                )}
              </Th>
            );
          })}
        </Tr>
      ))}
    </Thead>
  );
}

function TableBody<TData>({ table }: { table: TanstackTable<TData> }) {
  const groupTitleBgColor = useColorModeValue('gray.25', 'gray.800');
  const groupTitleBorderColor = useColorModeValue('gray.200', 'whiteAlpha.300');

  return (
    <Tbody>
      {table.getRowModel().rows.map((row) => {
        if (row.subRows.length !== 0) {
          return (
            <Tr
              key={row.id}
              backgroundColor={groupTitleBgColor}
              borderTop="1px"
              borderBottom="1px"
              borderColor={groupTitleBorderColor}
            >
              <Td
                py={3}
                colSpan={row.getAllCells().length}
                fontWeight="semibold"
                textColor="gray.500"
                lineHeight="6"
              >
                {row.groupingValue as string}
              </Td>
            </Tr>
          );
        }

        return (
          <Tr key={row.id} h={12}>
            {row.getVisibleCells().map((cell, cellIndex) => {
              return (
                <Td key={cell.id} colSpan={cellIndex === 0 ? 2 : 1} overflowX={'clip'}>
                  {flexRender(cell.column.columnDef.cell, cell.getContext())}
                </Td>
              );
            })}
          </Tr>
        );
      })}
    </Tbody>
  );
}
