/* eslint @typescript-eslint/no-unused-vars: off */
import { unreachable } from "@augmedi/type-utils";
import { z } from "zod";

/*
From https://github.com/huggingface/safetensors?tab=readme-ov-file#format

Format:
- 8 bytes: N, an unsigned little-endian 64-bit integer, containing the size of the header
- N bytes: a JSON UTF-8 string representing the header.
  - The header data MUST begin with a { character (0x7B).
  - The header data MAY be trailing padded with whitespace (0x20).
  - The header is a dict like {"TENSOR_NAME": {"dtype": "F16", "shape": [1, 16, 256], "data_offsets": [BEGIN, END]}, "NEXT_TENSOR_NAME": {...}, ...},
    - data_offsets point to the tensor data relative to the beginning of the byte buffer (i.e. not an absolute position in the file), with BEGIN as the starting offset and END as the one-past offset (so total tensor byte size = END - BEGIN).
  - A special key __metadata__ is allowed to contain free form string-to-string map. Arbitrary JSON is not allowed, all values must be strings.
- Rest of the file: byte-buffer.

Notes:
- Duplicate keys are disallowed. Not all parsers may respect this.
- In general the subset of JSON is implicitly decided by serde_json for this library. Anything obscure might be modified at a later time, that odd ways to represent integer, newlines and escapes in utf-8 strings. This would only be done for safety concerns
- Tensor values are not checked against, in particular NaN and +/-Inf could be in the file
- Empty tensors (tensors with 1 dimension being 0) are allowed. They are not storing any data in the databuffer, yet retaining size in the header. They don't really bring a lot of values but are accepted since they are valid tensors from traditional tensor libraries perspective (torch, tensorflow, numpy, ..).
- 0-rank Tensors (tensors with shape []) are allowed, they are merely a scalar.
- The byte buffer needs to be entirely indexed, and cannot contain holes. This prevents the creation of polyglot files.
- Endianness: Little-endian. moment.
- Order: 'C' or row-major.

DATA_TYPE can be one of ["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL]
*/

type ParsedTensorValue = BigInt64Array | Float32Array;

export interface ParsedTensor<V extends ParsedTensorValue = ParsedTensorValue> {
  dtype: string;
  shape: number[];
  // currently not all data types are supported
  values: V;
}

const headerZod = z.record(
  z.object({
    dtype: z.union([z.literal("I64"), z.literal("F32")]),
    shape: z.array(z.number()),
    data_offsets: z.array(z.number()),
  }),
);

type Dtype = "I64" | "F32";

function typedArrayFromArrayBuffer(
  buffer: ArrayBuffer,
  dtype: Dtype,
  begin: number,
  end: number,
): ParsedTensor["values"] {
  switch (dtype) {
    case "I64":
      return new BigInt64Array(buffer.slice(begin, end));
    case "F32":
      return new Float32Array(buffer.slice(begin, end));
    default:
      unreachable(dtype);
  }
}

function assertValidShape(shape: number[], typedArrayLength: number): void {
  for (const dim of shape) {
    if (!Number.isInteger(dim) || dim < 0) {
      throw new Error("Invalid shape");
    }
  }
  if (shape.reduce((a, b) => a * b, 1) !== typedArrayLength) {
    throw new Error("Invalid shape");
  }
}

export function parseSafetensors(
  buffer: ArrayBuffer, // The passed ArrayBuffer should not be used after calling this function, since it may be modified.
): Record<string, ParsedTensor> {
  const view = new DataView(buffer);

  const headerSizeBigUint = view.getBigUint64(0, true);
  // Only support 32 bit header sizes for now
  if (headerSizeBigUint > 0xffffffff) {
    throw new Error("Header size too large");
  }
  const headerSizeNumber = Number(headerSizeBigUint);
  const header = headerZod.parse(
    JSON.parse(
      new TextDecoder().decode(new Uint8Array(buffer, 8, headerSizeNumber)),
    ),
  );

  const dataOffset = 8 + headerSizeNumber;

  const sortedRanges = Object.values(header).map(
    ({ data_offsets: [begin, end] }) => [begin, end],
  );
  sortedRanges.sort(([a], [b]) => a - b);
  for (let i = 1; i < sortedRanges.length; i++) {
    const [_, endA] = sortedRanges[i - 1];
    const [beginB] = sortedRanges[i];
    if (endA > beginB) {
      throw new Error("Overlapping ranges");
    }
  }

  const tensors: Record<string, ParsedTensor> = {};
  for (const [
    name,
    {
      dtype,
      shape,
      data_offsets: [begin, end],
    },
  ] of Object.entries(header)) {
    const values = typedArrayFromArrayBuffer(
      buffer,
      dtype,
      begin + dataOffset,
      end + dataOffset,
    );
    assertValidShape(shape, values.length);
    tensors[name] = { dtype, shape, values };
  }

  return tensors;
}
