import kotlin.math.abs
import kotlin.math.floor
import kotlin.properties.Delegates

class AnimationClip(
    val start: Float,
    val end: Float,
    val channels: List<Channel>
) {
    operator fun get(time: Float) = evaluate(time)
    fun evaluate(time: Float): Pose {
        val channelValues = mutableListOf<Float>()

        for (channel in channels) {
            channelValues.add(channel[time])
        }

        return Pose(channelValues)
    }

    companion object Loader {

        private enum class State {
            ROOT,
            ANIMATION,
            CHANNEL,
            KEYS,
        }

        fun fromSource(source: String): AnimationClip {
            var start: Float = 0.0f
            var end: Float = 0.0f
            val channels = mutableListOf<Channel>()

            var state = State.ROOT

            val lines = source.split("\n")

            var extrapolateIn: ExtrapolationMode = ExtrapolationMode.CONSTANT
            var extrapolationOut: ExtrapolationMode = ExtrapolationMode.CONSTANT
            var keys = mutableListOf<Keyframe>()

            for ((lineNumber, line) in lines.withIndex()) {
                val tokens = tokenizeLine(line)

                if (tokens.isEmpty())
                    continue

//                console.info("[$lineNumber] \t$state: \t$tokens")

                when (state) {
                    State.ROOT -> {
                        when (tokens[0]) {
                            "animation" -> state = State.ANIMATION  // Not checking for { because lazy
                            else -> printUnexpectedTokenWarning(lineNumber, line, tokens[0])
                        }
                    }
                    State.ANIMATION -> {
                        when (tokens[0]) {
                            "range" -> {
                                if (tokens.size != 3) {
                                    printArgumentsNumberWarning(lineNumber, line, tokens[0], 2)
                                    continue
                                }

                                val startValue = tokens[1].toFloatOrNull()
                                val endValue = tokens[2].toFloatOrNull()

                                if ((startValue === null) or (endValue === null)) {
                                    console.warn("[$lineNumber]\t$tokens[0]: Expecting 2 numbers, start and end.\n\t$line")
                                    continue
                                }

                                start = startValue!!
                                end = endValue!!
                            }
                            "numchannels" -> {} // Not needed
                            "channel" -> { state = State.CHANNEL }  // Not checking for { because lazy
                            "}" -> { state = State.ROOT }
                            else -> printUnexpectedTokenWarning(lineNumber, line, tokens[0])
                        }
                    }
                    State.CHANNEL -> {
                        when (tokens[0]) {
                            "extrapolate" -> {
                                if (tokens.size != 3) {
                                    printArgumentsNumberWarning(lineNumber, line, tokens[0], 2)
                                    continue
                                }

                                val eIn = ExtrapolationMode.fromString(tokens[1])
                                val eOut = ExtrapolationMode.fromString(tokens[2])

                                if ((eIn === null) or (eOut === null)) {
                                    console.warn("[$lineNumber]\t$tokens[0]: Expecting extrapolation modes.\n\t$line")
                                    continue
                                }

                                extrapolateIn = eIn!!
                                extrapolationOut = eOut!!
                            }
                            "keys" -> {
                                state = State.KEYS
                                keys = mutableListOf()
                            }    // Not checking for { because lazy
                            "}" -> {
                                state = State.ANIMATION

                                for (key in keys) {
                                    key.computeCoefficients()
                                }

                                // Create and add channel
                                channels.add(
                                    Channel(extrapolateIn, extrapolationOut, keys)
                                )
                            }
                            else -> printUnexpectedTokenWarning(lineNumber, line, tokens[0])
                        }
                    }
                    State.KEYS -> {
                        when (tokens[0]) {
                            "}" -> state = State.CHANNEL
                            else -> {
                                // time value tan_in tan_out
                                if (tokens.size != 4) {
                                    printArgumentsNumberWarning(lineNumber, line, tokens[0], 3)
                                    continue
                                }

                                val time = tokens[0].toFloatOrNull()
                                val value = tokens[1].toFloatOrNull()

                                val tangentIn = Tangent.fromString(tokens[2], isIn = true)
                                val tangentOut = Tangent.fromString(tokens[3], isIn = false)

                                if (time === null) {
                                    console.warn("[$lineNumber]\t${tokens[0]} must be a number.\n\t$line")
                                    continue
                                } else if (value == null) {
                                    console.warn("[$lineNumber]\t${tokens[1]} must be a number.\n\t$line\n")
                                    continue
                                } else {
                                    val key = Keyframe(time, value, tangentIn, tangentOut)
                                    val prevKey = keys.lastOrNull()

                                    key.prev = prevKey
                                    prevKey?.next = key

                                    keys.add(
                                        key
                                    )
                                }
                            }
                        }
                    }
                }
            }

            console.info("Animation loader info:")
            console.info("Start: $start")
            console.info("End: $end")
            console.info("numChannels: ${channels.size}")

            return AnimationClip(start, end, channels)
        }

        private fun printUnexpectedTokenWarning(lineNumber: Int, line: String, token: String) {
            console.warn("[$lineNumber]\tUnexpected token \"$token\" in line \"$line\"")
        }

        private fun printArgumentsNumberWarning(lineNumber: Int, line: String, token: String, numArgs: Int) {
            console.warn("[$lineNumber]\t$token: Expecting $numArgs in line \"$line\"")
        }

        /**
         * Splits a line into tokens. Assumes tokens are separated by whitespace.
         */
        private fun tokenizeLine(line: String): List<String> {
            return line.split(" ", "\t", "\n").filter { it.isNotEmpty() }
        }
    }
}

