import { ColorSpace, Oklch } from '@paper/models/src/colors/Color';
import { convertOklabToLinearP3 } from '@paper/models/src/colors/color-conversion';
import { clampChroma, convertOklabToLrgb, modeOklch, modeP3, useMode as enableCuloriMode } from 'culori/fn';
import { assert } from '@paper/models/src/assert';
import { isInGamut } from '@paper/models/src/colors/Color';

// When using the tree-shakeable version of Culori via `culori/fn`, we need to
// enable modes explicitly. There are used needed for the clampChroma function.
enableCuloriMode(modeP3);
enableCuloriMode(modeOklch);

/** The info you need to generate a color map or find the edges of the slice of hue */
export type HueSlice = {
  /** The hue of the gamut slice in degrees */
  hue: number;
  /** The ratio used in the oklch lightness compression, it varies by hue */
  toe: number;
  /** The saturation at the cusp, the most saturated point in the top right of the color picker */
  cuspSaturation: number;
  /** The lightness at the cusp */
  cuspL: number;
  /** The chroma at the cusp */
  cuspC: number;
};

/** Takes a hue and a color space and returns the info you need to know the bounds */
export function getHueSlice(hue: number, colorSpace: 'p3' | 'rgb'): HueSlice {
  const a = Math.cos((2 * Math.PI * hue) / 360);
  const b = Math.sin((2 * Math.PI * hue) / 360);

  // First, find the maximum saturation (saturation S = C/L)
  let S_cusp = colorSpace === 'p3' ? getMaxSaturationForHueP3(hue) : getMaxSaturationForHueSrgb(a, b);

  // Convert to linear sRGB to find the first point where at least one of r,g or b >= 1:
  let rgb_at_max =
    colorSpace === 'p3'
      ? convertOklabToLinearP3({ l: 1, a: S_cusp * a, b: S_cusp * b })
      : convertOklabToLrgb({ l: 1, a: S_cusp * a, b: S_cusp * b });

  const L_cusp = Math.cbrt(1 / Math.max(Math.max(rgb_at_max.r, rgb_at_max.g), rgb_at_max.b));
  const C_cusp = L_cusp * S_cusp;
  const toe = C_cusp / (1 - L_cusp);

  return {
    cuspL: L_cusp,
    cuspC: C_cusp,
    cuspSaturation: S_cusp,
    hue,
    toe,
  };
}

function findMaxChromaForLightness(hueSlice: HueSlice, lightness: number, colorSpace: 'p3' | 'rgb') {
  // Convert hue to normalized a,b coordinates
  const hueRad = hueSlice.hue * (Math.PI / 180);
  const a = Math.cos(hueRad);
  const b = Math.sin(hueRad);

  // The chroma used as the top of the line, specific value is not super important as long as it's higher than the cusp's C
  // Using "1" puts our intersection result in the same scale as oklch's chroma, which is maybe useful for debugging or reading
  const chroma = 1;

  let intersection: number;
  if (colorSpace === 'p3') {
    // Ask for "straight down from hueSlice.cuspC on this lightness" intersection:
    intersection = findGamutIntersectionP3(a, b, lightness, chroma, lightness, hueSlice);
  } else {
    // Ask for "straight down from hueSlice.cuspC on this lightness" intersection:
    intersection = findGamutIntersectionSrgb(a, b, lightness, chroma, lightness, hueSlice);
  }

  // Return the chroma at the intersection
  // Important to multiply by the chroma used as the top of the line, not the cusp's C (since intersection is relative to the line)
  return intersection * chroma;
}

/**
 * Finds the closest valid color inside the provided color space, trying to match visually
 * You can provide a hueSlice if you already have it, but sometimes this is for a new color entirely
 * so it will generate a hueSlice for you based on the provided color's hue
 */
