/* eslint react/no-unknown-property: off, @typescript-eslint/no-unused-vars: off */
import { assertDefined } from "@augmedi/type-utils";
import { useFrame, useThree, type ThreeEvent } from "@react-three/fiber";
import assert from "assert-ts";
import { maxBy, sortBy } from "lodash-es";
import { useEffect, useMemo, useState } from "react";
import * as THREE from "three";
import { FullScreenQuad } from "three-stdlib";
import { BitSelectMaterial } from "../logic/bit-select-material.js";
import { clickedWithinTolerance } from "../logic/click-detection-bit-select-material.js";
import {
  clickDetectionFragmentShader,
  clickDetectionVertexShader,
  useWhiteMaterial,
} from "../logic/click-detection-shader.js";
import {
  useEmptyTexture,
  useOurGltf,
  type LoadedLabelMask,
  type MaskCompatibleTexture,
} from "../logic/gltf.js";
import { useMemoDisposable } from "../logic/memo-cleanup.js";
import {
  allOverlayMaterialCombinedChannelKeys,
  defaultOverlayMaterialChannelSettings,
  isOverlayMaterialCombinedChannelKey,
  overlayMaterialCombinedChannelKeyToIndex,
  type OverlayMaterialChannelSettings,
  type OverlayMaterialCombinedChannelKeys,
} from "../logic/overlay-material.js";
import { usePickLabelMasks as usePickPaintedLabelMasks } from "../logic/picking.js";

export interface Props {
  gltfUrl: string;
  meshNamesByWholeMeshLabelId: Map<string, string>;
  visibleGltfMeshNames?: Set<string>;
  visibleLabelIdsPerChannel?: {
    [channel in OverlayMaterialCombinedChannelKeys]?: string[];
  };
  settingsByChannel?: {
    [channel in OverlayMaterialCombinedChannelKeys]?: Partial<OverlayMaterialChannelSettings>;
  };
  desiredLabelIdsPerChannel?: { CombinedR?: string[] | undefined };
  onClick?: (labelIds: string[]) => void;
  frozenMeshIdsByMeshName?: Map<string, string>;
}

type GroupedSources = {
  [channel in OverlayMaterialCombinedChannelKeys]: Map<
    MaskCompatibleTexture,
    Set<number> // set of textureBit
  >;
};

interface PaintedState {
  renderTarget: THREE.WebGLRenderTarget;
  groupedSources: GroupedSources;
}

type ChannelValues = [number, number, number, number];

interface CombinedHighlightWholeMeshState {
  channelValuesByMeshName: Map<string, ChannelValues>;
}

interface CombinedPaintedState {
  visibleLabelsPaintedState?: PaintedState;
  desiredLabelsPaintedState?: PaintedState;
  wholeMeshState: CombinedHighlightWholeMeshState;
  didRender: boolean;
}

function getPaintedStateAndChannelMap(
  meshNamesByWholeMeshLabelId: Map<string, string>,
  labelMasks: {
    [id: string]: LoadedLabelMask;
  },
  labelIdsPerChannel?: {
    [channel in OverlayMaterialCombinedChannelKeys]?: string[];
  },
): [PaintedState | undefined, Map<string, ChannelValues>] {
  const groupedSources: GroupedSources = {
    CombinedR: new Map(),
    CombinedG: new Map(),
    CombinedB: new Map(),
    CombinedA: new Map(),
  };
  const channelValuesByMeshName = new Map<string, ChannelValues>();

  if (labelIdsPerChannel) {
    for (const [channel, labelIds] of Object.entries(labelIdsPerChannel)) {
      if (!isOverlayMaterialCombinedChannelKey(channel)) {
        // The object might have other properties, since our type can't restrict that.
        continue;
      }
      if (!labelIds) {
        continue;
      }
      for (const labelId of labelIds) {
        const wholeMeshName = meshNamesByWholeMeshLabelId.get(labelId);
        if (wholeMeshName === undefined) {
          const labelMask = labelMasks[labelId];
          if (!labelMask) {
            console.warn(`Missing mask for label ${labelId}`);
            continue;
          }
          let textureBitSet = groupedSources[channel].get(labelMask.texture);
          if (!textureBitSet) {
            textureBitSet = new Set();
            groupedSources[channel].set(labelMask.texture, textureBitSet);
          }
          textureBitSet.add(labelMask.textureBit);
        } else {
          if (!channelValuesByMeshName.has(wholeMeshName)) {
            channelValuesByMeshName.set(wholeMeshName, [0, 0, 0, 0]);
          }
          const channelValues = assertDefined(
            channelValuesByMeshName.get(wholeMeshName),
          );
          channelValues[overlayMaterialCombinedChannelKeyToIndex(channel)] =
            255;
        }
      }
    }
  }

  const usedSizedTextures = new Set<
    THREE.Texture & { image: { width: number; height: number } }
  >();
  for (const channel of allOverlayMaterialCombinedChannelKeys) {
    for (const texture of groupedSources[channel].keys()) {
      usedSizedTextures.add(texture);
    }
  }

  let paintedState: PaintedState | undefined;
  if (usedSizedTextures.size) {
    const largestTexture = maxBy(
      [...usedSizedTextures.values()],
      (t) => t.image.width * t.image.height,
    )!;
    const renderTarget = new THREE.WebGLRenderTarget(
      largestTexture.image.width,
      largestTexture.image.height,
    );

    paintedState = {
      renderTarget,
      groupedSources,
    };
  }

  return [paintedState, channelValuesByMeshName];
}

