
import com.tanelso2.glmatrix.Mat4
import com.tanelso2.glmatrix.Vec3
import com.tanelso2.glmatrix.Vec4
import com.tanelso2.glmatrix.vec3
import org.khronos.webgl.Float32Array
import org.khronos.webgl.Uint16Array
import org.khronos.webgl.WebGLProgram
import org.khronos.webgl.get
import org.khronos.webgl.WebGLRenderingContext as GL

class Skin(
    private val gl: GL,
    private val positions: Array<Float>,
    private val normals: Array<Float>,
    private val skinWeights: List<List<Pair<Int, Float>>>,
    private val triangles: Array<Short>,
    bindings: List<Mat4>,
) {

    private var model = Mat4()

    private var inverseBindings: List<Mat4>

    private var hasInitializedTransformed = false
    private val transformedPositions = mutableListOf<Float>()
    private val transformedNormals = mutableListOf<Float>()

    private val positionBuffer = gl.createBuffer()
    private val normalBuffer = gl.createBuffer()
    private val indexBuffer = gl.createBuffer()

    init {
        initBuffers()
        inverseBindings = invertBindings(bindings)
    }

    fun update(skeleton: Skeleton){
        if (skeleton.hasUpdated or !hasInitializedTransformed) {
            // Compute skinning matrix for each joint
            // M_i = W_i * B_i^{-1} for each joint
            val jointTransforms = computeSkinningTransforms(skeleton)

            // Compute normal transform matrices
            val jointNormalTransforms = computeSkinningNormalTransforms(jointTransforms)

            // Compute blended transformed positions and normals
            computeTransformed(jointTransforms, jointNormalTransforms)

            setTransformedPositionBuffer()
            setTransformedNormalBuffer()

            skeleton.hasUpdated = false
            hasInitializedTransformed = true
        }
    }

    private fun computeTransformed(
        jointTransforms: List<Mat4>,
        jointNormalTransforms: List<Mat4>
    ) {
        transformedPositions.clear()
        transformedNormals.clear()

        for ((vertexIndex, weights) in skinWeights.withIndex()) {
            val bindingPosition = Vec3(
                positions[vertexIndex * 3],
                positions[vertexIndex * 3 + 1],
                positions[vertexIndex * 3 + 2],
            )
            val bindingNormal = Vec3(
                normals[vertexIndex * 3],
                normals[vertexIndex * 3 + 1],
                normals[vertexIndex * 3 + 2],
            )

            bindingNormal.normalize()

            var transformedPosition = Vec3()
            var transformedNormal = Vec3()

            for ((jointIndex, weight) in weights) {
                val jointTransform = jointTransforms[jointIndex] * Mat4()
                val jointNormalTransform = jointNormalTransforms[jointIndex] * Mat4()

                var weightedPosition: Vec3
                var weightedNormal: Vec3

                // M * v
                weightedPosition = jointTransform * bindingPosition

                // * w
                vec3.scale(weightedPosition.array, weightedPosition.array, weight)

                // M-T * v
                weightedNormal = Vec3(
                    jointNormalTransform * Vec4(bindingNormal, 0),
                    isNormal = true
                )

                // Normalize
                weightedNormal.normalize()

                //  * w
                vec3.scale(weightedNormal.array, weightedNormal.array, weight)

                transformedPosition += weightedPosition
                transformedNormal += weightedNormal
            }

            // Normalize transformed normal
            transformedNormal.normalize()

            transformedPositions.add(transformedPosition.array[0])
            transformedPositions.add(transformedPosition.array[1])
            transformedPositions.add(transformedPosition.array[2])

            transformedNormals.add(transformedNormal.array[0])
            transformedNormals.add(transformedNormal.array[1])
            transformedNormals.add(transformedNormal.array[2])
        }
    }

    private fun computeSkinningNormalTransforms(jointTransforms: List<Mat4>): List<Mat4> {
        val jointNormalTransforms = jointTransforms.map {
            val normalTransform = it * Mat4()
            normalTransform.transpose()
            normalTransform.invert()

            normalTransform
        }

        return jointNormalTransforms
    }

    private fun computeSkinningTransforms(skeleton: Skeleton): List<Mat4> {

        val jointTransforms: List<Mat4> = inverseBindings.withIndex().map {
            val (index, inverseBinding) = it

            val world = if (index < skeleton.jointList.size) skeleton.jointList[index].worldModel else Mat4()

            world * inverseBinding
        }

        return jointTransforms
    }


    fun draw(projMatrix: Mat4, viewMatrix: Mat4, shaderProgram: ShaderProgram) {
        draw(projMatrix, viewMatrix, shaderProgram, Mat4())
    }

    fun draw(projMatrix: Mat4, viewMatrix: Mat4, shaderProgram: ShaderProgram, parentModel: Mat4 = Mat4()){

        val program = shaderProgram.program

        val modelView = viewMatrix * parentModel * model

        val normalTransform = modelView * Mat4()
        normalTransform.invert()
        normalTransform.transpose()

        // Set attributes
        setPositionAttribute(program)
        setNormalAttribute(program)

        // Bind index buffer (triangles)
        gl.bindBuffer(GL.ELEMENT_ARRAY_BUFFER, indexBuffer)

        // Get uniform locations
        val projLocation = gl.getUniformLocation(program, "uProjectionMatrix")
        val modelViewLocation = gl.getUniformLocation(program, "uModelViewMatrix")
        val viewLocation = gl.getUniformLocation(program, "uViewMatrix")
        val normalTransformLocation = gl.getUniformLocation(program, "uNormalTransformMatrix")

        // Get uniforms
        gl.uniformMatrix4fv(projLocation, false, projMatrix.array)
        gl.uniformMatrix4fv(modelViewLocation, false, modelView.array)
        gl.uniformMatrix4fv(viewLocation, false, viewMatrix.array)
        gl.uniformMatrix4fv(normalTransformLocation, false, normalTransform.array)

        // Draw
        gl.drawElements(GL.TRIANGLES, triangles.size, GL.UNSIGNED_SHORT, 0)
    }

    fun isEmpty(): Boolean {
        return positions.isEmpty()
    }

    /**
     * Deletes buffers. Do not use this object after calling this.
     */
    fun clean() {
        deleteBuffers()
    }

    private fun deleteBuffers() {
        gl.deleteBuffer(positionBuffer)
        gl.deleteBuffer(normalBuffer)
        gl.deleteBuffer(indexBuffer)
    }

    private fun setTransformedNormalBuffer() {
        gl.bindBuffer(GL.ARRAY_BUFFER, normalBuffer)
        gl.bufferData(GL.ARRAY_BUFFER, Float32Array(transformedNormals.toTypedArray()), GL.STATIC_DRAW)
    }

    private fun setTransformedPositionBuffer() {
        gl.bindBuffer(GL.ARRAY_BUFFER, positionBuffer)
        gl.bufferData(
            GL.ARRAY_BUFFER,
            Float32Array(transformedPositions.toTypedArray()),
            GL.STATIC_DRAW
        )
    }

    private fun initBuffers() {
        initPositionBuffer()
        initNormalBuffer()
        initIndexBuffer()
    }

    private fun initPositionBuffer(){
        gl.bindBuffer(GL.ARRAY_BUFFER, positionBuffer)
        gl.bufferData(GL.ARRAY_BUFFER, Float32Array(positions), GL.STATIC_DRAW)
    }
    private fun initNormalBuffer(){
        gl.bindBuffer(GL.ARRAY_BUFFER, normalBuffer)
        gl.bufferData(GL.ARRAY_BUFFER, Float32Array(normals), GL.STATIC_DRAW)
    }
    private fun initIndexBuffer(){
        gl.bindBuffer(GL.ELEMENT_ARRAY_BUFFER, indexBuffer)
        gl.bufferData(
            GL.ELEMENT_ARRAY_BUFFER,
            Uint16Array(triangles),
            GL.STATIC_DRAW
        )
    }

    private fun invertBindings(bindings: List<Mat4>): List<Mat4> {
        val inverseBindings = bindings.map {
            it.inverse()
        }

        return inverseBindings
    }

    private fun setPositionAttribute(shaderProgram: WebGLProgram){
        val vertexPosition = gl.getAttribLocation(shaderProgram, "aVertexPosition")

        gl.bindBuffer(GL.ARRAY_BUFFER, positionBuffer)
        gl.vertexAttribPointer(vertexPosition, 3, GL.FLOAT, false, 0, 0)
        gl.enableVertexAttribArray(vertexPosition)
    }

    private fun setNormalAttribute(shaderProgram: WebGLProgram){
        val vertexNormals = gl.getAttribLocation(shaderProgram, "aVertexNormal")

        gl.bindBuffer(GL.ARRAY_BUFFER, normalBuffer)
        gl.vertexAttribPointer(vertexNormals, 3, GL.FLOAT, true, 0, 0)
        gl.enableVertexAttribArray(vertexNormals)
    }


    companion object Loader {

        private enum class Mode {
            ROOT,
            POSITIONS, NORMALS, SKINWEIGHTS,
            TRIANGLES, BINDINGS, MATRIX,
        }

        fun fromSource(gl: GL, source: String): Skin {

            var positions = mutableListOf<Float>()
            var normals = mutableListOf<Float>()
            var skinweights = mutableListOf<List<Pair<Int, Float>>>()
            var triangles = mutableListOf<Short>()
            var bindings = mutableListOf<Mat4>()

            var matrixElements = mutableListOf<Float>()

            val lines = source.split('\n')

            var mode = Mode.ROOT;

            for ((lineNumber, line) in lines.withIndex()) {
                val tokens = tokenizeLine(line)

                if (tokens.isEmpty()) {
                    continue
                }

                when (mode) {
                    Mode.ROOT -> {

                        // Expect changing to another one
                        when (tokens[0]) {
                            "" -> {}
                            "positions" -> {
                                mode = Mode.POSITIONS
                            }
                            "normals" -> {
                                mode = Mode.NORMALS
                            }
                            "skinweights" -> {
                                mode = Mode.SKINWEIGHTS
                            }
                            "triangles" -> {
                                mode = Mode.TRIANGLES
                            }
                            "bindings" -> {
                                mode = Mode.BINDINGS
                            }
                        }
                    }
                    Mode.POSITIONS -> {
                        // Expect 3 numbers or '}'
                        if (tokens[0] == "}") {
                            mode = Mode.ROOT
                            continue
                        }
                        else if (tokens.size != 3) {
                            console.warn("$lineNumber: Error: Expecting } or 3 numbers in position.")
                            continue
                        }

                        val x = tokens[0].toFloatOrNull()
                        val y = tokens[1].toFloatOrNull()
                        val z = tokens[2].toFloatOrNull()

                        if ((x === null) or (y === null) or (z === null)) {
                            console.warn("$lineNumber: Error: Expecting 3 numbers in position")
                            continue
                        }

                        positions.add(x!!)
                        positions.add(y!!)
                        positions.add(z!!)
                    }
                    Mode.NORMALS -> {
                        // Expect 3 numbers or '}'
                        if (tokens[0] == "}") {
                            mode = Mode.ROOT
                            continue
                        }
                        else if (tokens.size != 3) {
                            console.warn("$lineNumber: Error: Expecting } or 3 numbers in normal.")
                            continue
                        }

                        val x = tokens[0].toFloatOrNull()
                        val y = tokens[1].toFloatOrNull()
                        val z = tokens[2].toFloatOrNull()

                        if ((x === null) or (y === null) or (z === null)) {
                            console.warn("$lineNumber: Error: Expecting 3 numbers in normal, got $tokens instead.")
                            continue
                        }

                        normals.add(x!!)
                        normals.add(y!!)
                        normals.add(z!!)
                    }
                    Mode.SKINWEIGHTS -> {
                        // for each vert, joints + weights
                        if (tokens[0] == "}") {
                            mode = Mode.ROOT
                            continue
                        }

                        val numAttachments = tokens[0].toIntOrNull()

                        if (numAttachments === null) {
                            console.warn("$lineNumber: Error: Expecting an int for numAttachments")
                            continue
                        }

                        val vertexSkinWeight = mutableListOf<Pair<Int, Float>>()

                        for (i in 1 until tokens.size-1 step 2) {
                            val joint = tokens[i].toIntOrNull()
                            val weight = tokens[i+1].toFloatOrNull()

                            if (joint === null) {
                                console.warn("$lineNumber: Error: Weight ${tokens[i]} must be an int.")
                                continue
                            }
                            else if (weight === null) {
                                console.warn("$lineNumber: Error: Weight ${tokens[i+1]} must be a number.")
                                continue
                            }

                            vertexSkinWeight.add(Pair(joint, weight))
                        }

                        skinweights.add(vertexSkinWeight)
                    }
                    Mode.TRIANGLES -> {
                        // Expect 3 ints or '}'
                        if (tokens[0] == "}") {
                            mode = Mode.ROOT
                            continue
                        }
                        else if (tokens.size != 3) {
                            console.warn("$lineNumber: Error: Expecting } or 3 numbers in triangle.")
                            continue
                        }

                        val x = tokens[0].toShortOrNull()
                        val y = tokens[1].toShortOrNull()
                        val z = tokens[2].toShortOrNull()

                        if ((x === null) or (y === null) or (z === null)) {
                            console.warn("$lineNumber: Error: Expecting 3 numbers in triangle, got $tokens instead")
                            continue
                        }

                        triangles.add(x!!)
                        triangles.add(y!!)
                        triangles.add(z!!)
                    }
                    Mode.BINDINGS -> {
                        // Expect matrix or '}'
                        if (tokens[0] == "}") {
                            mode = Mode.ROOT
                            continue
                        }
                        else if (tokens[0] == "matrix") {
                            mode = Mode.MATRIX
                            continue
                        }
                    }
                    Mode.MATRIX -> {
                        // Expect 12 numbers total
                        if (tokens[0] == "}") {
                            mode = Mode.BINDINGS

                            val matrix = bindingElementsToMat4(matrixElements)

                            bindings.add(matrix)
                            matrixElements = mutableListOf<Float>()
                            continue
                        }

                        for (token in tokens) {
                            val num = token.toFloatOrNull()

                            if (num === null) {
                                console.warn("$lineNumber: Error: $token is expected to be a number.")
                                continue
                            }

                            matrixElements.add(num)
                        }
                    }
                }
            }

            return Skin(
                gl,
                positions.toTypedArray(),
                normals.toTypedArray(),
                skinweights,
                triangles.toTypedArray(),
                bindings
            )
        }

        /**
         * Splits a line into tokens. Assumes tokens are separated by whitespace.
         */
        private fun tokenizeLine(line: String): List<String> {
            return line.split(" ", "\t", "\n").filter { it.isNotEmpty() }
        }

        /**
         * Takes the elements of a binding matrix and creates a Mat4.
         */
        private fun bindingElementsToMat4(elements: MutableList<Float>): Mat4 {

            // Pad with 0
            while (elements.size < 12)
                elements.add(0.0f)

            val newElements = mutableListOf<Float>()
            newElements.addAll(elements.subList(0, 3))
            newElements.add(0f)
            newElements.addAll(elements.subList(3, 6))
            newElements.add(0f)
            newElements.addAll(elements.subList(6, 9))
            newElements.add(0f)
            newElements.addAll(elements.subList(9, 12))
            newElements.add(1f)

            return Mat4(newElements.toTypedArray())
        }
    }
}