/* eslint @typescript-eslint/no-empty-object-type: off, @typescript-eslint/no-explicit-any: off */
import { quizPb } from "@augmedi/proto-gen";
import { useGLTF } from "@react-three/drei";
import { useThree } from "@react-three/fiber";
import assert from "assert-ts";
import { isPlainObject } from "lodash-es";
import * as THREE from "three";
import type {
  GLTF,
  GLTFLoader,
  GLTFLoaderPlugin,
  GLTFParser,
} from "three-stdlib";
import { getKtxLoader } from "./ktx";
import { useMemoDisposable } from "./memo-cleanup";
import { onBeforeCompileOverlayShader } from "./overlay-material";

declare const maskCompatibleTextureBrand: unique symbol;

export type MaskCompatibleTexture = THREE.Texture & {
  image: { width: number; height: number };
  [maskCompatibleTextureBrand]: true;
};

export interface OurGltfContent {
  meshes: THREE.Mesh[];
  labelMasks: { [id: string]: LoadedLabelMask };
  material: THREE.MeshBasicMaterial | THREE.MeshStandardMaterial;
}

export interface LoadedLabelMask {
  texture: MaskCompatibleTexture;
  textureBit: number;
}

interface ExtrasFromTextureCollector {
  maskTexturesByIndex: Map<number, MaskCompatibleTexture>;
}

const maskLabelExtensionName = "AUGMEDI_mesh_with_label_masks";

function hasWidthHeight(o: {}): o is { width: number; height: number } {
  return (
    "width" in o &&
    typeof o.width === "number" &&
    "height" in o &&
    typeof o.height === "number"
  );
}

function convertToMaskCompatibleTexture(
  texture: THREE.Texture,
): MaskCompatibleTexture {
  if (
    !hasWidthHeight(texture.image) ||
    texture.format !== THREE.RGBAFormat ||
    (texture.type !== THREE.ByteType && texture.type !== THREE.UnsignedByteType)
  ) {
    throw new Error("Texture does not have the right format for use as a mask");
  }

  texture.flipY = false;
  texture.format = THREE.RGBAIntegerFormat;
  texture.type = THREE.UnsignedByteType;
  texture.internalFormat = "RGBA8UI";
  texture.magFilter = THREE.NearestFilter;
  texture.minFilter = THREE.NearestFilter;
  texture.generateMipmaps = false;
  texture.needsUpdate = true;

  return texture as MaskCompatibleTexture;
}

function getMaskLabelExtensionFromObject3D(
  mesh: THREE.Object3D,
): quizPb.AugmediMeshWithLabelMasksGltfExtensionContent | undefined {
  const maskLabelExtensionContentJson =
    mesh.userData?.gltfExtensions?.[maskLabelExtensionName];
  if (maskLabelExtensionContentJson !== undefined) {
    return quizPb.AugmediMeshWithLabelMasksGltfExtensionContent.fromJson(
      maskLabelExtensionContentJson,
    );
  }
  return undefined;
}

class TextureCollectorPlugin implements GLTFLoaderPlugin {
  public name: string = "TextureCollectorPlugin";

  afterRoot = this.#afterRoot.bind(this);

  constructor(private parser: GLTFParser) {}

  async #afterRoot(result: GLTF): Promise<void> {
    const maskTextureIndices = new Set<number>();
    for (const scene of result.scenes) {
      scene.traverse((object) => {
        const extensionContent = getMaskLabelExtensionFromObject3D(object);
        if (!extensionContent) {
          return;
        }
        for (const mask of Object.values(extensionContent.labelMasks)) {
          maskTextureIndices.add(mask.textureIndex);
        }
      });
    }

    const maskTexturesByIndex = new Map<number, MaskCompatibleTexture>();
    await Promise.all(
      [...maskTextureIndices].map(async (textureIndex) => {
        const texture = convertToMaskCompatibleTexture(
          await this.parser.loadTexture(textureIndex),
        );
        maskTexturesByIndex.set(textureIndex, texture);
      }),
    );

