/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.completion.full.line.local.generation.processor;

import java.util.LinkedHashSet;
import java.util.Set;
import kotlin.Metadata;
import kotlin._Assertions;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SourceDebugExtension;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.completion.full.line.local.generation.UtilsKt;
import org.jetbrains.completion.full.line.local.generation.generation.SearchState;
import org.jetbrains.completion.full.line.local.generation.processor.DistributionProcessor;

@Metadata(mv={1, 9, 0}, k=1, xi=48, d1={"\u00000\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0006\n\u0000\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0000\n\u0002\u0010\u0011\n\u0002\u0010\u0013\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018\u00002\u00020\u0001B\u0015\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u00a2\u0006\u0002\u0010\u0006J#\u0010\u0007\u001a\u00020\b2\f\u0010\t\u001a\b\u0012\u0004\u0012\u00020\u000b0\n2\u0006\u0010\f\u001a\u00020\rH\u0016\u00a2\u0006\u0002\u0010\u000eR\u000e\u0010\u0004\u001a\u00020\u0005X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004\u00a2\u0006\u0002\n\u0000\u00a8\u0006\u000f"}, d2={"Lorg/jetbrains/completion/full/line/local/generation/processor/CumulativeProbabilityConstraint;", "Lorg/jetbrains/completion/full/line/local/generation/processor/DistributionProcessor;", "threshold", "", "size", "", "(DI)V", "process", "", "logProbDistribution", "", "", "state", "Lorg/jetbrains/completion/full/line/local/generation/generation/SearchState;", "([[DLorg/jetbrains/completion/full/line/local/generation/generation/SearchState;)V", "intellij.fullLine.local"})
@SourceDebugExtension(value={"SMAP\nCumulativeProbabilityConstraint.kt\nKotlin\n*S Kotlin\n*F\n+ 1 CumulativeProbabilityConstraint.kt\norg/jetbrains/completion/full/line/local/generation/processor/CumulativeProbabilityConstraint\n+ 2 fake.kt\nkotlin/jvm/internal/FakeKt\n*L\n1#1,35:1\n1#2:36\n*E\n"})
public final class CumulativeProbabilityConstraint
implements DistributionProcessor {
    private final double threshold;
    private final int size;

    public CumulativeProbabilityConstraint(double threshold, int size) {
        this.threshold = threshold;
        this.size = size;
    }

    @Override
    public void process(@NotNull double[][] logProbDistribution, @NotNull SearchState state) {
        boolean bl;
        Intrinsics.checkNotNullParameter((Object)logProbDistribution, (String)"logProbDistribution");
        Intrinsics.checkNotNullParameter((Object)state, (String)"state");
        if (state.getLength() > 0) {
            return;
        }
        boolean bl2 = bl = ((Object[])logProbDistribution).length == 1;
        if (_Assertions.ENABLED && !bl) {
            boolean $i$a$-assert-CumulativeProbabilityConstraint$process$22 = false;
            String $i$a$-assert-CumulativeProbabilityConstraint$process$22 = "The distribution must have only one beam at the first iteration of Beam Search";
            throw new AssertionError((Object)$i$a$-assert-CumulativeProbabilityConstraint$process$22);
        }
        double[] logProbs = logProbDistribution[0];
        double cumSum = 0.0;
        int[] topIndices = UtilsKt.topk1d(logProbs, this.size);
        Set remainingIndices = new LinkedHashSet();
        for (int i : topIndices) {
            if (cumSum >= this.threshold) break;
            cumSum += Math.exp(logProbs[i]);
            remainingIndices.add(i);
        }
        if (cumSum >= this.threshold) {
            int n = logProbs.length;
            for (int i = 0; i < n; ++i) {
                if (remainingIndices.contains(i)) continue;
                logProbs[i] = Double.NEGATIVE_INFINITY;
            }
        }
    }
}

