import { useRef } from 'react';
import { flushSync } from 'react-dom';
import { Input, InputHandle } from '../../components/input';
import { useComposedRefs } from '@radix-ui/react-compose-refs';
import { parseNumberAndUnit } from './parse-number-and-unit';
import { decimal } from './decimal';

type Unit = 'px' | 'em' | '%' | 'deg' | 's' | 'ms';

interface NumberInputProps extends React.ComponentProps<typeof Input> {
  min?: number;
  max?: number;
  integer?: boolean;

  /** Allowed units */
  units?: Unit[];

  /** Allowed non-numeric keywords (lowercase) */
  keywords?: string[];

  /** Small and large nudge amounts */
  increments?: [number, number];
}

export const NumberInput = ({
  integer = false,
  min = -Infinity,
  max = Infinity,
  units = ['px'],
  keywords = [],
  increments = [1, 10],
  format = (value) => formatNumber({ value, min, max, integer, keywords, units }),
  parse = (value) => parseNumber({ value, min, max, integer, keywords, units }),
  select = units.includes('px') && units.length > 1 ? selectNumber : (input) => input?.select(),
  ...props
}: NumberInputProps) => {
  const ref = useRef<InputHandle>(null);

  return (
    <Input
      {...props}
      ref={useComposedRefs(ref, props.ref)}
      format={format}
      parse={parse}
      select={select}
      onKeyDown={(event) => {
        if (!ref.current) {
          return;
        }

        if (event.key === 'ArrowUp' || event.key === 'ArrowDown') {
          event.preventDefault();
          const direction = event.key === 'ArrowUp' ? 1 : -1;
          const [smallIncrement, largeIncrement] = increments;
          let amount = event.shiftKey ? largeIncrement : smallIncrement;

          const [defaultUnit = ''] = units;
          const defaultNumber = Math.max(min, Math.min(0, max));
          const value = parse(ref.current.value) ?? parse(props.value);

          const defaultValue = defaultNumber + defaultUnit;
          let number = defaultNumber;
          let unit = defaultUnit;
          let newValue = defaultValue;

          if (value !== null) {
            [number, unit = defaultUnit] = parseNumberAndUnit(value, integer);
          }

          if (Number.isNaN(number)) {
            number = defaultNumber;
          } else {
            newValue = decimal(Math.min(max, Math.max(min, number + amount * direction))) + unit;
          }

          newValue = parse(newValue) ?? defaultValue;
          flushSync(() => ref.current?.commitValue(newValue));
          select(ref.current);
        }

        props.onKeyDown?.(event);
      }}
    />
  );
};

interface NumberFnParams {
  value: string;
  min: number;
  max: number;
  integer: boolean;
  keywords: string[];
  units: Unit[];
}

/** Takes our value and formats it before displaying to the user (e.g. "20px" -> "20") */
function formatNumber({ value, keywords, units }: NumberFnParams) {
  if (units.includes('px') && /(\d)px$/.test(value)) {
    return value.replace(/(\d)px$/, '$1');
  }

  if (keywords.includes(value)) {
    const [W = '', ...word] = value;
    return `${W.toUpperCase()}${word.join('')}`;
  }

  return value;
}

/** Takes user input and parses it for us to save (e.g. "20" -> "20px") */
function parseNumber({ value, min, max, integer, keywords, units }: NumberFnParams) {
  const [defaultUnit = ''] = units;

  // Make sure "Infinity" keyword isn't parsed
  if (/^\s+-?Infinity/.test(value)) {
    return null;
  }

  let [number, unit] = parseNumberAndUnit(value, integer);

  if (Number.isNaN(number)) {
    if (keywords.includes(value.toLowerCase())) {
      return value.toLowerCase();
    }

    return null;
  }

  number = Math.min(max, Math.max(min, number));

  if ((units as string[]).includes(unit)) {
    return number + unit;
  }

  if (!unit) {
    if (units.includes('px')) {
      return number + 'px';
    }

    return number + defaultUnit;
  }

  return null;
}

/** Selects numeric value in the input without selecting the unit */
function selectNumber(input: HTMLInputElement | null) {
  if (!input) return;
  const [number] = parseNumberAndUnit(input.value);

  if (Number.isNaN(number)) {
    input.select();
    return;
  }

  input.selectionStart = 0;
  input.selectionEnd = number.toString().length;
}