    if (result.asset.extras === undefined) {
      result.asset.extras = {};
    }
    const extras = result.asset.extras;
    if (!isPlainObject(extras)) {
      throw new Error("extras is defined, but not an object");
    }
    if ("maskTexturesByIndex" in extras) {
      throw new Error("extras.allMaskImages is already defined");
    }
    Object.assign(extras, {
      maskTexturesByIndex,
    } satisfies ExtrasFromTextureCollector);
  }
}

class GeometryExtenderPlugin implements GLTFLoaderPlugin {
  public name: string = "GeometryExtenderPlugin";

  afterRoot = this.#afterRoot.bind(this);

  async #afterRoot(result: GLTF): Promise<void> {
    for (const scene of result.scenes) {
      scene.traverse((object) => {
        if (
          object instanceof THREE.Mesh &&
          object.geometry instanceof THREE.BufferGeometry
        ) {
          const numVertices = object.geometry.attributes?.position?.count;
          if (numVertices === undefined) {
            throw new Error("Mesh has no vertex positions");
          }
          const bufferAttribute = new THREE.Uint8BufferAttribute(
            new Uint8Array(4 * numVertices),
            4,
          );
          bufferAttribute.normalized = true;
          object.geometry.setAttribute(
            "combinedWholeMeshHighlight",
            bufferAttribute,
          );
        }
      });
    }
  }
}

function extendLoader(loader: GLTFLoader, gl: THREE.WebGLRenderer) {
  loader.register((parser) => new TextureCollectorPlugin(parser));
  loader.register(() => new GeometryExtenderPlugin());
  loader.setKTX2Loader(getKtxLoader(gl));
}

export function useOurGltf(gltfUrl: string): OurGltfContent {
  const { gl } = useThree();

  const {
    scene: gltfScene,
    asset: { extras: gltfExtras },
  } = useGLTF(gltfUrl, undefined, undefined, (loader) =>
    extendLoader(loader, gl),
  );

  const content = useMemoDisposable((): OurGltfContent & {
    dispose(): void;
  } => {
    if (!gltfScene.children.length) {
      throw new Error("Expected at least 1 object in the scene");
    }
    const meshes = gltfScene.children.filter(
      (child): child is THREE.Mesh => child instanceof THREE.Mesh,
    );
    if (meshes.length !== gltfScene.children.length) {
      throw new Error("Expected all objects in the scene to be meshes");
    }

    const { maskTexturesByIndex } = gltfExtras as ExtrasFromTextureCollector;

    const labelMasks: { [id: string]: LoadedLabelMask } = {};
    for (const mesh of meshes) {
      const maskLabelExtensionContent = getMaskLabelExtensionFromObject3D(mesh);
      if (maskLabelExtensionContent) {
        for (const [id, { textureIndex, textureBit }] of Object.entries(
          maskLabelExtensionContent.labelMasks,
        )) {
          const texture = maskTexturesByIndex.get(textureIndex);
          assert(texture !== undefined);
          if (id in labelMasks) {
            throw new Error(`Label mask ${id} is defined in multiple meshes`);
          }
          labelMasks[id] = { texture, textureBit };
        }
      }
    }

    const material = meshes[0].material;
    for (const mesh of meshes) {
      if (mesh.material !== material) {
        throw new Error("Expected all meshes to have the same material");
      }
    }

    if (
      !(material instanceof THREE.MeshBasicMaterial) &&
      !(material instanceof THREE.MeshStandardMaterial)
    ) {
      throw new Error(
        "Expected a THREE.MeshBasicMaterial or THREE.MeshStandardMaterial",
      );
    }

    const overlayMaterial = material.clone();
    overlayMaterial.onBeforeCompile = (shader) =>
      onBeforeCompileOverlayShader(shader, overlayMaterial);

    return {
      meshes,
      labelMasks,
      material: overlayMaterial,
      dispose() {
        overlayMaterial.dispose();
      },
    };
  }, [gltfScene, gltfExtras]);

  return content;
}

export function useEmptyTexture() {
  const emptyTexture = useMemoDisposable(
    () =>
      new THREE.DataTexture(
        new Uint8Array([0, 0, 0, 0]),
        1,
        1,
        THREE.RGBAFormat,
      ),
    [],
  );
  return emptyTexture;
}