export function findClosestColorInGamut(
  color: Oklch,
  colorSpace: ColorSpace,
  recursionCount = 0,
  hueSliceProvided?: HueSlice
): Oklch {
  if (recursionCount === 0 && isInGamut(color, colorSpace) === true) {
    return color;
  }

  /**
   *
   *   C
   *   |                                 • clipped color
   *   |     hue cusp •
   *   |
   *   |                           •
   *   |          closest color ↗        • limit (derived from the clipped
   *   |           on the gamut            color by clamping the chroma)
   *   |__________________________________________________________________ L
   *
   * Given a hue cusp, a clipped color, and a chroma limit, the closest color in the gamut
   * is going to be a point closest to the clipped color that lies on the line between the
   * hue cusp and the limit.
   *
   * Recursive since the gamut shape isn't always a straight line.
   */

  const hueSlice = hueSliceProvided ?? getHueSlice(color.h, colorSpace);
  assert(hueSlice.hue === color.h, 'findClosestInGamut - hue slice hue does not match color hue');

  const clippedL = color.l; // matches `limitL`
  const clippedC = color.c;

  const limitL = clippedL;
  const limitC = findMaxChromaForLightness(hueSlice, clippedL, colorSpace);

  if (clippedC === limitC) {
    console.warn('Unexpected: the color to bring into gamut is already on the gamut boundary. Hue:', color.h);
    return color;
  }

  // It's usually just 3-4 recursions
  if (recursionCount > 20) {
    console.warn('Unexpected: more than 20 attempts to find a color in gamut. Color:', { ...color });
    return { ...color, c: limitC };
  }

  // Find intersection of a triangle between the cusp, the clipped color, and the limit version of the color
  // (max chroma at the same lightness). This will likely be outside the gamut if we're coming from outside
  // on the scooped side of the hue slice triangle
  const dl = hueSlice.cuspL - limitL;
  const dc = hueSlice.cuspC - limitC;

  // A coefficient that determines how much distance we'd need to move from the limit point
  // towards the cusp alongside the line that connects the limit and the cusp.
  //
  // `t === 1` would land us right on the cusp.
  //
  // If `t > 1`, then the resulting point would be above the cusp (which makes no sense).
  // This is possible in a really sharp triangle where a perpendicular line to the base can't be
  // formed, and the closest point from the clipped color to the hue slice is the cusp itself.
  const t = Math.min(1, ((clippedC - limitC) * dc) / (dl ** 2 + dc ** 2));

  if (Number.isNaN(t) || t < 0) {
    console.warn('Unexpected: invalid geometry when finding a color in gamut. Color:', { ...color });
    return { ...color, c: limitC };
  }

  const resolvedL = limitL + t * dl;
  const resolvedC = limitC + t * dc;
  const candidate = { ...color, l: resolvedL, c: resolvedC };

  // We initially assume a perfect triangle, but this might be a concave curve:
  // keep searching until we encounter a color in gamut.
  if (isInGamut(candidate, colorSpace) === false) {
    return findClosestColorInGamut(candidate, colorSpace, recursionCount + 1, hueSlice);
  }

  return candidate;
}

/**
 * For a given hue, uses a quick hill climb algo to find the maximum chroma and lightness for that hue
 * Benchmarks on my M4 put this under a millisecond
 *
 * We could also adapt Bjorn's method for sRGB instead of doing this, but I had trouble accurately defining the RGB boundaries
 *
 * This is the value that would be at the peak of the triangle if you graphed lightness on the X and chroma on the Y
 * for a given hue and it turns into the top right corner when we map a hue slice into a square
 */
