/* eslint react/display-name: off, react/prop-types: off, react/no-unknown-property: off */
import { compressBinaryMask, decompressBinaryMask } from "@augmedi/encoding";
import type { quizPb } from "@augmedi/proto-gen";
import type { PlainMessage } from "@bufbuild/protobuf";
import { Sphere } from "@react-three/drei";
import {
  addAfterEffect,
  useFrame,
  useThree,
  type ThreeEvent,
} from "@react-three/fiber";
import assert from "assert-ts";
import {
  forwardRef,
  useEffect,
  useImperativeHandle,
  useMemo,
  useRef,
} from "react";
import * as THREE from "three";
import { colorConstants } from "../color-constants.js";
import { shaderRgbaFromString } from "../logic/color.js";
import { useEmptyTexture, useOurGltf } from "../logic/gltf.js";
import { useMemoDisposable } from "../logic/memo-cleanup.js";
import { OtherMesh } from "./OtherMesh.js";

interface Props {
  modelId: string;
  gltfUrl: string;
  gltfMeshName: string;
  otherPartitionIdByName: Map<string, string>;
  readOnlyMasks: PlainMessage<quizPb.MaskImage>[];
  initialWriteableMask?: PlainMessage<quizPb.MaskImage>;
  onWriteableMaskDirtyChanged?: (dirty: boolean) => void;
  writeableMaskOpacity?: number;
  onIdle?: () => void;
  tool?: ModelPainterTool;
  brushRadius: number;
}

enum PaintMode {
  Paint = "Paint",
  Erase = "Erase",
}

export enum ModelPainterTool {
  Brush = "Brush",
  Camera = "Camera",
}

interface Brush {
  paintMode?: PaintMode;
  paintStart?: THREE.Vector3;
  paintEnd?: THREE.Vector3;
}

export interface ModelPainterRef {
  getWriteableMask: () => PlainMessage<quizPb.MaskImage>;
}

