/*
 * Decompiled with CFR 0.152.
 */
package io.kinference.ndarray.extensions.broadcasting;

import io.kinference.ndarray.arrays.MutableShortNDArray;
import io.kinference.ndarray.arrays.ShortBinaryOperation;
import io.kinference.ndarray.arrays.ShortNDArray;
import io.kinference.ndarray.extensions.broadcasting.BroadcastingInfo;
import io.kinference.ndarray.extensions.broadcasting.ReshapeViewKt;
import java.util.Arrays;
import kotlin.Metadata;
import kotlin.Unit;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function4;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SourceDebugExtension;
import org.jetbrains.annotations.NotNull;

@Metadata(mv={1, 8, 0}, k=2, xi=48, d1={"\u0000\u0018\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0002\u001a(\u0010\u0000\u001a\u00020\u00012\u0006\u0010\u0002\u001a\u00020\u00032\u0006\u0010\u0004\u001a\u00020\u00032\u0006\u0010\u0005\u001a\u00020\u00012\u0006\u0010\u0006\u001a\u00020\u0007H\u0000\u001a(\u0010\b\u001a\u00020\u00012\u0006\u0010\u0002\u001a\u00020\u00032\u0006\u0010\u0004\u001a\u00020\u00032\u0006\u0010\u0005\u001a\u00020\u00012\u0006\u0010\u0006\u001a\u00020\u0007H\u0002\u00a8\u0006\t"}, d2={"broadcastTwoTensorsShort", "Lio/kinference/ndarray/arrays/MutableShortNDArray;", "left", "Lio/kinference/ndarray/arrays/ShortNDArray;", "right", "dest", "op", "Lio/kinference/ndarray/arrays/ShortBinaryOperation;", "executeWithoutBroadcasting", "ndarray-core"})
@SourceDebugExtension(value={"SMAP\nBroadcastTwoArgumentsShort.kt\nKotlin\n*S Kotlin\n*F\n+ 1 BroadcastTwoArgumentsShort.kt\nio/kinference/ndarray/extensions/broadcasting/BroadcastTwoArgumentsShortKt\n+ 2 fake.kt\nkotlin/jvm/internal/FakeKt\n*L\n1#1,147:1\n1#2:148\n*E\n"})
public final class BroadcastTwoArgumentsShortKt {
    @NotNull
    public static final MutableShortNDArray broadcastTwoTensorsShort(@NotNull ShortNDArray left, @NotNull ShortNDArray right, @NotNull MutableShortNDArray dest, @NotNull ShortBinaryOperation op) {
        Intrinsics.checkNotNullParameter((Object)left, (String)"left");
        Intrinsics.checkNotNullParameter((Object)right, (String)"right");
        Intrinsics.checkNotNullParameter((Object)dest, (String)"dest");
        Intrinsics.checkNotNullParameter((Object)op, (String)"op");
        Object[] objectArray = new ShortNDArray[]{left, right};
        BroadcastingInfo broadcastingInfo = BroadcastingInfo.Companion.create(CollectionsKt.listOf((Object[])objectArray));
        if (!Arrays.equals(dest.getShape(), broadcastingInfo.getDestShape())) {
            boolean bl = false;
            String string2 = "Destination has incorrect shape, expected: " + ArraysKt.joinToString$default((int[])broadcastingInfo.getDestShape(), null, null, null, (int)0, null, null, (int)63, null) + ", actual " + ArraysKt.joinToString$default((int[])dest.getShape(), null, null, null, (int)0, null, null, (int)63, null);
            throw new IllegalArgumentException(string2.toString());
        }
        if (broadcastingInfo.getBroadcastingAxes().isEmpty()) {
            return BroadcastTwoArgumentsShortKt.executeWithoutBroadcasting(left, right, dest, op);
        }
        int totalAxesToBroadcast = broadcastingInfo.getBroadcastAlongLastAxis() ? broadcastingInfo.getBroadcastingAxes().size() - 1 : broadcastingInfo.getBroadcastingAxes().size();
        int[][] nArray = broadcastingInfo.getBroadcastingShapes();
        int[] leftBroadcastingShape = (int[])((Object[])nArray)[0];
        int[] rightBroadcastingShape = (int[])((Object[])nArray)[1];
        int[] destBroadcastingShape = broadcastingInfo.getBroadcastingDestShape();
        int destBlocksInRow = ArraysKt.last((int[])destBroadcastingShape) / dest.getArray().getBlockSize();
        int[] leftOffsets = ReshapeViewKt.makeOffsets(leftBroadcastingShape, ArraysKt.last((int[])leftBroadcastingShape) / left.getArray().getBlockSize());
        int[] rightOffsets = ReshapeViewKt.makeOffsets(rightBroadcastingShape, ArraysKt.last((int[])rightBroadcastingShape) / right.getArray().getBlockSize());
        int[] destOffsets = ReshapeViewKt.makeOffsets(destBroadcastingShape, destBlocksInRow);
        boolean leftIsScalar = broadcastingInfo.getBroadcastAlongLastAxis() && ArraysKt.last((int[])leftBroadcastingShape) == 1;
        boolean rightIsScalar = broadcastingInfo.getBroadcastAlongLastAxis() && ArraysKt.last((int[])rightBroadcastingShape) == 1;
        short[][] leftBlocks = left.getArray().getBlocks();
        short[][] rightBlocks = right.getArray().getBlocks();
        short[][] destBlocks = dest.getArray().getBlocks();
        Function4 leftIsScalarFun2 = (Function4)new Function4<Integer, Integer, Integer, Integer, Unit>(destBroadcastingShape, leftBlocks, destBlocksInRow, destBlocks, rightBlocks, op){
            final /* synthetic */ int[] $destBroadcastingShape;
            final /* synthetic */ short[][] $leftBlocks;
            final /* synthetic */ int $destBlocksInRow;
            final /* synthetic */ short[][] $destBlocks;
            final /* synthetic */ short[][] $rightBlocks;
            final /* synthetic */ ShortBinaryOperation $op;
            {
                this.$destBroadcastingShape = $destBroadcastingShape;
                this.$leftBlocks = $leftBlocks;
                this.$destBlocksInRow = $destBlocksInRow;
                this.$destBlocks = $destBlocks;
                this.$rightBlocks = $rightBlocks;
                this.$op = $op;
                super(4);
            }

            public final void invoke(int leftOffset, int rightOffset, int destOffset, int axisToBroadcastIdx) {
                int shapeIdx = axisToBroadcastIdx * 2;
                int batchSize = this.$destBroadcastingShape[shapeIdx];
                for (int batchIdx = 0; batchIdx < batchSize; ++batchIdx) {
                    short leftScalar = this.$leftBlocks[leftOffset][0];
                    for (int blockIdx = 0; blockIdx < this.$destBlocksInRow; ++blockIdx) {
                        short[] destBlock = this.$destBlocks[destOffset + blockIdx];
                        short[] rightBlock = this.$rightBlocks[rightOffset + blockIdx];
                        int n = destBlock.length;
                        for (int idx = 0; idx < n; ++idx) {
                            destBlock[idx] = this.$op.invoke(leftScalar, rightBlock[idx]);
                        }
                    }
                }
            }
        };
        Function4 rightIsScalarFun2 = (Function4)new Function4<Integer, Integer, Integer, Integer, Unit>(destBroadcastingShape, rightBlocks, destBlocksInRow, destBlocks, leftBlocks, op){
            final /* synthetic */ int[] $destBroadcastingShape;
            final /* synthetic */ short[][] $rightBlocks;
            final /* synthetic */ int $destBlocksInRow;
            final /* synthetic */ short[][] $destBlocks;
            final /* synthetic */ short[][] $leftBlocks;
            final /* synthetic */ ShortBinaryOperation $op;
            {
                this.$destBroadcastingShape = $destBroadcastingShape;
                this.$rightBlocks = $rightBlocks;
                this.$destBlocksInRow = $destBlocksInRow;
                this.$destBlocks = $destBlocks;
                this.$leftBlocks = $leftBlocks;
                this.$op = $op;
                super(4);
            }

            public final void invoke(int leftOffset, int rightOffset, int destOffset, int axisToBroadcastIdx) {
                int shapeIdx = axisToBroadcastIdx * 2;
                int batchSize = this.$destBroadcastingShape[shapeIdx];
                for (int batchIdx = 0; batchIdx < batchSize; ++batchIdx) {
                    short rightScalar = this.$rightBlocks[rightOffset][0];
                    for (int blockIdx = 0; blockIdx < this.$destBlocksInRow; ++blockIdx) {
                        short[] destBlock = this.$destBlocks[destOffset + blockIdx];
                        short[] leftBlock = this.$leftBlocks[leftOffset + blockIdx];
                        int n = destBlock.length;
                        for (int idx = 0; idx < n; ++idx) {
                            destBlock[idx] = this.$op.invoke(leftBlock[idx], rightScalar);
                        }
                    }
                }
            }
        };
        Function4 defaultFun2 = (Function4)new Function4<Integer, Integer, Integer, Integer, Unit>(destBlocksInRow, leftBlocks, rightBlocks, destBlocks, op){
            final /* synthetic */ int $destBlocksInRow;
            final /* synthetic */ short[][] $leftBlocks;
            final /* synthetic */ short[][] $rightBlocks;
            final /* synthetic */ short[][] $destBlocks;
            final /* synthetic */ ShortBinaryOperation $op;
            {
                this.$destBlocksInRow = $destBlocksInRow;
                this.$leftBlocks = $leftBlocks;
                this.$rightBlocks = $rightBlocks;
                this.$destBlocks = $destBlocks;
                this.$op = $op;
                super(4);
            }

            public final void invoke(int leftOffset, int rightOffset, int destOffset, int axisToBroadcastIdx) {
                for (int blockIdx = 0; blockIdx < this.$destBlocksInRow; ++blockIdx) {
                    short[] leftBlock = this.$leftBlocks[leftOffset + blockIdx];
                    short[] rightBlock = this.$rightBlocks[rightOffset + blockIdx];
                    short[] destBlock = this.$destBlocks[destOffset + blockIdx];
                    int n = destBlock.length;
                    for (int idx = 0; idx < n; ++idx) {
                        destBlock[idx] = this.$op.invoke(leftBlock[idx], rightBlock[idx]);
                    }
                }
            }
        };
        Function4 broadcastingFun = leftIsScalar ? leftIsScalarFun2 : (rightIsScalar ? rightIsScalarFun2 : defaultFun2);
        BroadcastTwoArgumentsShortKt.broadcastTwoTensorsShort$broadcast(totalAxesToBroadcast, (Function4<? super Integer, ? super Integer, ? super Integer, ? super Integer, Unit>)broadcastingFun, destBroadcastingShape, leftOffsets, rightOffsets, destOffsets, leftBroadcastingShape, rightBroadcastingShape, 0, 0, 0, 0);
        return dest;
    }