function getMaxSaturationForHueP3(hue: number): number {
  // Constants
  const INITIAL_STEP = 0.2; // Initial step size when hunting
  const MIN_STEP = 0.00005; // Precision of answer
  // (in testing, 0.0005 required to get 255 in upper right corner, except P3 red is especially tricky and needs 0.00005)
  // this results in ~12-20 iterations depending on hue, which is quite fast
  // P3 with culori clampChroma is measuring at 2ms for this entire function on my M4
  const MAX_ITERATIONS = 50; // Prevent infinite loops
  const IMPROVEMENT_THRESHOLD = MIN_STEP * 0.5; // Minimum meaningful improvement, don't pursue non-meaningful improvements

  // Start variables
  let lightness = 0.742; // This is the average lightness of maximum chromas so a good place to start
  let currentStep = INITIAL_STEP;
  let iterations = 0;

  // To find the max chroma for a given hue and lightness, Culori's clampChroma function works well in P3 space (has problems in sRGB though)
  const findMaxChroma = (hue: number, lightness: number) =>
    clampChroma({ mode: 'oklch', l: lightness, c: 1, h: hue }, 'oklch', 'p3').c;

  // Cache initial chroma calculation
  let bestChroma = findMaxChroma(hue, lightness);

  // Reusable object to avoid allocations
  const colorObj = { mode: 'oklch', l: lightness, c: 0.4, h: hue } satisfies Oklch;

  while (currentStep >= MIN_STEP && iterations < MAX_ITERATIONS) {
    let improved = false;
    iterations++;

    // Test upward direction
    colorObj.l = lightness + currentStep;
    const upChroma = findMaxChroma(hue, colorObj.l);

    if (upChroma > bestChroma + IMPROVEMENT_THRESHOLD) {
      lightness += currentStep;
      bestChroma = upChroma;
      improved = true;
    } else {
      // Only test downward if upward didn't improve
      colorObj.l = lightness - currentStep;
      const downChroma = findMaxChroma(hue, colorObj.l);

      if (downChroma > bestChroma + IMPROVEMENT_THRESHOLD) {
        lightness -= currentStep;
        bestChroma = downChroma;
        improved = true;
      }
    }

    // If no improvement, reduce step size more aggressively
    if (!improved) {
      currentStep *= 0.5;
    }
  }

  // Convert lightness and chroma to saturation
  const saturation = bestChroma / lightness;
  return saturation;
}

function toe(x: number) {
  const k_1 = 0.206;
  const k_2 = 0.03;
  const k_3 = (1 + k_1) / (1 + k_2);

  return 0.5 * (k_3 * x - k_1 + Math.sqrt((k_3 * x - k_1) * (k_3 * x - k_1) + 4 * k_2 * k_3 * x));
}

function toeInverse(x: number) {
  const k_1 = 0.206;
  const k_2 = 0.03;
  const k_3 = (1 + k_1) / (1 + k_2);
  return (x * x + k_1 * x) / (k_3 * (x + k_2));
}

export const HALLEY_ITERATIONS = 3;

// Finds the maximum saturation possible for a given hue that fits in sRGB
// Saturation here is defined as S = C/L
// a and b must be normalized so a^2 + b^2 == 1
function getMaxSaturationForHueSrgb(a: number, b: number) {
  // Max saturation will be when one of r, g or b goes below zero.
  // Select different coefficients depending on which component goes below zero first
  let k0, k1, k2, k3, k4, wl, wm, ws;

  if (-1.88170328 * a - 0.80936493 * b > 1) {
    // Red component
    k0 = +1.19086277;
    k1 = +1.76576728;
    k2 = +0.59662641;
    k3 = +0.75515197;
    k4 = +0.56771245;
    wl = +4.0767416621;
    wm = -3.3077115913;
    ws = +0.2309699292;
  } else if (1.81444104 * a - 1.19445276 * b > 1) {
    // Green component
    k0 = +0.73956515;
    k1 = -0.45954404;
    k2 = +0.08285427;
    k3 = +0.1254107;
    k4 = +0.14503204;
    wl = -1.2684380046;
    wm = +2.6097574011;
    ws = -0.3413193965;
  } else {
    // Blue component
    k0 = +1.35733652;
    k1 = -0.00915799;
    k2 = -1.1513021;
    k3 = -0.50559606;
    k4 = +0.00692167;
    wl = -0.0041960863;
    wm = -0.7034186147;
    ws = +1.707614701;
  }

  // Approximate max saturation using a polynomial:
  let saturation = k0 + k1 * a + k2 * b + k3 * a * a + k4 * a * b;

  // Use Halley's method to get closer
  // this gives an error less than 10e6, except for some blue hues where the dS/dh is close to infinite
  // this should be sufficient for most applications, otherwise do two/three steps
  // Stephen's note: do 3 or red clips on right edge

  let k_l = +0.3963377774 * a + 0.2158037573 * b;
  let k_m = -0.1055613458 * a - 0.0638541728 * b;
  let k_s = -0.0894841775 * a - 1.291485548 * b;

  for (let i = 0; i < HALLEY_ITERATIONS; i++) {
    let l_ = 1 + saturation * k_l;
    let m_ = 1 + saturation * k_m;
    let s_ = 1 + saturation * k_s;

    let l = l_ * l_ * l_;
    let m = m_ * m_ * m_;
    let s = s_ * s_ * s_;

    let l_dS = 3 * k_l * l_ * l_;
    let m_dS = 3 * k_m * m_ * m_;
    let s_dS = 3 * k_s * s_ * s_;

    let l_dS2 = 6 * k_l * k_l * l_;
    let m_dS2 = 6 * k_m * k_m * m_;
    let s_dS2 = 6 * k_s * k_s * s_;

    let f = wl * l + wm * m + ws * s;
    let f1 = wl * l_dS + wm * m_dS + ws * s_dS;
    let f2 = wl * l_dS2 + wm * m_dS2 + ws * s_dS2;

    saturation = saturation - (f * f1) / (f1 * f1 - 0.5 * f * f2);
  }

  return saturation;
}