class Channel(
    val extrapolateIn: ExtrapolationMode,
    val extrapolateOut: ExtrapolationMode,
    val keys: List<Keyframe>,
) {
    val timeStart = keys.firstOrNull()?.keyTime
    val timeEnd = keys.lastOrNull()?.keyTime

    var lastUsedSpan: Keyframe? = keys.firstOrNull()

    operator fun get(time: Float) = evaluate(time)
    fun evaluate(time: Float): Float {

        if (lastUsedSpan === null) {
            // No keys, do nothing?
            return 0f
        }

        // TODO: Extrapolation
        val (effectiveTime, offset) = extrapolate(
            time,
            timeStart!!, timeEnd!!,
            keys.first().value, keys.last().value
        )

        // Find span starting from last used
        val span = findSpan(effectiveTime)

        if (span !== null) {
            return span.evaluateWithinSpan(effectiveTime) + offset
        }

        return 0f
    }

    private fun extrapolate(
        time: Float,
        timeStart: Float, timeEnd: Float,
        valueStart: Float, valueEnd: Float
    ): Pair<Float, Float> {
        fun extrapolateCycle(timeStart: Float, time: Float, timeEnd: Float) =
            timeStart + (time - timeStart).mod(timeEnd - timeStart)
        /**
         * Extrapolates wit bounce for both in and out. Used to avoid large amounts of unnecessary recursion.
         */
        fun extrapolateBounceBothSides(time: Float): Float =
            -abs((time - timeStart).mod(timeEnd - timeStart) - (timeEnd - timeStart)) + timeEnd

        return if (time < timeStart) {
            when (extrapolateIn) {
                ExtrapolationMode.CONSTANT -> Pair(timeStart, 0f)
                ExtrapolationMode.LINEAR -> Pair(time, 0f)
                ExtrapolationMode.CYCLE -> Pair(extrapolateCycle(timeStart, time, timeEnd), 0f)
                ExtrapolationMode.CYCLE_OFFSET -> Pair(
                    extrapolateCycle(timeStart, time, timeEnd),
                    -(valueEnd - valueStart) * floor((time - timeStart)/(timeEnd - timeStart))
                )
                ExtrapolationMode.BOUNCE -> if (extrapolateOut == ExtrapolationMode.BOUNCE) {
                    Pair(extrapolateBounceBothSides(time), 0f)
                }
                else {
                    val newTime = timeStart + abs(time - timeStart)
                    extrapolate(newTime, timeStart, timeEnd, valueStart, valueEnd)
                }
            }
        } else if (time > timeEnd) {
            when (extrapolateOut) {
                ExtrapolationMode.CONSTANT -> Pair(timeEnd, 0f)
                ExtrapolationMode.LINEAR -> Pair(time, 0f)
                ExtrapolationMode.CYCLE -> Pair(extrapolateCycle(timeStart, time, timeEnd), 0f)
                ExtrapolationMode.CYCLE_OFFSET -> Pair(
                    extrapolateCycle(timeStart, time, timeEnd),
                    (valueEnd - valueStart) * floor((time - timeStart)/(timeEnd - timeStart))
                )
                ExtrapolationMode.BOUNCE -> if (extrapolateOut == ExtrapolationMode.BOUNCE) {
                    Pair(extrapolateBounceBothSides(time), 0f)
                }
                else {
                    val newTime = timeEnd - abs(timeEnd - time)
                    extrapolate(newTime, timeStart, timeEnd, valueStart, valueEnd)
                }
            }
        } else {
            Pair(time, 0f)
        }
    }


    private fun findSpan(time: Float, startingSpan: Keyframe? = lastUsedSpan): Keyframe? {
        if (startingSpan === null)
            return null

        return when (startingSpan.compareTo(time)) {
            -1 -> {
                if (startingSpan.prev === null)
                    startingSpan
                else
                    findSpan(time, startingSpan.prev)
            }
            0 -> startingSpan
            1 -> findSpan(time, startingSpan.next)
            else -> throw IllegalStateException("This case should not be possible.")
        }
    }
}

