import { Quaternion, Vector2 } from 'three';

import type { CartesianPose, Plane } from '@sb/geometry';
import {
  applyCompoundPose,
  cameraPoseFromWristPose,
  castCameraRay,
  findRayPlaneIntersection,
  getPlanePose,
  getTooltipOrientationPerpendicularToPlane,
  invertPose,
  movePlaneInItsZDirection,
  ZERO_POSE,
} from '@sb/geometry';
import type { WristCamera } from '@sb/integrations/implementations/WristCamera/implementation/WristCamera';
import type {
  CameraIntegration,
  CameraIntrinsics,
} from '@sb/integrations/types/cameraTypes';
import { makeNamespacedLog } from '@sb/log';
import type {
  Blob2D,
  LocateMethod,
  RegionOfInterest,
  RoutineRunnerState,
  VisionInterface,
  Space,
} from '@sb/routine-runner';
import { isNotUndefined } from '@sb/utilities';

import { getCalibrationOffset } from './calibrationHelpers';

const log = makeNamespacedLog('runLocate');

/**
 * We allow users to define a tranform to offset locate results by a set amount relative to the plane.
 * This applies that transform to a blob
 */
function applyUserDefinedTransform(
  transform: CartesianPose,
  blobRelativeToBase: CartesianPose,
  planePoseRelativeToBase: CartesianPose,
): CartesianPose {
  // The user-defined tranform is relative to the plane, so get the
  // blobs in plane coordinates first, then apply the transform
  // and then transform back to base coordinates
  const blobRelativeToPlane = applyCompoundPose(
    blobRelativeToBase,
    invertPose(planePoseRelativeToBase),
  );

  // apply the rotation separately so it doesn't change the position
  const rotation = new Quaternion(
    blobRelativeToPlane.i,
    blobRelativeToPlane.j,
    blobRelativeToPlane.k,
    blobRelativeToPlane.w,
  );

  const offsetRotation = new Quaternion(
    transform.i,
    transform.j,
    transform.k,
    transform.w,
  );

  const newRotation = offsetRotation.multiply(rotation);

  const offsetBlob = {
    x: blobRelativeToPlane.x + transform.x,
    y: blobRelativeToPlane.y + transform.y,
    z: blobRelativeToPlane.z + transform.z,
    i: newRotation.x,
    j: newRotation.y,
    k: newRotation.z,
    w: newRotation.w,
  };

  const offsetBlobRelativeToBase = applyCompoundPose(
    offsetBlob,
    planePoseRelativeToBase,
  );

  return offsetBlobRelativeToBase;
}

/**
 * Takes 2D pixel coordinates from a camera image and returns
 * the 3D coordinates of the point in the base's coordinate system
 * */
function deproject(
  blobResult: { x: number; y: number; rotation: number },
  cameraIntrinsics: CameraIntrinsics,
  cameraPose: CartesianPose,
  plane: Plane,
): CartesianPose | undefined {
  const rayOrigin = {
    x: cameraPose.x,
    y: cameraPose.y,
    z: cameraPose.z,
  };

  const rayDirection = castCameraRay(
    cameraIntrinsics,
    cameraPose,
    new Vector2(blobResult.x, blobResult.y),
  );

  // Now that we have ray and base in same coordinate system
  // we can find the intersection of the ray with the plane
  const intersectionPoint = findRayPlaneIntersection(
    rayOrigin,
    rayDirection,
    plane,
  );

  if (intersectionPoint === undefined) {
    return undefined;
  }

  const tooltipPose = getTooltipOrientationPerpendicularToPlane(
    plane,
    blobResult.rotation,
  );

  const intersection: CartesianPose = {
    ...intersectionPoint,
    i: tooltipPose.x,
    j: tooltipPose.y,
    k: tooltipPose.z,
    w: tooltipPose.w,
  };

  return intersection;
}