/**
 * Takes a line and determines where the boundary intersects on the line
 * Useful for finding the maximum chroma for a given lightness and hue combo
 * Note the result is the intersection relative to the line, NOT a point on the triangle
 */
function findGamutIntersectionSrgb(
  a: number, // Hue normalized a
  b: number, // Hue normalized b
  L1: number, // Target lightness
  C1: number, // Target chroma
  L0: number, // Starting lightness
  hueSlice: HueSlice
): number {
  const cuspL = hueSlice.cuspL;
  const cuspC = hueSlice.cuspC;

  // Find intersection for upper and lower half separately
  let intersection: number;
  if ((L1 - L0) * cuspC - (cuspL - L0) * C1 <= 0) {
    // Lower half
    intersection = (cuspC * L0) / (C1 * cuspL + cuspC * (L0 - L1));
  } else {
    // Upper half
    // First intersect with triangle
    intersection = (cuspC * (L0 - 1)) / (C1 * (cuspL - 1) + cuspC * (L0 - L1));

    // Then step Halley's method
    const dL = L1 - L0;
    const dC = C1;
    const k_l = +0.3963377774 * a + 0.2158037573 * b;
    const k_m = -0.1055613458 * a - 0.0638541728 * b;
    const k_s = -0.0894841775 * a - 1.291485548 * b;
    const l_dt = dL + dC * k_l;
    const m_dt = dL + dC * k_m;
    const s_dt = dL + dC * k_s;

    for (let i = 0; i < HALLEY_ITERATIONS; i++) {
      const L = L0 * (1 - intersection) + intersection * L1;
      const C = intersection * C1;
      const l_ = L + C * k_l;
      const m_ = L + C * k_m;
      const s_ = L + C * k_s;
      const l = l_ * l_ * l_;
      const m = m_ * m_ * m_;
      const s = s_ * s_ * s_;
      const ldt = 3 * l_dt * l_ * l_;
      const mdt = 3 * m_dt * m_ * m_;
      const sdt = 3 * s_dt * s_ * s_;
      const ldt2 = 6 * l_dt * l_dt * l_;
      const mdt2 = 6 * m_dt * m_dt * m_;
      const sdt2 = 6 * s_dt * s_dt * s_;

      const r = 4.0767416621 * l - 3.3077115913 * m + 0.2309699292 * s - 1;
      const r1 = 4.0767416621 * ldt - 3.3077115913 * mdt + 0.2309699292 * sdt;
      const r2 = 4.0767416621 * ldt2 - 3.3077115913 * mdt2 + 0.2309699292 * sdt2;
      const u_r = r1 / (r1 * r1 - 0.5 * r * r2);
      const t_r = -r * u_r;

      const g = -1.2684380046 * l + 2.6097574011 * m - 0.3413193965 * s - 1;
      const g1 = -1.2684380046 * ldt + 2.6097574011 * mdt - 0.3413193965 * sdt;
      const g2 = -1.2684380046 * ldt2 + 2.6097574011 * mdt2 - 0.3413193965 * sdt2;
      const u_g = g1 / (g1 * g1 - 0.5 * g * g2);
      const t_g = -g * u_g;

      const b = -0.0041960863 * l - 0.7034186147 * m + 1.707614701 * s - 1;
      const b1 = -0.0041960863 * ldt - 0.7034186147 * mdt + 1.707614701 * sdt;
      const b2 = -0.0041960863 * ldt2 - 0.7034186147 * mdt2 + 1.707614701 * sdt2;
      const u_b = b1 / (b1 * b1 - 0.5 * b * b2);
      const t_b = -b * u_b;

      const t_r_final = u_r >= 0 ? t_r : 1e5;
      const t_g_final = u_g >= 0 ? t_g : 1e5;
      const t_b_final = u_b >= 0 ? t_b : 1e5;

      intersection += Math.min(t_r_final, Math.min(t_g_final, t_b_final));
    }
  }

  return intersection;
}

