/* global BigInt */

import { Float16Array } from "@petamoriken/float16";
import * as tf from "@tensorflow/tfjs";
import '@tensorflow/tfjs-backend-cpu';
// import { setWasmPaths } from "@tensorflow/tfjs-backend-wasm";
import * as ort from "onnxruntime-web/all";

import * as perf from "./perf.js";

// ort.env.trace = true;
// ort.env.debug = true;
// ort.env.logLevel = "verbose";
// ort.env.wasm.proxy = true;
ort.env.wasm.numThreads = 4;
ort.env.wasm.simd = true;

// setWasmPaths({
//   "tfjs-backend-wasm.wasm": new URL(
//     "@tensorflow/tfjs-backend-wasm/dist/tfjs-backend-wasm.wasm",
//     import.meta.url
//   ).href,
//   "tfjs-backend-wasm-simd.wasm": new URL(
//     "@tensorflow/tfjs-backend-wasm/dist/tfjs-backend-wasm-simd.wasm",
//     import.meta.url
//   ).href,
//   "tfjs-backend-wasm-threaded-simd.wasm": new URL(
//     "@tensorflow/tfjs-backend-wasm/dist/tfjs-backend-wasm-threaded-simd.wasm",
//     import.meta.url
//   ).href,
// });
tf.setBackend("cpu");
console.log("set tf.getBackend() to", tf.getBackend());

export const LOG_EPS = Math.log(1e-10);

export function logaddexp(x1, x2) {
  // FYI: this is to return Math.log(Math.exp(x1) + Math.exp(x2))
  return x1 + Math.log1p(Math.exp(x2 - x1));
}

export async function ml_context(ep) {
  if (ep === "webnn-cpu") {
    return await navigator.ml.createContext({ deviceType: "cpu" });
  } else if (ep === "webnn-gpu") {
    return await navigator.ml.createContext({ deviceType: "gpu" });
  } else if (ep === "webnn-npu") {
    return await navigator.ml.createContext({ deviceType: "npu" });
  }

  return null;
}

export function onnx_session_opts(ep, ml_context) {
  let session_opts = {};

  if (ep === "wasm") {
    session_opts.executionProviders = ["wasm"];
  } else if (ep === "webgl") {
    session_opts.executionProviders = ["webgl"];
  } else if (ep === "webgpu") {
    session_opts.executionProviders = ["webgpu"];
  } else if (ep === "webnn-cpu") {
    session_opts.executionProviders = [
      {
        name: "webnn",
        deviceType: "cpu",
        context: ml_context,
        // preferredOutputLocation: "ml-tensor",
      },
    ];
  } else if (ep === "webnn-gpu") {
    session_opts.executionProviders = [
      {
        name: "webnn",
        deviceType: "gpu",
        context: ml_context,
        // preferredOutputLocation: "ml-tensor",
      },
    ];
  } else if (ep === "webnn-npu") {
    session_opts.executionProviders = [
      {
        name: "webnn",
        deviceType: "npu",
        context: ml_context,
        // preferredOutputLocation: "ml-tensor",
      },
    ];
  } else {
    throw new Error(`Unsupported execution provider: ${ep}`);
  }

  return session_opts;
}

export function tf2onnx(type, x) {
  perf.time("tf2onnx");
  // console.log("tf2onnx(x)", x);
  let ds = x.dataSync();
  // console.log("x.dataSync()", ds);
  if (x.dtype === "float32" && type === "float16") {
    ds = array2webnn_float16(ds);
  }
  // x.print()
  // console.log('x.shape', x.shape);
  const t = new ort.Tensor(type, ds, x.shape);

  perf.timeEnd("tf2onnx");

  return t;
}

