package com.google.autofill.detection.ml;

import com.google.autofill.detection.ml.ModelConfig;
import defpackage.bnas;
import defpackage.bnax;
import defpackage.bndd;
import defpackage.bnle;
import defpackage.kxb;
import defpackage.kxq;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import java.util.function.IntFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/* compiled from: :com.google.android.gms@200914037@20.09.14 (120400-300565878) */
/* loaded from: classes5.dex */
public class Model {
    private final ModelConfig.FieldConfig fieldConfig;
    private final boolean isLiteModel;
    private final NeuralNetwork neuralNetwork;
    private final boolean randomizeSignalOrder;
    private final RuntimeConfiguration runtimeConfiguration;
    private final ModelConfig.SignalConfig signalConfig;

    /* compiled from: :com.google.android.gms@200914037@20.09.14 (120400-300565878) */
    /* loaded from: classes5.dex */
    public final class Result {
        private final bnax orderedFieldPredictions;

        /* compiled from: :com.google.android.gms@200914037@20.09.14 (120400-300565878) */
        /* loaded from: classes5.dex */
        public abstract class FieldPrediction {
            public static FieldPrediction of(kxq kxqVar, float f) {
                return new AutoValue_Model_Result_FieldPrediction(kxqVar, f);
            }

            public abstract float getConfidence();

            public abstract kxq getType();
        }

        public Result(List list) {
            this.orderedFieldPredictions = bnax.a((Collection) list.stream().sorted(Comparator.comparingDouble(Model$Result$$Lambda$0.$instance).reversed()).collect(Collectors.toList()));
        }

        public bnax getFieldPredictions() {
            return this.orderedFieldPredictions;
        }

        /* JADX WARN: Multi-variable type inference failed */
        public bnax getFieldPredictionsAbove(float f) {
            bnas j = bnax.j();
            bnle it = this.orderedFieldPredictions.iterator();
            while (it.hasNext()) {
                FieldPrediction fieldPrediction = (FieldPrediction) it.next();
                if (fieldPrediction.getConfidence() < f) {
                    break;
                }
                j.c(fieldPrediction);
            }
            return j.a();
        }
    }

    public Model(ModelConfig modelConfig, NeuralNetwork neuralNetwork) {
        this(modelConfig, neuralNetwork, true);
    }

    private Model(ModelConfig modelConfig, NeuralNetwork neuralNetwork, boolean z) {
        RuntimeConfiguration snapshot = RuntimeConfiguration.getSnapshot();
        this.runtimeConfiguration = snapshot;
        this.randomizeSignalOrder = snapshot.shouldRandomizeSignalOrder();
        if (z) {
            assertCompatible(modelConfig, neuralNetwork);
        }
        this.signalConfig = modelConfig.getSignalConfig();
        this.fieldConfig = modelConfig.getFieldConfig();
        this.neuralNetwork = neuralNetwork;
        this.isLiteModel = modelConfig.isLiteModel();
    }

    private static void assertCompatible(ModelConfig modelConfig, NeuralNetwork neuralNetwork) {
        Layer layer = (Layer) neuralNetwork.getLayers().get(0);
        Layer layer2 = (Layer) bndd.d(neuralNetwork.getLayers());
        if (modelConfig.getSignalConfig().getSignals().size() != layer.inputSize() || layer2.outputSize() != modelConfig.getFieldConfig().numberOfSupportedTypes()) {
            throw new IllegalArgumentException("Model config is not compatible with neural network.");
        }
    }

    private Result buildResult(final Matrix matrix) {
        return new Result((List) IntStream.range(0, matrix.cols()).mapToObj(new IntFunction(this, matrix) { // from class: com.google.autofill.detection.ml.Model$$Lambda$0
            private final Model arg$1;
            private final Matrix arg$2;

            {
                this.arg$1 = this;
                this.arg$2 = matrix;
            }

            @Override // java.util.function.IntFunction
            public Object apply(int i) {
                return this.arg$1.lambda$buildResult$0$Model(this.arg$2, i);
            }
        }).collect(Collectors.toList()));
    }

    @Deprecated
    public static Model obsoleteCreate(ModelConfig modelConfig, NeuralNetwork neuralNetwork) {
        return new Model(modelConfig, neuralNetwork, false);
    }

    public ModelConfig.FieldConfig getFieldConfig() {
        return this.fieldConfig;
    }

    public NeuralNetwork getNeuralNetwork() {
        return this.neuralNetwork;
    }

    public ModelConfig.SignalConfig getSignalConfig() {
        return this.signalConfig;
    }

    public boolean isLiteModel() {
        return this.isLiteModel;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final /* synthetic */ Result.FieldPrediction lambda$buildResult$0$Model(Matrix matrix, int i) {
        return Result.FieldPrediction.of(this.fieldConfig.getTypeAtIndex(i), matrix.get(0, i));
    }

    public Result predict(kxb kxbVar) {
        try {
            bnax signals = this.signalConfig.getSignals();
            ArrayMatrix arrayMatrix = new ArrayMatrix(1, signals.size());
            if (this.randomizeSignalOrder) {
                int nextInt = new Random().nextInt(signals.size());
                for (int i = 0; i < signals.size(); i++) {
                    int size = (i + nextInt) % signals.size();
                    arrayMatrix.set(0, size, (float) ((Signal) signals.get(size)).generate(kxbVar));
                }
            } else {
                for (int i2 = 0; i2 < signals.size(); i2++) {
                    arrayMatrix.set(0, i2, (float) ((Signal) signals.get(i2)).generate(kxbVar));
                }
            }
            return buildResult(this.neuralNetwork.execute(arrayMatrix));
        } catch (ExecutionException e) {
            throw e;
        } catch (Throwable th) {
            throw new ExecutionException(th);
        }
    }

    public void reset() {
        bnle it = this.signalConfig.getSignals().iterator();
        while (it.hasNext()) {
            ((Signal) it.next()).reset();
        }
    }
}