/**
 * Takes a line and determines where the boundary intersects on the line
 * Useful for finding the maximum chroma for a given lightness and hue combo
 * Note the result is the intersection relative to the line, NOT a point on the triangle
 */
function findGamutIntersectionP3(
  a: number, // Hue normalized a
  b: number, // Hue normalized b
  L1: number, // Target lightness
  C1: number, // Target chroma
  L0: number, // Starting lightness
  hueSlice: HueSlice
): number {
  const cuspL = hueSlice.cuspL;
  const cuspC = hueSlice.cuspC;

  // Find intersection for upper and lower half separately
  let intersection: number;
  if ((L1 - L0) * cuspC - (cuspL - L0) * C1 <= 0) {
    // Lower half
    intersection = (cuspC * L0) / (C1 * cuspL + cuspC * (L0 - L1));
  } else {
    // Upper half
    // First intersect with triangle
    intersection = (cuspC * (L0 - 1)) / (C1 * (cuspL - 1) + cuspC * (L0 - L1));

    // Then step Halley's method
    const dL = L1 - L0;
    const dC = C1;
    const k_l = +0.3963377774 * a + 0.2158037573 * b;
    const k_m = -0.1055613458 * a - 0.0638541728 * b;
    const k_s = -0.0894841775 * a - 1.291485548 * b;
    const l_dt = dL + dC * k_l;
    const m_dt = dL + dC * k_m;
    const s_dt = dL + dC * k_s;

    for (let i = 0; i < HALLEY_ITERATIONS; i++) {
      const L = L0 * (1 - intersection) + intersection * L1;
      const C = intersection * C1;
      const l_ = L + C * k_l;
      const m_ = L + C * k_m;
      const s_ = L + C * k_s;
      const l = l_ * l_ * l_;
      const m = m_ * m_ * m_;
      const s = s_ * s_ * s_;
      const ldt = 3 * l_dt * l_ * l_;
      const mdt = 3 * m_dt * m_ * m_;
      const sdt = 3 * s_dt * s_ * s_;
      const ldt2 = 6 * l_dt * l_dt * l_;
      const mdt2 = 6 * m_dt * m_dt * m_;
      const sdt2 = 6 * s_dt * s_dt * s_;

      const r = 3.127768971361874 * l - 2.2571357625916395 * m + 0.12936679122976516 * s - 1;
      const r1 = 3.127768971361874 * ldt - 2.2571357625916395 * mdt + 0.12936679122976516 * sdt;
      const r2 = 3.127768971361874 * ldt2 - 2.2571357625916395 * mdt2 + 0.12936679122976516 * sdt2;
      const u_r = r1 / (r1 * r1 - 0.5 * r * r2);
      const t_r = -r * u_r;

      const g = -1.0910090184377979 * l + 2.413331710306922 * m - 0.32232269186912466 * s - 1;
      const g1 = -1.0910090184377979 * ldt + 2.413331710306922 * mdt - 0.32232269186912466 * sdt;
      const g2 = -1.0910090184377979 * ldt2 + 2.413331710306922 * mdt2 - 0.32232269186912466 * sdt2;
      const u_g = g1 / (g1 * g1 - 0.5 * g * g2);
      const t_g = -g * u_g;

      const b = -0.02601080193857028 * l - 0.508041331704167 * m + 1.5340521336427373 * s - 1;
      const b1 = -0.02601080193857028 * ldt - 0.508041331704167 * mdt + 1.5340521336427373 * sdt;
      const b2 = -0.02601080193857028 * ldt2 - 0.508041331704167 * mdt2 + 1.5340521336427373 * sdt2;
      const u_b = b1 / (b1 * b1 - 0.5 * b * b2);
      const t_b = -b * u_b;

      const t_r_final = u_r >= 0 ? t_r : 1e5;
      const t_g_final = u_g >= 0 ? t_g : 1e5;
      const t_b_final = u_b >= 0 ? t_b : 1e5;

      intersection += Math.min(t_r_final, Math.min(t_g_final, t_b_final));
    }
  }
  return intersection;
}