export async function onnx2tf(x, ml_context = null) {
  perf.time("onnx2tf");
  // console.log('onnx2tf(x)', x);

  let t = null;

  if (ml_context && x.location === "ml-tensor") {
    perf.time("onnx2tf:readTensor");
    const t_buffer = await ml_context.readTensor(x.mlTensorData);
    perf.timeEnd("onnx2tf:readTensor");

    if (x.type === "float32") {
      perf.time("onnx2tf:Float32Array");
      const f32 = new Float32Array(t_buffer);
      t = tf.tensor(f32, x.dims);
      perf.timeEnd("onnx2tf:Float32Array");
    } else if (x.type === "float16") {
      perf.time("onnx2tf:Float16Array");
      const f16 = new Float16Array(t_buffer);
      const f32 = new Float32Array(f16);
      t = tf.tensor(f32, x.dims);
      perf.timeEnd("onnx2tf:Float16Array");
    } else {
      throw new Error(`Unsupported type: ${x.type}`);
    }
  } else {
    let data = await x.getData();
    // console.log('x.getData()', data);

    if (x.type === "float16") {
      t = onnx_float16_2tf_float32(x);
    } else {
      t = tf.tensor(data, x.dims);
    }
  }

  perf.timeEnd("onnx2tf");

  return t;
}

export async function onnx_tensor_zeros(
  dims,
  type,
  ml_context = null,
  readable = false,
  writable = false
) {
  perf.time("onnx_tensor_zeros");
  // console.log(dims, type);

  let len = dims.reduce((a, b) => a * b);

  let t = null;
  if (ml_context) {
    const ml_desc = {
      dataType: type,
      shape: dims,
      readable: readable,
      writable: writable,
    };

    const ml_tensor = await ml_context.createTensor(ml_desc);
    const ort_desc = { dataType: type, dims: dims };

    t = ort.Tensor.fromMLTensor(ml_tensor, ort_desc);
  } else if (type === "int64") {
    t = new ort.Tensor(type, new Array(len).fill(BigInt(0)), dims);
  } else if (type === "float16") {
    t = new ort.Tensor(type, new Uint16Array(len).fill(0), dims);
  } else if (type === "float32") {
    // console.log("dims", dims);
    t = new ort.Tensor(type, new Float32Array(len).fill(0), dims);
  } else {
    throw new Error(`Unsupported type: ${type}`);
  }

  perf.timeEnd("onnx_tensor_zeros");

  return t;
}

export async function zero_out_ml_tensor(ml_context, x) {
  perf.time("zero_out_ml_tensor");

  let tensor = x;
  let buffer = null;

  if (tensor.type === "float16") {
    buffer = new Uint16Array(tensor.dims.reduce((a, b) => a * b));
  } else if (tensor.type === "int64") {
    buffer = new BigInt64Array(tensor.dims.reduce((a, b) => a * b)).fill(
      BigInt(0)
    );
  } else {
    buffer = new Float32Array(tensor.dims.reduce((a, b) => a * b));
  }

  ml_context.writeTensor(tensor.mlTensorData, buffer);

  perf.timeEnd("zero_out_ml_tensor");
}

export function onnx_float16_2tf_float32(arr) {
  const f16 = new Float16Array(arr.data.buffer);
  // console.log("f16", f16);
  const f32 = new Float32Array(f16);
  // console.log("f32", f32);

  return new tf.tensor(f32, arr.dims);
}

export function array2onnx_int64(arr, dims) {
  // console.log(dims, type, x);

  arr = arr.map((x) => BigInt(x));

  if (dims) {
    return new ort.Tensor("int64", arr, dims);
  }

  return new ort.Tensor("int64", arr, [arr.length]);
}

export function array2onnx_float16(arr, dims = null) {
  const f16 = new Float16Array(arr);

  if (dims) {
    return new ort.Tensor("float16", new Uint16Array(f16.buffer), dims);
  }

  return new ort.Tensor("float16", new Uint16Array(f16.buffer), [f16.length]);
}

export function array2webnn_float16(arr) {
  const f16 = new Float16Array(arr);

  return new Uint16Array(f16.buffer);
}

export function to_int_list(s) {
  const l = s.split(",").map((x) => parseInt(x));
  // console.log(s, '->', l);
  return l;
}

// export function topk(arr, k) {
//   let indices = new Array(arr.length);
//   for (let i = 0; i < arr.length; i++) {
//     indices[i] = i;
//   }

//   function compare_non_numbers(a, b) {
//     if (isNaN(arr[a]) && isNaN(arr[b])) {
//       return 0;
//     } else if (isNaN(arr[a])) {
//       return 1;
//     } else if (isNaN(arr[b])) {
//       return -1;
//     }

//     return arr[b] - arr[a]
//   }
  
//   indices.sort(compare_non_numbers);
  
//   indices = indices.slice(0, k);
//   const topkValues = indices.map(i => arr[i]);

//   return { indices: indices, values: topkValues };
// }