function renderPaintedStateTexture(
  gl: THREE.WebGLRenderer,
  fullScreenQuad: FullScreenQuad<BitSelectMaterial>,
  paintedState: PaintedState,
) {
  const { renderTarget, groupedSources } = paintedState;
  for (const [channelKey, textureMap] of Object.entries(groupedSources)) {
    if (!isOverlayMaterialCombinedChannelKey(channelKey)) {
      continue;
    }
    const channelIndex = overlayMaterialCombinedChannelKeyToIndex(channelKey);

    for (const [texture, bitSet] of textureMap) {
      const bitMasks = [0, 0, 0, 0];
      for (const bit of bitSet) {
        const byteOffset = Math.floor(bit / 8);
        assert(byteOffset >= 0 && byteOffset < 4);
        const bitOffset = bit % 8;
        assert(bitOffset >= 0 && bitOffset < 8);
        bitMasks[byteOffset] |= 1 << bitOffset;
      }

      const bitSelectUniforms = fullScreenQuad.material.uniforms;
      bitSelectUniforms.inputTexture.value = texture;
      bitSelectUniforms.inputBitMasks.value = bitMasks;
      bitSelectUniforms.outputChannel.value = channelIndex;

      gl.setRenderTarget(renderTarget);
      const oldAutoClear = gl.autoClear;
      gl.autoClear = false;
      fullScreenQuad.render(gl);
      gl.autoClear = oldAutoClear;
      gl.setRenderTarget(null);
    }
  }
}

function isWholeMeshLabelOfMeshDesired(
  wholeMeshLabelIdsByMeshName: Map<string, string>,
  desiredLabelIdsPerChannel:
    | {
        CombinedR?: string[] | undefined;
      }
    | undefined,
  meshName: string,
): boolean {
  const wholeMeshLabelId = wholeMeshLabelIdsByMeshName.get(meshName);
  return Boolean(
    wholeMeshLabelId &&
      desiredLabelIdsPerChannel?.CombinedR?.includes(wholeMeshLabelId),
  );
}

function detectClickedWithinTolerance(
  ev: ThreeEvent<MouseEvent>,
  gl: THREE.WebGLRenderer,
  scene: THREE.Scene,
  camera: THREE.Camera,
): boolean {
  const rect = gl.domElement.getBoundingClientRect();
  const projectedPoint = ev.unprojectedPoint.project(camera);
  const xNormalized = Math.round(((projectedPoint.x + 1) / 2) * rect.width);
  const yNormalized = Math.round(((projectedPoint.y + 1) / 2) * rect.height);
  return clickedWithinTolerance(xNormalized, yNormalized, gl, scene, camera);
}