function getCalibrationCorrection(
  pose: CartesianPose,
  calibration: Space.AccuracyCalibrationEntry[],
): CartesianPose {
  if (calibration.length === 0) {
    log.info('getCalibrationCorrection.none', 'No calibration data found');

    return ZERO_POSE;
  }

  const calibrationOffset = getCalibrationOffset(pose, calibration);

  log.info(
    'getCalibrationCorrection.offset',
    'Calibration offset',
    calibrationOffset,
  );

  return calibrationOffset;
}

interface RunLocateArgs {
  wristCamera: WristCamera;
  camera: CameraIntegration;
  regionOfInterest: RegionOfInterest;
  method: LocateMethod;
  plane: Plane;
  resultsLimit: number | undefined;
  transform: CartesianPose | null;
  accuracyCalibration: Space.AccuracyCalibrationEntry[] | null;
  vision: VisionInterface;
  objectHeight: number;
  getState: () => RoutineRunnerState;
}

async function get2DBlobs({
  wristCamera,
  camera,
  method,
  regionOfInterest,
  vision,
}: RunLocateArgs): Promise<Blob2D[]> {
  const image = await wristCamera.getColorFrame(camera);

  switch (method.kind) {
    case 'BlobDetection2D': {
      return vision.detect2DBlobs(image, regionOfInterest, method.settings);
    }
    case 'ShapeDetection2D': {
      return vision.detect2DShapes(
        image,
        method.templateImage,
        regionOfInterest,
        method.settings,
      );
    }
    // This locates the center of a the chessboard used for accuracy calibration
    case 'ChessboardDetection2D': {
      const result = await vision.getChessboardCorners(image, 3, 3);

      return [
        {
          x: result.corners[4].x, // center corner
          y: result.corners[4].y, // center corner
          width: 1,
          height: 1,
          score: 0,
          rotation: 0,
          contour: [],
        },
      ];
    }
    default:
      throw new Error(`Unsupported method ${method}`);
  }
}

export async function runLocate(args: RunLocateArgs): Promise<
  Array<{
    blob: Blob2D;
    pose: CartesianPose;
  }>
> {
  const blobs2D = await get2DBlobs(args);

  let cameraCorrection = ZERO_POSE;

  const intrinsics = await args.wristCamera.getIntrinsics();

  if (!intrinsics) {
    throw new Error('No intrinsics found for wrist camera');
  }

  const { wristPose } = args.getState().kinematicState;

  const cameraPose = cameraPoseFromWristPose(wristPose);

  log.info(
    'blobs',
    'Blobs detected',
    blobs2D.map((item) => ({ ...item, contour: undefined })),
  );

  blobs2D.sort((a, b) => b.score - a.score);

  const filteredBlobs2D = blobs2D.slice(0, args.resultsLimit ?? blobs2D.length);

  // Move plane in z direction by object height
  const heightAdjustedPlane = movePlaneInItsZDirection(
    args.plane,
    args.objectHeight,
  );

  const planePoseRelativeToBase = getPlanePose(heightAdjustedPlane);

  log.info(
    'Accuracy calibration',
    'Accuracy calibration',
    args.accuracyCalibration,
  );

  const blobsAndPoses = filteredBlobs2D
    .map((blob) => {
      let pose = deproject(blob, intrinsics, cameraPose, args.plane);

      if (!pose) {
        return undefined;
      }

      if (args.accuracyCalibration) {
        cameraCorrection = getCalibrationCorrection(
          pose,
          args.accuracyCalibration,
        );

        // Directly apply the calibration correction to the intersection point
        pose = applyCompoundPose(pose, cameraCorrection);
      }

      // apply user-defined offset and rotation
      if (args.transform) {
        pose = applyUserDefinedTransform(
          args.transform,
          pose,
          planePoseRelativeToBase,
        );
      }

      return { blob, pose };
    })
    .filter(isNotUndefined);

  log.info(
    'blobs.transformed',
    'Transformed blobs',
    blobsAndPoses.map((item) => ({
      ...item,
      blob: { ...item.blob, contour: undefined },
    })),
  );

  return blobsAndPoses;
}