data class Keyframe(
    val keyTime: Float, val value: Float,
    val tangentIn: Tangent, val tangentOut: Tangent) {

    var prev: Keyframe? = null
    var next: Keyframe? = null

    // Cubic coefficients
    var a by Delegates.notNull<Float>()
    var b by Delegates.notNull<Float>()
    var c by Delegates.notNull<Float>()
    var d by Delegates.notNull<Float>()

    fun computeCoefficients() {
        val p0 = this.value
        val p1 = next?.value
        val t0 = this.keyTime
        val t1 = next?.keyTime
        val v0 = this.getSlope(false)
        val v1 = next?.getSlope(true)

        if (next === null) {
            // No next keyframe, so... use linear as placeholder?
            // TODO: Figure out if linear or constant?
            a = 0f
            b = 0f
            c = v0
            d = p0
        } else {
            val nV0 = (t1!! - t0) * v0
            val nV1 = (t1 - t0) * v1!!

            a = ( 2 * p0) + (-2 * p1!!) + ( 1 * nV0) + ( 1 * nV1)
            b = (-3 * p0) + ( 3 * p1) + (-2 * nV0) + (-1 * nV1)
            c = ( 0 * p0) + ( 0 * p1) + ( 1 * nV0) + ( 0 * nV1)
            d = ( 1 * p0) + ( 0 * p1) + ( 0 * nV0) + ( 0 * nV1)
        }
    }

    fun evaluateWithinSpan(time: Float): Float {

        return if (time < keyTime) {
            // Linearly extrapolate before keyframe
            val slope = getSlope(true)
            slope * (time - keyTime) + value
        } else {

            val u = if (next !== null)
                inverseLerp(time, keyTime, next!!.keyTime)
            else
                time - keyTime

            d + u * (c + u * (b + u * a))
        }
    }

    operator fun compareTo(time: Float): Int {
        return if (next === null) {
            if (time >= keyTime) 0 else -1
        } else {
            if (time > next!!.keyTime) 1
            else if (time < keyTime) -1
            else 0
        }
    }

    fun getSlope(isTangentIn: Boolean): Float {

        val prevKey: Keyframe = if (prev === null) this else prev!!
        val nextKey: Keyframe = if (next === null) this else next!!

        val tangent: Tangent = if (isTangentIn) tangentIn else tangentOut

        return when (tangent.tangentType) {
            Tangent.Companion.TangentType.FLAT -> 0.0f
            Tangent.Companion.TangentType.LINEAR -> {
                val first = if (tangent.isTangentIn) prevKey else this
                val second = if (tangent.isTangentIn) this else nextKey

                if (first == second) {
                    0f
                } else {
                    (second.value - first.value) / (second.keyTime - first.keyTime)
                }
            }
            Tangent.Companion.TangentType.SMOOTH -> {
                val first = prevKey
                val second = nextKey

                if (first == second) {
                    0f
                } else {
                    (second.value - first.value) / (second.keyTime - first.keyTime)
                }
            }
            Tangent.Companion.TangentType.FIXED -> tangent.slope
        }
    }
}

data class Tangent(val tangentType: TangentType, val isTangentIn: Boolean, val slope: Float = 0f) {

    companion object {

        enum class TangentType {
            FLAT, LINEAR, SMOOTH, FIXED
        }
        fun fromString(token: String, isIn: Boolean): Tangent {
            val asNumber = token.toFloatOrNull()

            return if (asNumber != null) {
                Tangent(TangentType.FIXED, isIn, asNumber)
            } else {
                Tangent(
                    when (token) {
                        "flat" -> TangentType.FLAT
                        "linear" -> TangentType.LINEAR
                        "smooth" -> TangentType.SMOOTH
                        else -> {
                            console.warn("Warning: $token is not a valid tangent type. Using SMOOTH instead.")
                            TangentType.SMOOTH
                        }
                    },
                    isIn
                )
            }
        }
    }
}

enum class ExtrapolationMode {
    CONSTANT, LINEAR, CYCLE, CYCLE_OFFSET, BOUNCE;


    companion object {
        fun fromString(token: String): ExtrapolationMode? {
            return when (token) {
                "constant" -> CONSTANT
                "linear" -> LINEAR
                "cycle" -> CYCLE
                "cycle_offset" -> CYCLE_OFFSET
                "bounce" -> BOUNCE
                else -> null
            }
        }
    }
}

private fun inverseLerp(t: Float, t0: Float, t1: Float): Float {
    return (t - t0) / (t1 - t0)
}