export const RawModelPreview = ({
  gltfUrl,
  meshNamesByWholeMeshLabelId,
  visibleGltfMeshNames,
  visibleLabelIdsPerChannel,
  settingsByChannel,
  desiredLabelIdsPerChannel,
  onClick,
  frozenMeshIdsByMeshName,
}: Props) => {
  const { meshes: gltfMeshes, labelMasks, material } = useOurGltf(gltfUrl);
  const { gl, scene, camera, set, onPointerMissed } = useThree();

  const [combinedPaintedState, setCombinedPaintedState] =
    useState<CombinedPaintedState>({
      visibleLabelsPaintedState: undefined,
      desiredLabelsPaintedState: undefined,
      wholeMeshState: { channelValuesByMeshName: new Map() },
      didRender: false,
    });

  const fullScreenQuad = useMemoDisposable(
    () => new FullScreenQuad(new BitSelectMaterial()),
    [],
  );

  const clickDetectionMaterial = useMemoDisposable(
    () =>
      new THREE.ShaderMaterial({
        uniforms: {
          uTexture: {
            value: combinedPaintedState.desiredLabelsPaintedState?.renderTarget,
          },
        },
        vertexShader: clickDetectionVertexShader,
        fragmentShader: clickDetectionFragmentShader,
      }),
    [combinedPaintedState.desiredLabelsPaintedState],
  );

  const wholeMeshLabelIdsByMeshName = useMemo(() => {
    const wholeMeshLabelIdsByMeshName = new Map(
      [...meshNamesByWholeMeshLabelId.entries()].map(([labelId, meshName]) => [
        meshName,
        labelId,
      ]),
    );
    assert(
      wholeMeshLabelIdsByMeshName.size === meshNamesByWholeMeshLabelId.size,
    );
    return wholeMeshLabelIdsByMeshName;
  }, [meshNamesByWholeMeshLabelId]);

  const emptyTexture = useEmptyTexture();
  const whiteMaterial = useWhiteMaterial();
  const pickPaintedLabelMasks = usePickPaintedLabelMasks(labelMasks);

  useEffect(() => {
    const [visibleLabelsPaintedState, channelValuesByMeshName] =
      getPaintedStateAndChannelMap(
        meshNamesByWholeMeshLabelId,
        labelMasks,
        visibleLabelIdsPerChannel,
      );
    const [desiredLabelsPaintedState, _] = getPaintedStateAndChannelMap(
      meshNamesByWholeMeshLabelId,
      labelMasks,
      desiredLabelIdsPerChannel,
    );

    const wholeMeshState: CombinedHighlightWholeMeshState = {
      channelValuesByMeshName,
    };

    setCombinedPaintedState({
      visibleLabelsPaintedState,
      desiredLabelsPaintedState,
      wholeMeshState,
      didRender: false,
    });

    return () => {
      visibleLabelsPaintedState?.renderTarget.dispose();
      desiredLabelsPaintedState?.renderTarget.dispose();
    };
  }, [
    visibleLabelIdsPerChannel,
    labelMasks,
    meshNamesByWholeMeshLabelId,
    desiredLabelIdsPerChannel,
  ]);

  useFrame(({ gl, clock }) => {
    if (!combinedPaintedState.didRender) {
      // Render visible labels texture
      if (combinedPaintedState.visibleLabelsPaintedState) {
        renderPaintedStateTexture(
          gl,
          fullScreenQuad,
          combinedPaintedState.visibleLabelsPaintedState,
        );
      }

      // Render desired labels texture
      if (combinedPaintedState.desiredLabelsPaintedState) {
        renderPaintedStateTexture(
          gl,
          fullScreenQuad,
          combinedPaintedState.desiredLabelsPaintedState,
        );
      }

      for (const mesh of gltfMeshes) {
        const channelValues =
          combinedPaintedState.wholeMeshState.channelValuesByMeshName.get(
            mesh.name,
          ) ?? [0, 0, 0, 0];

        const attribute = mesh.geometry.getAttribute(
          "combinedWholeMeshHighlight",
        );
        const buffer = attribute.array;
        assert(buffer instanceof Uint8Array);

        const isUpToDate =
          buffer[0] === channelValues[0] &&
          buffer[1] === channelValues[1] &&
          buffer[2] === channelValues[2] &&
          buffer[3] === channelValues[3];
        if (!isUpToDate) {
          // HACK The buffer that we are writing to here is shared between all
          // components that call useOurGltf with the same URL. This will break
          // if there are multiple separate canvases showing the same model on
          // the same page. This is pretty hard to avoid without leaking memory,
          // so we'll fix it when/if we actually need multiple canvases.
          assert(channelValues.length % 4 === 0);
          for (let i = 0; i < buffer.length; i += 4) {
            buffer[i + 0] = channelValues[0];
            buffer[i + 1] = channelValues[1];
            buffer[i + 2] = channelValues[2];
            buffer[i + 3] = channelValues[3];
          }
          attribute.needsUpdate = true;
        }
      }
    }

    combinedPaintedState.didRender = true;

    // Apply visible labels texture to material
    if (material.userData.shaderUniforms) {
      if (combinedPaintedState) {
        material.userData.shaderUniforms.combinedHighlightTexture.value =
          combinedPaintedState.visibleLabelsPaintedState?.renderTarget
            .texture ?? emptyTexture;

        for (const channel of allOverlayMaterialCombinedChannelKeys) {
          const settings = {
            ...defaultOverlayMaterialChannelSettings,
            ...settingsByChannel?.[channel],
          };

          const color = [...settings.color];

          const pulseProgress =
            Math.cos(
              clock.getElapsedTime() * settings.pulseFrequency * Math.PI * 2,
            ) *
              0.5 +
            0.5;
          color[3] *= 1 - pulseProgress * settings.pulseStrength;

          material.userData.shaderUniforms[`color${channel}`].value = color;
        }
      } else {
        material.userData.shaderUniforms.combinedHighlightTexture.value =
          emptyTexture;
      }
    }
  });

  return (
    <>
      {/* HACK - background mesh to detect clicks outside the visible meshes and run the tolerance check */}
      {desiredLabelIdsPerChannel?.CombinedR && (
        <mesh
          userData={{ ignoreForFit: true }}
          onClick={(ev) => {
            if (
              onClick &&
              desiredLabelIdsPerChannel?.CombinedR &&
              detectClickedWithinTolerance(ev, gl, scene, camera)
            ) {
              onClick(desiredLabelIdsPerChannel.CombinedR);
              ev.stopPropagation();
              return;
            }
          }}
        >
          <sphereGeometry args={[-100, 16, 16]} />
          <meshStandardMaterial color="white" />
        </mesh>
      )}
      {gltfMeshes
        .filter(
          (mesh) =>
            !visibleGltfMeshNames || visibleGltfMeshNames.has(mesh.name),
        )
        .map((mesh) => (
          <mesh
            key={mesh.uuid}
            name={frozenMeshIdsByMeshName?.get(mesh.name)}
            geometry={mesh.geometry}
            material={material}
            userData={{
              clickDetectionMaterial: isWholeMeshLabelOfMeshDesired(
                wholeMeshLabelIdsByMeshName,
                desiredLabelIdsPerChannel,
                mesh.name,
              )
                ? whiteMaterial
                : clickDetectionMaterial,
            }}
            position={mesh.position}
            scale={mesh.scale}
            rotation={mesh.rotation}
            onClick={(ev) => {
              if (onClick && ev.delta < 10) {
                let pickedLabelIds: string[] = [];
                // Tolerance render first
                if (
                  desiredLabelIdsPerChannel?.CombinedR &&
                  detectClickedWithinTolerance(ev, gl, scene, camera)
                ) {
                  pickedLabelIds = desiredLabelIdsPerChannel.CombinedR;
                  onClick(pickedLabelIds);
                  ev.stopPropagation();
                  return;
                }

                // Raycast render later if tolerance fails
                if (ev.uv) {
                  pickedLabelIds = pickPaintedLabelMasks(ev.uv);
                }
                const wholeMeshLabelId = wholeMeshLabelIdsByMeshName.get(
                  mesh.name,
                );
                if (wholeMeshLabelId !== undefined) {
                  pickedLabelIds.push(wholeMeshLabelId);
                }
                pickedLabelIds = sortBy([...new Set(pickedLabelIds)]);
                onClick(pickedLabelIds);
                ev.stopPropagation();
              }
            }}
          />
        ))}
    </>
  );
};
