import com.tanelso2.glmatrix.Mat4
import com.tanelso2.glmatrix.Vec3
import org.khronos.webgl.Float32Array
import org.khronos.webgl.Uint16Array
import org.khronos.webgl.WebGLProgram
import org.khronos.webgl.get
import kotlin.math.PI
import org.khronos.webgl.WebGLRenderingContext as GL

open class Cube(val gl: GL, private val cubeMin: Vec3, private val cubeMax: Vec3) {
    constructor(gl: GL) : this(gl, Vec3(-1f, -1f, -1f), Vec3(1f, 1f, 1f))
    constructor(gl: GL, cubeMin: Triple<Number, Number, Number>, cubeMax: Triple<Number, Number, Number>) : this(
        gl,
        Vec3(cubeMin.first, cubeMin.second, cubeMin.third),
        Vec3(cubeMax.first, cubeMax.second, cubeMax.third)
    )

    // Model matrix
    var model = Mat4()

    private val positionBuffer = gl.createBuffer()
    val normalBuffer = gl.createBuffer()
    val indexBuffer = gl.createBuffer()

    init {
        initBuffers()
    }

    fun clean(){
        deleteBuffers()
    }

    private fun deleteBuffers() {
        gl.deleteBuffer(positionBuffer)
        gl.deleteBuffer(normalBuffer)
        gl.deleteBuffer(indexBuffer)
    }

    fun draw(projMat: Mat4, viewMat: Mat4, shaderProgram: ShaderProgram) {
        draw(projMat, viewMat, shaderProgram, Mat4())
    }

    open fun draw(projMat: Mat4, viewMat: Mat4, shaderProgram: ShaderProgram, transform: Mat4) {
        val program = shaderProgram.program

        val transformedModel = viewMat * transform * model

        val normalTransform = transformedModel * Mat4()
        normalTransform.invert()
        normalTransform.transpose()

        // Set attributes
        setPositionAttribute(program)
        setNormalAttribute(program)

        gl.bindBuffer(GL.ELEMENT_ARRAY_BUFFER, indexBuffer)

        // Get uniform locations
        val viewProjLocation = gl.getUniformLocation(program, "uProjectionMatrix")
        val modelLocation = gl.getUniformLocation(program, "uModelViewMatrix")
        val normalTransformLocation = gl.getUniformLocation(program, "uNormalTransformMatrix")

        // Get uniforms
        gl.uniformMatrix4fv(viewProjLocation, false, projMat.array)
        gl.uniformMatrix4fv(modelLocation, false, transformedModel.array)
        gl.uniformMatrix4fv(normalTransformLocation, false, normalTransform.array)

        // draw
        gl.drawElements(GL.TRIANGLES, 36, GL.UNSIGNED_SHORT, 0)
    }

    fun update() {
        spin(0.05f)
    }

    fun spin(degrees: Float) {
        model.rotateY(degrees * PI / 180f)
    }

    private fun initBuffers() {
        initPositionBuffer()
        initNormalBuffer()
        initIndexBuffer()
    }

    private fun initPositionBuffer() {
        gl.bindBuffer(GL.ARRAY_BUFFER, positionBuffer)

        // Set positions
        val minX = cubeMin.array[0]
        val minY = cubeMin.array[1]
        val minZ = cubeMin.array[2]
        val maxX = cubeMax.array[0]
        val maxY = cubeMax.array[1]
        val maxZ = cubeMax.array[2]
        val positions = arrayOf(
            // front
            minX, minY, maxZ,
            maxX, minY, maxZ,
            maxX, maxY, maxZ,
            minX, maxY, maxZ,

            // back
            maxX, minY, minZ,
            minX, minY, minZ,
            minX, maxY, minZ,
            maxX, maxY, minZ,

            // top
            minX, maxY, maxZ,
            maxX, maxY, maxZ,
            maxX, maxY, minZ,
            minX, maxY, minZ,

            // bottom
            minX, minY, minZ,
            maxX, minY, minZ,
            maxX, minY, maxZ,
            minX, minY, maxZ,

            // left
            minX, minY, minZ,
            minX, minY, maxZ,
            minX, maxY, maxZ,
            minX, maxY, minZ,

            // right
            maxX, minY, maxZ,
            maxX, minY, minZ,
            maxX, maxY, minZ,
            maxX, maxY, maxZ,
        )

        // Pass into WebGL
        gl.bufferData(GL.ARRAY_BUFFER, Float32Array(positions), GL.STATIC_DRAW)
    }

    open fun initNormalBuffer() {
        val faceNormal = arrayOf(
            arrayOf(0f, 0f, 1f),
            arrayOf(0f, 0f, -1f),
            arrayOf(0f, 1f, 0f),
            arrayOf(0f, -1f, 1f),
            arrayOf(-1f, 0f, 0f),
            arrayOf(1f, 0f, 0f),
        )

        // Each x4; 1 for each vertex
        val normals: MutableList<Float> = mutableListOf()

        for (normal in faceNormal) {
            normals.addAll(normal)
            normals.addAll(normal)
            normals.addAll(normal)
            normals.addAll(normal)
        }

        // Buffer stuff
        gl.bindBuffer(GL.ARRAY_BUFFER, normalBuffer)
        gl.bufferData(GL.ARRAY_BUFFER, Float32Array(normals.toTypedArray()), GL.STATIC_DRAW)
    }

    open fun initIndexBuffer() {
        gl.bindBuffer(GL.ELEMENT_ARRAY_BUFFER, indexBuffer)

        val indices: Array<Short> = arrayOf(
            0, 1, 2, 0, 2, 3,       // front
            4, 5, 6, 4, 6, 7,       // back
            8, 9, 10, 8, 10, 11,    // top
            12, 13, 14, 12, 14, 15, // bottom
            16, 17, 18, 16, 18, 19, // left
            20, 21, 22, 20, 22, 23, // right
        )

        gl.bufferData(
            GL.ELEMENT_ARRAY_BUFFER,
            Uint16Array(indices),
            GL.STATIC_DRAW
        )
    }

    fun setPositionAttribute(shaderProgram: WebGLProgram) {
        val vertexPosition = gl.getAttribLocation(shaderProgram, "aVertexPosition")

        gl.bindBuffer(GL.ARRAY_BUFFER, positionBuffer)
        gl.vertexAttribPointer(vertexPosition, 3, GL.FLOAT, false, 0, 0)
        gl.enableVertexAttribArray(vertexPosition)
    }

    fun setNormalAttribute(shaderProgram: WebGLProgram) {
        val vertexColor = gl.getAttribLocation(shaderProgram, "aVertexNormal")

        gl.bindBuffer(GL.ARRAY_BUFFER, normalBuffer)
        gl.vertexAttribPointer(vertexColor, 3, GL.FLOAT, false, 0, 0)
        gl.enableVertexAttribArray(vertexColor)
    }
}