import { GLTF } from "three-stdlib";
import THREE, {
  BufferGeometry,
  CatmullRomCurve3,
  Curve,
  Float32BufferAttribute,
  IUniform,
  MeshStandardMaterial,
  Uniform,
  Vector3,
} from "three";
import { memo, useCallback, useEffect, useMemo, useRef } from "react";
import { useGLTF } from "@react-three/drei";
import { extend, Object3DNode } from "@react-three/fiber";
import shaderSource from "./triple-rail-shader.glsl?raw";
import { ControlPoint } from "./types";

const POINTS_COUNT = 128;

const shaderConstants = {
  POINTS_COUNT,
};

const shaderChunks: Record<string, string> = Array.from(
  shaderSource.matchAll(
    /\/\/__(?<name>\w+)__(?<chunk>(?:.|\n)+?(?=(?:\/\/__)|$))/g
  )
).reduce((chunks, { groups }) => {
  if (!groups) {
    return chunks;
  }

  const chunkWithConstants = Object.entries(shaderConstants).reduce(
    (chunk, [key, value]) => {
      return chunk.replaceAll(`__${key}__`, value.toString());
    },
    groups.chunk
  );

  return { ...chunks, [groups.name]: chunkWithConstants };
}, {});

type GLTFResult = GLTF & {
  nodes: {
    TripleRail: THREE.Mesh;
  };
  materials: {
    Material: THREE.MeshStandardMaterial;
  };
};

function repeatGeometry(
  sourceGeometry: BufferGeometry,
  count = 1
): [BufferGeometry, number] {
  const sourceVertices = sourceGeometry.getAttribute("position").array;
  const sourceNormals = sourceGeometry.getAttribute("normal").array;
  const sourceUvs = sourceGeometry.getAttribute("uv").array;
  const sourceIndicesBufferAttribute = sourceGeometry.getIndex();

  if (!sourceIndicesBufferAttribute) {
    throw new Error("No source indices");
  }

  const sourceIndices = sourceIndicesBufferAttribute.array;

  let maxZ = 0;
  let minZ = Infinity;

  for (let i = 2; i < sourceVertices.length; i += 3) {
    if (sourceVertices[i] < minZ) minZ = sourceVertices[i];
    if (sourceVertices[i] > maxZ) maxZ = sourceVertices[i];
  }

  const zLength = maxZ - minZ;

  const vertices = [];
  const uvs = [];
  const normals = [];
  const indices = [];
  for (let i = 0; i < count; i++) {
    for (let j = 0; j < sourceVertices.length; j++) {
      vertices[i * sourceVertices.length + j] =
        sourceVertices[j] + (j % 3 === 2 ? zLength * i - minZ : 0);
    }

    for (let j = 0; j < sourceNormals.length; j++) {
      normals[i * sourceNormals.length + j] = sourceNormals[j];
    }

    for (let j = 0; j < sourceUvs.length; j++) {
      uvs[i * sourceUvs.length + j] = sourceUvs[j];
    }

    for (let j = 0; j < sourceIndices.length; j++) {
      indices[i * sourceIndices.length + j] =
        sourceIndices[j] + i * (sourceVertices.length / 3);
    }
  }

  const geometry = new BufferGeometry();
  geometry.setAttribute("position", new Float32BufferAttribute(vertices, 3));
  geometry.setAttribute("normal", new Float32BufferAttribute(normals, 3));
  geometry.setAttribute("uv", new Float32BufferAttribute(uvs, 2));
  geometry.setIndex(indices);

  return [geometry, zLength];
}

extend({ CatmullRomCurve3 });

declare module "@react-three/fiber" {
  interface ThreeElements {
    catmullRomCurve3: Object3DNode<CatmullRomCurve3, CatmullRomCurve3>;
  }
}

type TripleRailProps = {
  curve: Curve<Vector3>;
  controlPoints: ControlPoint[];
  debug?: boolean;
};

