import { vec2, vec3 } from 'gl-matrix';

const imageOrientationToUnitVectors = (image) => {
  const rowVec = vec3.normalize(
    vec3.create(),
    vec3.fromValues(...image.imageorientationpatient.slice(0, 3))
  );
  const colVec = vec3.normalize(
    vec3.create(),
    vec3.fromValues(...image.imageorientationpatient.slice(3, 6))
  );
  const depthVec = vec3.cross(vec3.create(), rowVec, colVec);
  return { rowVec, colVec, depthVec };
};

const EPS = 0.00001;
const MIN_COSINE = 0.95;
const OBB_EDGE_INDICES = [
  [6, 1],
  [6, 3],
  [6, 4],
  [2, 7],
  [2, 5],
  [2, 0],
  [0, 1],
  [0, 3],
  [7, 1],
  [7, 4],
  [4, 5],
  [5, 3],
];

const edgePlaneIntersection = (edge, plane) => {
  const AB = vec3.sub(vec3.create(), edge[1], edge[0]);
  const nA = vec3.dot(plane.normal, edge[0]);
  const nAB = vec3.dot(plane.normal, AB);
  if (Math.abs(nAB) < EPS) {
    return undefined;
  }
  const t = (plane.distance - nA) / nAB;
  return t >= 0 && t <= 1 ? vec3.scaleAndAdd(vec3.create(), edge[0], AB, t) : undefined;
};

const sameDir = (v1, v2) => Math.abs(vec3.dot(v1, v2)) > MIN_COSINE;

const computeVerticesFromObb = (center, axis, halfWidth) => {
  const vertices = [];
  for (let i = 0; i < 8; i++) {
    vertices.push(vec3.clone(center));
  }

  vertices[0] = vec3.scaleAndAdd(vertices[0], vertices[0], axis[0], halfWidth[0]);
  vertices[0] = vec3.scaleAndAdd(vertices[0], vertices[0], axis[1], halfWidth[1]);
  vertices[0] = vec3.scaleAndAdd(vertices[0], vertices[0], axis[2], halfWidth[2]);

  vertices[1] = vec3.scaleAndAdd(vertices[1], vertices[1], axis[0], -halfWidth[0]);
  vertices[1] = vec3.scaleAndAdd(vertices[1], vertices[1], axis[1], halfWidth[1]);
  vertices[1] = vec3.scaleAndAdd(vertices[1], vertices[1], axis[2], halfWidth[2]);

  vertices[2] = vec3.scaleAndAdd(vertices[2], vertices[2], axis[0], halfWidth[0]);
  vertices[2] = vec3.scaleAndAdd(vertices[2], vertices[2], axis[1], -halfWidth[1]);
  vertices[2] = vec3.scaleAndAdd(vertices[2], vertices[2], axis[2], halfWidth[2]);

  vertices[3] = vec3.scaleAndAdd(vertices[3], vertices[3], axis[0], halfWidth[0]);
  vertices[3] = vec3.scaleAndAdd(vertices[3], vertices[3], axis[1], halfWidth[1]);
  vertices[3] = vec3.scaleAndAdd(vertices[3], vertices[3], axis[2], -halfWidth[2]);

  vertices[4] = vec3.scaleAndAdd(vertices[4], vertices[4], axis[0], -halfWidth[0]);
  vertices[4] = vec3.scaleAndAdd(vertices[4], vertices[4], axis[1], -halfWidth[1]);
  vertices[4] = vec3.scaleAndAdd(vertices[4], vertices[4], axis[2], -halfWidth[2]);

  vertices[5] = vec3.scaleAndAdd(vertices[5], vertices[5], axis[0], halfWidth[0]);
  vertices[5] = vec3.scaleAndAdd(vertices[5], vertices[5], axis[1], -halfWidth[1]);
  vertices[5] = vec3.scaleAndAdd(vertices[5], vertices[5], axis[2], -halfWidth[2]);

  vertices[6] = vec3.scaleAndAdd(vertices[6], vertices[6], axis[0], -halfWidth[0]);
  vertices[6] = vec3.scaleAndAdd(vertices[6], vertices[6], axis[1], halfWidth[1]);
  vertices[6] = vec3.scaleAndAdd(vertices[6], vertices[6], axis[2], -halfWidth[2]);

  vertices[7] = vec3.scaleAndAdd(vertices[7], vertices[7], axis[0], -halfWidth[0]);
  vertices[7] = vec3.scaleAndAdd(vertices[7], vertices[7], axis[1], -halfWidth[1]);
  vertices[7] = vec3.scaleAndAdd(vertices[7], vertices[7], axis[2], halfWidth[2]);

  return vertices;
};

