import { useMemo } from 'react'

import {
  DataGridPremiumProps,
  GridRowClassNameParams,
} from '@mui/x-data-grid-premium'
import { GridApiPremium } from '@mui/x-data-grid-premium/models/gridApiPremium'

import { TREE_CHILD_CLASS, TREE_PARENT_CLASS } from './DataGridStyles'
import { getGroupingColDef } from './TreeExpander'
import { useRowExpansionProps } from './useRowExpansionProps'

type TreeGroupingOptions<TDto extends Record<string, any>> = {
  groupingColOpts?: Parameters<typeof getGroupingColDef>[0]
  getRowId: (row: TDto) => string
  getTreeDataPath: (row: TDto) => string[]
  getRowClassName?: (params: GridRowClassNameParams<TDto>) => string
}

export function useTreeGroupingProps<TDto extends Record<string, any>>(
  api: React.MutableRefObject<GridApiPremium>,
  {
    groupingColOpts,
    getRowClassName,
    getRowId,
    getTreeDataPath,
  }: TreeGroupingOptions<TDto>
): Partial<DataGridPremiumProps<TDto>> {
  groupingColOpts ??= {}
  groupingColOpts.width ??= 100

  const defaultGetRowClassName = (params: GridRowClassNameParams<any>) => {
    const pathLength = getTreeDataPath(params.row).length
    return pathLength % 2 === 0 ? TREE_CHILD_CLASS : TREE_PARENT_CLASS
  }

  const rowExpansion = useRowExpansionProps(api)

  return {
    treeData: true,
    isGroupExpandedByDefault: rowExpansion.isGroupExpandedByDefault,
    groupingColDef: useMemo(
      () =>
        getGroupingColDef({
          ...groupingColOpts,
        }),
      [groupingColOpts]
    ),
    getRowId: getRowId,
    getTreeDataPath: getTreeDataPath,
    getRowClassName: getRowClassName ?? defaultGetRowClassName,
  }
}