function TripleRail({ curve, controlPoints, debug }: TripleRailProps) {
  const result = useGLTF("/models.glb") as unknown as GLTFResult;
  const [geometry, zLength] = useMemo(() => {
    return repeatGeometry(result.nodes.TripleRail.geometry, 100);
  }, [result]);

  const materialUniformsRef = useRef<Record<string, IUniform>>();
  const materialRef = useRef<MeshStandardMaterial>(null!);

  const curveData = useMemo(() => {
    const curvePoints = curve.getPoints(POINTS_COUNT - 1);

    const curveLengths = curve.getLengths(POINTS_COUNT - 1);

    const controlPointsCurveIndices = controlPoints.map((controlPoint) =>
      curvePoints.findIndex(
        (curvePoint) =>
          curvePoint.distanceTo(new Vector3(...controlPoint.position)) < 1
      )
    );

    const curveRolls = curveLengths.map((length) => {
      const nextControlPointIndex = controlPointsCurveIndices.findIndex(
        (curveIndex) => curveLengths[curveIndex] > length
      );
      if (nextControlPointIndex === -1) {
        return 0;
      }
      const prevControlPointIndex = nextControlPointIndex - 1;
      const distToPrev =
        length - curveLengths[controlPointsCurveIndices[prevControlPointIndex]];
      const distToNext =
        curveLengths[controlPointsCurveIndices[nextControlPointIndex]] - length;
      const distTotal = distToPrev + distToNext;
      const nextInfluence = 1 - distToNext / distTotal;
      const prevInfluence = 1 - distToPrev / distTotal;
      return (
        controlPoints[nextControlPointIndex].roll * nextInfluence +
        controlPoints[prevControlPointIndex].roll * prevInfluence
      );
    });

    const curveTangents = Array.from({ length: POINTS_COUNT }).map((_, i) =>
      curve.getTangent(i / (curvePoints.length - 1))
    );

    const curveNormals = Array.from({ length: POINTS_COUNT }).reduce<Vector3[]>(
      (normals, _, i) => {
        const tangent = curveTangents[i];
        const point = curvePoints[i];
        const nextPoint = curvePoints[(i + 1) % curvePoints.length];
        const forward = nextPoint.clone().sub(point).normalize();
        const normal = new Vector3().copy(tangent).cross(forward).normalize();
        normal.applyAxisAngle(tangent, curveRolls[i]);
        if (normals.length > 0 && normals[normals.length - 1].dot(normal) < 0) {
          normal.multiplyScalar(-1);
        }
        normals.push(normal);
        return normals;
      },
      [] as Vector3[]
    );

    return {
      points: curvePoints,
      lengths: curveLengths,
      tangents: curveTangents,
      normals: curveNormals,
      rolls: curveRolls,
    };
  }, [curve, controlPoints]);

  useEffect(() => {
    if (!materialUniformsRef.current) {
      return;
    }

    materialUniformsRef.current.points.value = curveData.points;
    materialUniformsRef.current.lengths.value = curveData.lengths;
    materialUniformsRef.current.tangents.value = curveData.tangents;
    materialUniformsRef.current.normals.value = curveData.normals;

    materialRef.current.needsUpdate = true;
  }, [curve, controlPoints]);

  const customProgramCacheKey = useCallback(
    () => JSON.stringify(shaderChunks),
    []
  );

  return (
    <mesh frustumCulled={false}>
      {debug && (
        <>
          {curveData.tangents.map((tangent, i) => {
            return (
              <arrowHelper
                key={"tangent" + i}
                args={[tangent, curveData.points[i], 10, "orange"]}
              />
            );
          })}

          {curveData.normals.map((normal, i) => {
            return (
              <arrowHelper
                key={"normal" + i}
                args={[normal, curveData.points[i], 10, "teal"]}
              />
            );
          })}
        </>
      )}
      <primitive object={geometry} attach="geometry" />
      <meshStandardMaterial
        ref={materialRef}
        customProgramCacheKey={customProgramCacheKey}
        onBeforeCompile={(shader: THREE.Shader) => {
          shader.vertexShader = shader.vertexShader
            .replace(
              "void main() {",
              `${shaderChunks.vertexSetup}\nvoid main() {`
            )
            .replace(
              "#include <begin_vertex>",
              `#include <begin_vertex>\n${shaderChunks.vertexBody}`
            );

          shader.fragmentShader = shader.fragmentShader.replace(
            "void main() {",
            `${shaderChunks.fragmentSetup}\nvoid main() {\n${shaderChunks.fragmentBody}`
          );

          materialUniformsRef.current = shader.uniforms;
          shader.uniforms = Object.assign(shader.uniforms, {
            zLength: new Uniform(zLength),
            interval: new Uniform(3),
            points: new Uniform(curveData.points),
            lengths: new Uniform(curveData.lengths),
            tangents: new Uniform(curveData.tangents),
            normals: new Uniform(curveData.normals),
          });
        }}
        map={result.materials.Material.map}
      />
    </mesh>
  );
}

const TripleRailMemo = memo(TripleRail);

export { TripleRailMemo as TripleRail };