export const ModelPainter = forwardRef<ModelPainterRef, Props>(
  (
    {
      modelId,
      gltfUrl,
      gltfMeshName,
      otherPartitionIdByName,
      readOnlyMasks,
      initialWriteableMask,
      onWriteableMaskDirtyChanged,
      writeableMaskOpacity = 1.0,
      onIdle,
      tool,
      brushRadius,
    },
    ref,
  ) => {
    const { meshes: gltfMeshes, material } = useOurGltf(gltfUrl);
    const gltfMesh = gltfMeshes.find((m) => m.name === gltfMeshName);
    if (!gltfMesh) {
      throw new Error(`Mesh "${gltfMeshName}" not found in GLTF`);
    }

    const brushRef = useRef<Brush>({});
    useEffect(() => {
      addAfterEffect(() => {
        if (brushRef.current.paintMode !== undefined) {
          brushRef.current = {
            paintMode: brushRef.current.paintMode,
            paintStart: brushRef.current.paintEnd,
            paintEnd: brushRef.current.paintEnd,
          };
        } else {
          brushRef.current = {};
        }
      });
    }, []);

    const hoverPointRef = useRef<THREE.Vector3 | undefined>();

    const paintShaderMaterial = useMemoDisposable(
      () =>
        new THREE.ShaderMaterial({
          depthWrite: false,
          depthTest: false,
          blending: THREE.CustomBlending,
          blendEquation: THREE.AddEquation, // Will be overridden in the render loop
          blendSrc: THREE.SrcAlphaFactor,
          blendDst: THREE.OneMinusSrcAlphaFactor,
          side: THREE.DoubleSide,
          uniforms: {
            brushStart: { value: new THREE.Vector3() },
            brushEnd: { value: new THREE.Vector3() },
            brushRadius: { value: 0 },
          },
          vertexShader: `
          varying vec3 vWorldPosition;

          void main() {
            // Calculate the world position for the fragment shader to determine the distance from the brush
            vWorldPosition = (modelMatrix * vec4(position, 1.0)).xyz;
            // Output colors to the UV position, effectively copying the texture to the render target  
            gl_Position = vec4(uv * 2.0 - 1.0, 0.0, 1.0);
          }
        `,
          fragmentShader: `
          uniform vec3 brushStart;
          uniform vec3 brushEnd;
          uniform float brushRadius;
          
          varying vec3 vWorldPosition;

          void main() {
            vec3 brushStartToBrushEnd = brushEnd - brushStart;
            float brushLength = length(brushStartToBrushEnd);

            float distanceToClosestBrushPoint;
            if (brushLength < 0.0001) {
              distanceToClosestBrushPoint = length(vWorldPosition - brushStart);
            } else {
              vec3 brushDirection = brushStartToBrushEnd / brushLength;
              float distanceAlongBrush = clamp(dot(vWorldPosition - brushStart, brushDirection), 0.0, brushLength);
              vec3 closestBrushPoint = brushStart + brushDirection * distanceAlongBrush;
              distanceToClosestBrushPoint = length(vWorldPosition - closestBrushPoint);
            }
            
            float alpha = 1.0 - step(brushRadius, distanceToClosestBrushPoint);
            gl_FragColor = vec4(1.0, 1.0, 1.0, alpha);
          }
        `,
        }),
      [],
    );

    const offscreenScene = useMemo(() => new THREE.Scene(), []);
    const offscreenMeshRef = useRef<THREE.Mesh>(null!);
    useEffect(() => {
      const offscreenMesh = new THREE.Mesh(
        gltfMesh.geometry,
        paintShaderMaterial,
      );
      offscreenMeshRef.current = offscreenMesh;
      offscreenScene.clear();
      offscreenScene.add(offscreenMesh);
    }, [offscreenScene, gltfMesh, paintShaderMaterial]);

    const threeState = useThree();

    const renderTargetRef = useRef<THREE.WebGLRenderTarget>();
    const renderTargetDirtyRef = useRef(false);
    useEffect(() => {
      if (!initialWriteableMask) {
        renderTargetRef.current = undefined;
        renderTargetDirtyRef.current = false;
        return;
      }

      const gl = threeState.gl;

      const renderTarget = new THREE.WebGLRenderTarget(
        initialWriteableMask.width,
        initialWriteableMask.height,
      );
      renderTargetRef.current = renderTarget;
      // HACK The renderTarget is only fully initialized when setRenderTarget is
      // called. We don't actually need it bound for the copy though.
      gl.setRenderTarget(renderTarget);
      gl.setRenderTarget(null);

      const pixelCount =
        initialWriteableMask.width * initialWriteableMask.height;
      const uncompressedSingleChannelMask = decompressBinaryMask(
        initialWriteableMask.compressedData,
        pixelCount,
      );
      const uncompressedMultiChannelMask = new Uint8Array(pixelCount * 4);
      // Selected label color stored in the A channel
      for (let i = 0; i < pixelCount; i++) {
        uncompressedMultiChannelMask[i * 4 + 3] =
          uncompressedSingleChannelMask[i] * 255;
      }
      // Other labels color stored in the R channel
      for (const readOnlyMask of readOnlyMasks) {
        const otherUncompressedSingleChannelMask = decompressBinaryMask(
          readOnlyMask.compressedData,
          pixelCount,
        );
        for (let i = 0; i < pixelCount; i++) {
          if (0 === uncompressedMultiChannelMask[i * 4]) {
            uncompressedMultiChannelMask[i * 4] =
              otherUncompressedSingleChannelMask[i] * 255;
          }
        }
      }

      // We can't easily write directly to a render target, so we copy through a DataTexture.
      const dataTexture = new THREE.DataTexture(
        uncompressedMultiChannelMask,
        initialWriteableMask.width,
        initialWriteableMask.height,
        THREE.RGBAFormat,
      );
      dataTexture.needsUpdate = true;
      try {
        gl.copyTextureToTexture(
          new THREE.Vector2(),
          dataTexture,
          renderTarget.texture,
        );
      } finally {
        dataTexture.dispose();
      }

      renderTargetDirtyRef.current = false;
      onWriteableMaskDirtyChanged?.(false);

      return () => {
        renderTarget.dispose();
        if (renderTargetRef.current === renderTarget) {
          renderTargetRef.current = undefined;
        }
      };
    }, [initialWriteableMask, readOnlyMasks, threeState.gl]);

    useImperativeHandle(
      ref,
      (): ModelPainterRef => ({
        getWriteableMask: () => {
          const renderTarget = renderTargetRef.current;
          if (!renderTarget) {
            throw new Error(
              "No writeable mask loaded. Did you pass initialWriteableMask?",
            );
          }

          const pixelCount = renderTarget.width * renderTarget.height;

          const renderTargetData = new Uint8Array(4 * pixelCount);
          threeState.gl.readRenderTargetPixels(
            renderTarget,
            0,
            0,
            renderTarget.width,
            renderTarget.height,
            renderTargetData,
          );

          const uncompressedMask = new Uint8Array(pixelCount);
          for (let i = 0; i < pixelCount; i++) {
            uncompressedMask[i] = renderTargetData[i * 4 + 3] > 127 ? 1 : 0;
          }
          const compressedMask = compressBinaryMask(uncompressedMask);

          return {
            width: renderTarget.width,
            height: renderTarget.height,
            compressedData: compressedMask,
          };
        },
      }),
      [threeState.gl],
    );

    const brushSphereRef = useRef<THREE.Mesh>(null!);
    const sceneMesh = useRef<THREE.Mesh>(null!);
    const emptyTexture = useEmptyTexture();

    useFrame(({ gl, camera }) => {
      const renderTarget = renderTargetRef.current;

      if (
        tool === ModelPainterTool.Brush &&
        renderTarget &&
        brushRef.current.paintMode !== undefined &&
        brushRef.current.paintEnd
      ) {
        brushSphereRef.current.visible = true;
        brushSphereRef.current.position.copy(brushRef.current.paintEnd);
        assert(
          brushSphereRef.current.material instanceof THREE.MeshBasicMaterial,
        );
        brushSphereRef.current.material.color.set(
          brushRef.current.paintMode === PaintMode.Paint ? "green" : "red",
        );

        paintShaderMaterial.uniforms.brushStart.value =
          brushRef.current.paintStart;
        paintShaderMaterial.uniforms.brushEnd.value = brushRef.current.paintEnd;
        paintShaderMaterial.uniforms.brushRadius.value = brushRadius;
        paintShaderMaterial.blendEquation =
          brushRef.current.paintMode === PaintMode.Paint
            ? THREE.AddEquation
            : THREE.ReverseSubtractEquation;

        const offscreenMesh = offscreenMeshRef.current;
        offscreenMesh.matrixAutoUpdate = false;
        offscreenMesh.matrix.copy(sceneMesh.current.matrixWorld);
        offscreenMesh.matrixWorld.copy(sceneMesh.current.matrixWorld);

        gl.setRenderTarget(renderTarget);
        const oldAutoClear = gl.autoClear;
        gl.autoClear = false;
        gl.render(offscreenScene, camera);
        gl.autoClear = oldAutoClear;
        gl.setRenderTarget(null);

        if (!renderTargetDirtyRef.current) {
          renderTargetDirtyRef.current = true;
          onWriteableMaskDirtyChanged?.(true);
        }
      } else if (tool === ModelPainterTool.Brush && hoverPointRef.current) {
        brushSphereRef.current.visible = true;
        brushSphereRef.current.position.copy(hoverPointRef.current);
        assert(
          brushSphereRef.current.material instanceof THREE.MeshBasicMaterial,
        );
        brushSphereRef.current.material.color.set("white");
      } else {
        brushSphereRef.current.visible = false;
      }

      if (material.userData.shaderUniforms) {
        material.userData.shaderUniforms.standaloneHighlightTexture.value =
          renderTarget?.texture ?? emptyTexture;
        const ambientColor = shaderRgbaFromString(
          colorConstants.introductionGroupSeenColor,
        );
        material.userData.shaderUniforms.colorStandaloneA.value = [
          ambientColor[0],
          ambientColor[1],
          ambientColor[2],
          writeableMaskOpacity,
        ];
        material.userData.shaderUniforms.combinedHighlightTexture.value =
          renderTarget?.texture ?? emptyTexture;
        material.userData.shaderUniforms.colorCombinedR.value = [
          1,
          1,
          1,
          writeableMaskOpacity,
        ];
      }
    });

    const wrapWithMeshIntersection =
      (
        meshRef: React.RefObject<THREE.Mesh>,
        cb: (
          ev: ThreeEvent<PointerEvent>,
          intersection: THREE.Intersection,
        ) => void,
      ) =>
      (ev: ThreeEvent<PointerEvent>) => {
        const intersection = ev.intersections.find(
          (intersection) => intersection.object === meshRef.current,
        );
        if (intersection) {
          cb(ev, intersection);
        }
      };

    return (
      <>
        <mesh
          key={gltfMesh.id}
          ref={sceneMesh}
          geometry={gltfMesh.geometry}
          material={material}
          scale={gltfMesh.scale}
          rotation={gltfMesh.rotation}
          onPointerDown={wrapWithMeshIntersection(
            sceneMesh,
            (ev, intersection) => {
              brushRef.current = {
                paintMode: ev.shiftKey ? PaintMode.Erase : PaintMode.Paint,
                paintStart: intersection.point,
                paintEnd: intersection.point,
              };
            },
          )}
          onPointerUp={() => {
            brushRef.current = {};
            onIdle?.();
          }}
          onPointerMove={wrapWithMeshIntersection(
            sceneMesh,
            (ev, intersection) => {
              hoverPointRef.current = intersection.point;
              if (brushRef.current.paintMode !== undefined) {
                brushRef.current = {
                  paintMode: ev.shiftKey ? PaintMode.Erase : PaintMode.Paint,
                  paintStart: brushRef.current.paintStart ?? intersection.point,
                  paintEnd: intersection.point,
                };
              }
            },
          )}
          onPointerOut={() => {
            hoverPointRef.current = undefined;
            brushRef.current = {};
          }}
        />
        {Array.from(otherPartitionIdByName).map(([meshName, partitionId]) => (
          <OtherMesh
            key={meshName}
            modelId={modelId}
            partitionId={partitionId}
            meshName={meshName}
          />
        ))}
        <Sphere
          ref={brushSphereRef}
          scale={brushRadius}
          userData={{ ignoreForFit: true }}
        >
          <meshBasicMaterial toneMapped={false} />
        </Sphere>
      </>
    );
  },
);