    private static final MutableShortNDArray executeWithoutBroadcasting(ShortNDArray left, ShortNDArray right, MutableShortNDArray dest, ShortBinaryOperation op) {
        short[][] leftBlocks = left.getArray().getBlocks();
        short[][] rightBlocks = right.getArray().getBlocks();
        short[][] destBlocks = dest.getArray().getBlocks();
        int n = ((Object[])destBlocks).length;
        for (int blockIdx = 0; blockIdx < n; ++blockIdx) {
            short[] destBlock = destBlocks[blockIdx];
            short[] leftBlock = leftBlocks[blockIdx];
            short[] rightBlock = rightBlocks[blockIdx];
            int n2 = destBlock.length;
            for (int idx = 0; idx < n2; ++idx) {
                destBlock[idx] = op.invoke(leftBlock[idx], rightBlock[idx]);
            }
        }
        return dest;
    }

    private static final void broadcastTwoTensorsShort$broadcast(int totalAxesToBroadcast, Function4<? super Integer, ? super Integer, ? super Integer, ? super Integer, Unit> broadcastingFun, int[] destBroadcastingShape, int[] leftOffsets, int[] rightOffsets, int[] destOffsets, int[] leftBroadcastingShape, int[] rightBroadcastingShape, int leftOffset, int rightOffset, int destOffset, int axisToBroadcastIdx) {
        if (axisToBroadcastIdx == totalAxesToBroadcast) {
            broadcastingFun.invoke((Object)leftOffset, (Object)rightOffset, (Object)destOffset, (Object)axisToBroadcastIdx);
        } else {
            int shapeIdx = axisToBroadcastIdx * 2;
            int batchSize = destBroadcastingShape[shapeIdx];
            int dimSize = destBroadcastingShape[shapeIdx + 1];
            for (int batchIdx = 0; batchIdx < batchSize; ++batchIdx) {
                int leftBatchOffset = leftOffset + leftOffsets[shapeIdx] * batchIdx;
                int rightBatchOffset = rightOffset + rightOffsets[shapeIdx] * batchIdx;
                int destBatchOffset = destOffset + destOffsets[shapeIdx] * batchIdx;
                for (int dimIdx = 0; dimIdx < dimSize; ++dimIdx) {
                    int leftFullOffset = leftBatchOffset + dimIdx % leftBroadcastingShape[shapeIdx + 1] * leftOffsets[shapeIdx + 1];
                    int rightFullOffset = rightBatchOffset + dimIdx % rightBroadcastingShape[shapeIdx + 1] * rightOffsets[shapeIdx + 1];
                    int destFullOffset = destBatchOffset + dimIdx * destOffsets[shapeIdx + 1];
                    BroadcastTwoArgumentsShortKt.broadcastTwoTensorsShort$broadcast(totalAxesToBroadcast, broadcastingFun, destBroadcastingShape, leftOffsets, rightOffsets, destOffsets, leftBroadcastingShape, rightBroadcastingShape, leftFullOffset, rightFullOffset, destFullOffset, axisToBroadcastIdx + 1);
                }
            }
        }
    }
}