const computeEdgesFromVectices = (vertices) => {
  return OBB_EDGE_INDICES.map((i) => [vertices[i[0]], vertices[i[1]]]);
};

class Obb {
  constructor(center, axis, halfWidth) {
    this.center = center && center.length === 3 ? vec3.fromValues(...center) : vec3.create();
    this.axis =
      axis && axis.length === 3 && axis[0].length === 3
        ? [vec3.fromValues(...axis[0]), vec3.fromValues(...axis[1]), vec3.fromValues(...axis[2])]
        : [vec3.create(), vec3.create(), vec3.create()];
    this.halfWidth =
      halfWidth && halfWidth.length === 3 ? vec3.fromValues(...halfWidth) : vec3.create();
    this.vertices = computeVerticesFromObb(center, axis, halfWidth);
    this.edges = computeEdgesFromVectices(this.vertices);
  }

  axisInterval(axis) {
    const { vertices } = this;
    const interval = {};

    interval.min = interval.max = vec3.dot(axis, vertices[0]);
    for (let i = 1; i < 8; i++) {
      let projection = vec3.dot(axis, vertices[i]);
      interval.min = projection < interval.min ? projection : interval.min;
      interval.max = projection > interval.max ? projection : interval.max;
    }
    return interval;
  }

  planeIntersection(plane) {
    return this.edges.map((e) => edgePlaneIntersection(e, plane)).filter((i) => !!i);
  }

  getPlaneCutEdges(image) {
    const normal = imageOrientationToUnitVectors(image).depthVec;
    const distance = vec3.dot(normal, image.imagepositionpatient);
    const plane = { normal, distance };

    const pointPlaneDistance = (p) => vec3.dot(plane.normal, p) - plane.distance;

    const close = [];
    const far = [];

    this.edges.forEach((e) => {
      const p1Dist = pointPlaneDistance(e[0]);
      const p2Dist = pointPlaneDistance(e[1]);
      if (p1Dist <= 0 && p2Dist <= 0) {
        close.push(e);
      } else if (p1Dist > 0 && p2Dist > 0) {
        far.push(e);
      } else {
        const midPoint = edgePlaneIntersection(e, plane);
        const e1 = [e[0], midPoint];
        const e2 = [midPoint, e[1]];
        close.push(p1Dist < 0 ? e1 : e2);
        far.push(p1Dist < 0 ? e2 : e1);
      }
    });

    return { close, far };
  }

  imageIntersection(images) {
    return images.map((i) => {
      const normal = imageOrientationToUnitVectors(i).depthVec;
      const distance = vec3.dot(normal, i.imagepositionpatient);
      const plane = { normal, distance };
      const planeIntersections = this.planeIntersection(plane);
      if (planeIntersections.length < 3) {
        return [];
      }
      return planeIntersections;
    });
  }

  testObbIntersection(obb2) {
    const obb1 = this;
    const test = [
      obb1.axis[0],
      obb1.axis[1],
      obb1.axis[2],
      obb2.axis[0],
      obb2.axis[1],
      obb2.axis[2],
    ];
    for (let i = 0; i < 3; i++) {
      test[6 + i * 3 + 0] = vec3.cross(vec3.create(), test[i], test[0]);
      test[6 + i * 3 + 1] = vec3.cross(vec3.create(), test[i], test[1]);
      test[6 + i * 3 + 2] = vec3.cross(vec3.create(), test[i], test[2]);
    }

    for (let i = 0; i < 15; i++) {
      const interval1 = obb1.axisInterval(test[i]);
      const interval2 = obb2.axisInterval(test[i]);
      if (!(interval2.min <= interval1.max && interval1.min <= interval2.max)) {
        return false;
      }
    }
    return true;
  }

