import { useRef, useEffect, forwardRef, RefObject, ReactNode } from 'react';

const KEYCODE_TAB = 9;

function useFocusTrap(isActive: boolean): RefObject<HTMLDivElement> {
  const elRef = useRef<HTMLDivElement>(null);
  const triggerRef = useRef<HTMLElement | null>(null);

  function handleFocus(e: KeyboardEvent) {
    if (!elRef.current) return;

    const focusableEls = elRef.current.querySelectorAll<HTMLElement>(
      'a[href], button, textarea, input[type="text"], input[type="radio"], input[type="checkbox"], select',
    );
    const firstFocusableEl = focusableEls[0];
    const lastFocusableEl = focusableEls[focusableEls.length - 1];
    const isTabPressed = e.key === 'Tab' || e.keyCode === KEYCODE_TAB;

    if (!isTabPressed) {
      return;
    }

    if (e.shiftKey) {
      // shift+tab
      if (document.activeElement === firstFocusableEl) {
        lastFocusableEl.focus();
        e.preventDefault();
      }
    } else {
      // tab
      if (document.activeElement === lastFocusableEl) {
        firstFocusableEl.focus();
        e.preventDefault();
      }
    }
  }

  useEffect(() => {
    const currentEl = elRef.current;
    if (isActive && currentEl) {
      if (!triggerRef.current) {
        triggerRef.current = document.activeElement as HTMLElement;
      }

      const focusableEls = currentEl.querySelectorAll<HTMLElement>(
        'a[href], button, textarea, input[type="text"], input[type="radio"], input[type="checkbox"], select',
      );
      if (focusableEls.length > 0) {
        focusableEls[0].focus();
      }
      currentEl.addEventListener('keydown', handleFocus);
    } else if (!isActive && triggerRef.current) {
      triggerRef.current.focus();
      triggerRef.current = null;
    }
    return () => {
      if (currentEl) {
        currentEl.removeEventListener('keydown', handleFocus);
      }
    };
  }, [isActive]);

  return elRef;
}

interface FocusTrapProps {
  isActive: boolean;
  children: ReactNode;
}

const FocusTrap = forwardRef<HTMLDivElement, FocusTrapProps>(({ isActive, children }, _ref) => {
  const elRef = useFocusTrap(isActive);

  return (
    <div className="trap" ref={elRef}>
      {children}
    </div>
  );
});

FocusTrap.displayName = 'FocusTrap';

export default FocusTrap;