/** Turns OKHSV-like coords into OKLCH  */
export function getOklchFromXY(x: number, y: number, colorSpace: 'p3' | 'rgb', hueSlice: HueSlice): Oklch {
  const normalizedHue = hueSlice.hue / 360;
  let a_ = Math.cos(2 * Math.PI * normalizedHue);
  let b_ = Math.sin(2 * Math.PI * normalizedHue);

  let S_max = hueSlice.cuspSaturation;
  let S_0 = 0.5;
  let T = hueSlice.toe;
  let k = 1 - S_0 / S_max;

  let L_v = 1 - (x * S_0) / (S_0 + T - T * k * x);
  let C_v = (x * T * S_0) / (S_0 + T - T * k * x);

  let L = y * L_v;
  let C = y * C_v;

  let L_vt = toeInverse(L_v);
  let C_vt = (C_v * L_vt) / L_v;

  let L_new = toeInverse(L); // * L_v/L_vt;
  C = (C * L_new) / L;
  L = L_new;

  // Bjorn's code scales the lightness here, I believe because it dedicates more pixels to useful colors
  if (colorSpace === 'p3') {
    // P3
    const p3_scale = convertOklabToLinearP3({ l: L_vt, a: a_ * C_vt, b: b_ * C_vt });
    const scale_L = Math.cbrt(1 / Math.max(p3_scale.r, p3_scale.g, p3_scale.b, 0));
    L = L * scale_L;
    C = C * scale_L;
  } else {
    // sRGB
    const rgb_scale = convertOklabToLrgb({ l: L_vt, a: a_ * C_vt, b: b_ * C_vt });
    const scale_L = Math.cbrt(1 / Math.max(rgb_scale.r, rgb_scale.g, rgb_scale.b, 0));
    // remove to see effect without rescaling
    L = L * scale_L;
    C = C * scale_L;
  }

  // Return black if any component is NaN which can happen at lightness 0 causing division by 0
  if (Number.isNaN(L) || Number.isNaN(C)) {
    return { mode: 'oklch', l: 0, c: 0, h: hueSlice.hue };
  }

  return { mode: 'oklch', l: L, c: C, h: hueSlice.hue };
}

/** Takes an OKLCH color and converts it into OKHSV-like coordinates */
export function getXYFromOklch(
  { l, c, h }: Omit<Oklch, 'mode'>,
  colorSpace: 'p3' | 'rgb',
  hueSlice: HueSlice
): [number, number] {
  const normalizedHue = h / 360;
  let a_ = Math.cos(2 * Math.PI * normalizedHue);
  let b_ = Math.sin(2 * Math.PI * normalizedHue);

  let S_max = hueSlice.cuspSaturation;
  let S_0 = 0.5;
  let T = hueSlice.toe;
  let k = 1 - S_0 / S_max;

  const t = T / (c + l * T);
  const L_v = t * l;
  const C_v = t * c;

  const L_vt = toeInverse(L_v);
  const C_vt = (C_v * L_vt) / L_v;

  let scale_L: number;
  if (colorSpace === 'p3') {
    const p3_scale = convertOklabToLinearP3({ l: L_vt, a: a_ * C_vt, b: b_ * C_vt });
    scale_L = Math.cbrt(1 / Math.max(p3_scale.r, p3_scale.g, p3_scale.b, 0));
  } else {
    const rgb_scale = convertOklabToLrgb({ l: L_vt, a: a_ * C_vt, b: b_ * C_vt });
    scale_L = Math.cbrt(1 / Math.max(rgb_scale.r, rgb_scale.g, rgb_scale.b, 0));
  }
  let L = l / scale_L;
  let C = c / scale_L;

  C = (C * toe(L)) / L;
  L = toe(L);

  const v = L / L_v;
  const s = ((S_0 + T) * C_v) / (T * S_0 + T * k * C_v);

  // If any division by 0 happened we might get a NaN, which indicates black
  if (Number.isNaN(s) || Number.isNaN(v)) {
    return [0, 0];
  }

  return [s, v];
}
