package org.genemania.engine.core.integration.gram;

import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.integration.Feature;
import org.genemania.engine.core.integration.FeatureList;
import org.genemania.engine.core.integration.FeatureLoader;
import org.genemania.engine.core.integration.FeatureWeightMap;
import org.genemania.engine.core.integration.Solver;
import org.genemania.engine.exception.CancellationException;
import org.genemania.engine.matricks.Matrix;
import org.genemania.engine.matricks.SymMatrix;
import org.genemania.exception.ApplicationException;
import org.genemania.util.ProgressReporter;

/* loaded from: input_file:org/genemania/engine/core/integration/gram/AutomaticGramBuilder.class */
public class AutomaticGramBuilder {
    private static Logger logger = Logger.getLogger(AutomaticGramBuilder.class);
    DataCache cache;
    String namespace;
    long organismId;
    FeatureList featureList;
    Vector labels;
    ProgressReporter progress;

    public AutomaticGramBuilder(DataCache dataCache, String str, long j, FeatureList featureList, Vector vector, ProgressReporter progressReporter) {
        this.cache = dataCache;
        this.namespace = str;
        this.organismId = j;
        this.featureList = new FeatureList(featureList, true);
        this.labels = vector;
        this.progress = progressReporter;
    }

    public FeatureWeightMap build(ProgressReporter progressReporter) throws ApplicationException {
        this.featureList.validate();
        int[] find = MatrixUtils.find(this.labels, 1.0d);
        int[] find2 = MatrixUtils.find(this.labels, -1.0d);
        int length = find.length;
        int length2 = find2.length;
        int size = this.featureList.size();
        int i = length * (length - 1);
        int i2 = 2 * length * length2;
        double d = 1.0d / (i + i2);
        double d2 = (2.0d * length2) / (length + length2);
        double d3 = ((-2.0d) * length) / (length + length2);
        double d4 = d2 * d2;
        double d5 = d2 * d3;
        DenseMatrix denseMatrix = new DenseMatrix(size, size);
        DenseVector denseVector = new DenseVector(size);
        denseMatrix.set(0, 0, d);
        denseVector.set(0, d * ((d4 * i) + (d5 * i2)));
        Matrix[] matrixArr = new SymMatrix[size];
        Matrix[] matrixArr2 = new Matrix[size];
        FeatureLoader featureLoader = new FeatureLoader(this.cache, this.namespace, this.organismId, false, true);
        for (int i3 = 1; i3 < size; i3++) {
            Feature feature = this.featureList.get(i3);
            matrixArr[i3] = featureLoader.load(feature, find);
            matrixArr2[i3] = featureLoader.load(feature, find, find2);
            double elementSum = matrixArr[i3].elementSum();
            double elementSum2 = matrixArr2[i3].elementSum();
            denseVector.set(i3, (d4 * elementSum) + (2.0d * d5 * elementSum2));
            denseMatrix.set(i3, 0, d * (elementSum + (2.0d * elementSum2)));
            denseMatrix.set(0, i3, denseMatrix.get(i3, 0));
            for (int i4 = 1; i4 <= i3; i4++) {
                if (this.progress.isCanceled()) {
                    throw new CancellationException();
                }
                double elementMultiplySum = Constants.DISCRIMINANT_THRESHOLD + matrixArr[i3].elementMultiplySum(matrixArr[i4]) + (2.0d * matrixArr2[i3].elementMultiplySum(matrixArr2[i4]));
                denseMatrix.set(i3, i4, elementMultiplySum);
                denseMatrix.set(i4, i3, elementMultiplySum);
            }
        }
        logger.debug("solving system of size " + this.featureList.size());
        return Solver.solve(denseMatrix, denseVector, this.featureList, this.progress);
    }

    void logAttributeCounts(long j) throws ApplicationException {
        org.genemania.engine.matricks.Vector columnSums = this.cache.getAttributeData(this.namespace, this.organismId, j).getData().columnSums();
        for (int i = 0; i < columnSums.getSize(); i++) {
            if (columnSums.get(i) > Constants.DISCRIMINANT_THRESHOLD) {
                logger.debug(String.format("attribute %d has col sum %f", Integer.valueOf(i), Double.valueOf(columnSums.get(i))));
            }
        }
    }
}