  testImageIntersection(images) {
    return images.map((i) => {
      const imageObb = Obb.fromImage(i);
      return this.testObbIntersection(imageObb);
    });
  }

  testImagesInInterval(axis, images) {
    if (!images && !images.length && !images[0].imagepositionpatient) return;
    const ipps = images.map((i) => i.imagepositionpatient);
    const interval = this.axisInterval(axis);
    const inInterval = ipps.map((ipp) => {
      const offset = vec3.dot(axis, ipp);
      return offset <= interval.max + EPS && offset >= interval.min - EPS;
    });
    return inInterval;
  }

  testPoint(p) {
    const dir = vec3.sub(vec3.create(), p, this.center);
    for (let i = 0; i < 3; i++) {
      const axis = this.axis[i];
      const distance = vec3.dot(dir, axis);
      if (distance > this.halfWidth[i] || distance < -this.halfWidth[i]) {
        return false;
      }
    }
    return true;
  }

  static fromImage(im, depthSpacing = 3) {
    const size = vec2.fromValues(im.rows, im.columns);
    const spacing = vec3.fromValues(...im.pixelspacing, depthSpacing);
    const { rowVec, colVec, depthVec } = imageOrientationToUnitVectors(im);
    const firstCorner = vec3.clone(im.imagepositionpatient);

    const opposedCorner = vec3.clone(im.imagepositionpatient);
    vec3.scaleAndAdd(opposedCorner, opposedCorner, rowVec, size[0] * spacing[0]);
    vec3.scaleAndAdd(opposedCorner, opposedCorner, colVec, size[1] * spacing[1]);

    const center = vec3.lerp(vec3.create(), firstCorner, opposedCorner, 0.5);
    const axis = [vec3.clone(rowVec), vec3.clone(colVec), vec3.clone(depthVec)];
    const halfWidth = vec3.fromValues(
      (size[0] * spacing[0]) / 2,
      (size[1] * spacing[1]) / 2,
      spacing[2] / 2
    );
    return new Obb(center, axis, halfWidth);
  }

  static fromImagePair(im1, im2, depthSpacing = 3) {
    const size = vec2.fromValues(im1.rows, im1.columns);
    const spacing = vec3.fromValues(...im1.pixelspacing, depthSpacing);
    const { rowVec, colVec, depthVec } = imageOrientationToUnitVectors(im1);
    const { rowVec: rowVec2, colVec: colVec2 } = imageOrientationToUnitVectors(im2);

    if (im1.rows !== im2.rows || im1.columns !== im2.columns) {
      console.warn(`Row or columns mismatch`);
      return;
    }

    if (!vec2.equals(spacing, vec2.fromValues(...im2.pixelspacing))) {
      console.warn(`Pixel spacing not equal`);
      return;
    }

    if (im1 !== im2 && vec3.equals(im1.imagepositionpatient, im2.imagepositionpatient)) {
      console.warn(`Both images at same location - Time series?`);
      return;
    }

    if (!sameDir(rowVec, rowVec2) || !sameDir(colVec, colVec2)) {
      console.warn(`Incompatible image orientation `);
      return;
    }

    const firstCorner = vec3.clone(im1.imagepositionpatient);

    const opposedCorner = vec3.clone(im2.imagepositionpatient);
    vec3.scaleAndAdd(opposedCorner, opposedCorner, rowVec, size[0] * spacing[0]);
    vec3.scaleAndAdd(opposedCorner, opposedCorner, colVec, size[1] * spacing[1]);

    const AB = vec3.sub(vec3.create(), im2.imagepositionpatient, im1.imagepositionpatient);
    const distance = Math.abs(vec3.dot(AB, depthVec));

    const center = vec3.lerp(vec3.create(), firstCorner, opposedCorner, 0.5);
    const axis = [vec3.clone(rowVec), vec3.clone(colVec), vec3.clone(depthVec)];
    const halfWidth = vec3.fromValues(
      (size[0] * spacing[0]) / 2,
      (size[1] * spacing[1]) / 2,
      (distance + spacing[2]) / 2
    );
    return new Obb(center, axis, halfWidth);
  }
}

export { Obb, Obb as default };
