package noria.ui.examples

import androidx.compose.foundation.Canvas
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.sizeIn
import androidx.compose.runtime.*
import androidx.compose.ui.Modifier
import androidx.compose.ui.geometry.toRect
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asComposePaint
import androidx.compose.ui.graphics.drawscope.drawIntoCanvas
import androidx.compose.ui.unit.dp
import fleet.compose.theme.components.gallery.Gallery
import fleet.compose.theme.components.gallery.gallery
import fleet.compose.theme.toSkija
import noria.NoriaContext
import noria.byReference
import noria.memo
import noria.ui.core.NoriaLogger
import noria.ui.core.boundary
import noria.ui.withModifier
import org.jetbrains.skia.Data
import org.jetbrains.skia.Paint
import org.jetbrains.skia.RuntimeEffect
import org.jetbrains.skia.impl.use
import kotlin.system.measureNanoTime
import kotlin.time.Duration
import kotlin.time.Duration.Companion.nanoseconds
import kotlin.time.Duration.Companion.seconds
import kotlin.time.DurationUnit

private fun <T> withMeasureTime(label: String, body: () -> T): T {
  val result: T
  val time = measureNanoTime {
    result = body()
  }
  NoriaLogger.logger.debug { "$label took: ${time.nanoseconds}" }
  return result
}

private val snowSkSl = """
uniform float3 iResolution;      // Viewport resolution (pixels)
uniform float  iTime;            // Shader playback time (s)

// Source: @kamoshika_vrc https://twitter.com/kamoshika_vrc/status/1495081980278751234

const float PI2 = 6.28318530718;
float F(vec2 c){
  return fract(sin(dot(c, vec2(12.9898, 78.233))) * 43758.5453);
}

half4 main(float2 FC) {
  vec4 o = vec4(0., 0., 0., 0.);
  float t = iTime;
  vec2 r = iResolution.xy * vec2(1, -1);
  vec3 R=normalize(vec3((FC.xy*2.-r)/r.y,1));
  for(float i=0; i<100; ++i) {
    float I=floor(t/.1)+i;
    float d=(I*.1-t)/R.z;
    vec2 p=d*R.xy+vec2(sin(t+F(I.xx)*PI2)*.3+F(I.xx*.9),t+F(I.xx*.8));
    if (F(I/100+ceil(p))<.03) {
      o+=smoothstep(.1,0.,length(fract(p)-.5))*exp(-d*d*.005);
    }
  }
  return o;
}

""".trimIndent()

private val spaceSkSl = """
  uniform float3 iResolution;      // Viewport resolution (pixels)
  uniform float  iTime;            // Shader playback time (s)

  // Source: @notargs https://twitter.com/notargs/status/1250468645030858753
  float f(vec3 p) {
      p.z -= iTime * 10.;
      float a = p.z * .1;
      p.xy *= mat2(cos(a), sin(a), -sin(a), cos(a));
      return .1 - length(cos(p.xy) + sin(p.yz));
  }

  half4 main(vec2 fragcoord) { 
      vec3 d = .5 - fragcoord.xy1 / iResolution.y;
      vec3 p=vec3(0);
      for (int i = 0; i < 32; i++) {
        p += f(p) * d;
      }
      return ((sin(p) + vec3(2, 5, 12)) / length(p)).xyz1;
  }
""".trimIndent()

private fun fillShaderData(width: Float, height: Float, time: Float): Data {
  val shaderByteArray = ByteArray(16)
  shaderByteArray.putFloat(0, width)
  shaderByteArray.putFloat(4, height)
  shaderByteArray.putFloat(8, 0f)
  shaderByteArray.putFloat(12, time)
  return Data.makeFromBytes(shaderByteArray)
}

private fun ByteArray.putFloat(index: Int, value: Float) {
  val rawBits = value.toRawBits()
  for (i in 0 until Int.SIZE_BYTES) {
    val bitIndex = i * 8
    this[index + i] = (rawBits shr bitIndex and 0xff).toByte()
  }
}

@Composable
private fun NoriaContext.drawWithShader(skSl: String) {
  var startTime by remember(byReference(skSl)) { mutableStateOf<Duration?>(null) }
  var currentTime by remember(byReference(skSl)) { mutableStateOf(0.seconds) }
  LaunchedEffect(byReference(skSl)) {
    while (true) {
      withFrameNanos { nanos ->
        if (startTime == null) {
          startTime = nanos.nanoseconds
        }
        currentTime = nanos.nanoseconds - startTime!!
      }
    }
  }
  boundary {
    val effect = memo { withMeasureTime("SkSl compilation") { RuntimeEffect.makeForShader(skSl) } }
    Canvas(Modifier.fillMaxSize()) {
      Paint().use { paint ->
        val skijaRect = size.toRect().toSkija()
        val time = currentTime.toDouble(DurationUnit.SECONDS).toFloat()
        paint.shader = effect.makeShader(uniforms = fillShaderData(skijaRect.width, skijaRect.height, time),
                                         children = null,
                                         localMatrix = null)
        drawIntoCanvas { canvas ->
          canvas.drawRect(size.toRect(), paint.asComposePaint())
        }
      }
    }
  }
}

internal fun shaderExample(): Gallery = gallery("Shader effects", NoriaExamples.sourceCodeForFile("Shaders.kt")) {
  example("Space") {
    Column {
      withModifier(Modifier.sizeIn(maxWidth = 400.dp, maxHeight = 400.dp)) {
        drawWithShader(spaceSkSl)
      }
    }
  }
  example("Snow") {
    Column {
      withModifier(Modifier.sizeIn(maxWidth = 400.dp, maxHeight = 400.dp).background(Color.Black)) {
        drawWithShader(snowSkSl)
      }
    }
  }
}
