package org.genemania.engine.core.integration.gram;

import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.CoAnnotationSet;
import org.genemania.engine.core.data.DatasetInfo;
import org.genemania.engine.core.integration.FeatureList;
import org.genemania.engine.core.integration.FeatureLoader;
import org.genemania.engine.exception.CancellationException;
import org.genemania.engine.matricks.SymMatrix;
import org.genemania.exception.ApplicationException;
import org.genemania.util.ProgressReporter;

/* loaded from: input_file:org/genemania/engine/core/integration/gram/BasicGramBuilder.class */
public class BasicGramBuilder {
    private static Logger logger = Logger.getLogger(BasicGramBuilder.class);
    DataCache cache;
    String namespace;
    long organismId;
    ProgressReporter progress;

    public BasicGramBuilder(DataCache dataCache, String str, long j, ProgressReporter progressReporter) {
        this.cache = dataCache;
        this.namespace = str;
        this.organismId = j;
        this.progress = progressReporter;
    }

    public DenseMatrix buildBasicKtK(FeatureList featureList, ProgressReporter progressReporter) throws ApplicationException {
        checkFeatureList(featureList, true);
        int size = featureList.size();
        int length = this.cache.getNodeIds(this.organismId).getNodeIds().length;
        DenseMatrix denseMatrix = new DenseMatrix(size, size);
        FeatureLoader featureLoader = new FeatureLoader(this.cache, this.namespace, this.organismId);
        denseMatrix.set(0, 0, length * length);
        for (int i = 1; i < size; i++) {
            SymMatrix load = featureLoader.load(featureList.get(i));
            double elementSum = load.elementSum();
            denseMatrix.set(i, 0, elementSum);
            denseMatrix.set(0, i, elementSum);
            for (int i2 = 1; i2 <= i; i2++) {
                if (this.progress.isCanceled()) {
                    throw new CancellationException();
                }
                double elementMultiplySum = featureLoader.load(featureList.get(i2)).elementMultiplySum(load);
                denseMatrix.set(i, i2, elementMultiplySum);
                denseMatrix.set(i2, i, elementMultiplySum);
            }
        }
        return denseMatrix;
    }

    public DenseMatrix buildKtT(FeatureList featureList, CoAnnotationSet coAnnotationSet, ProgressReporter progressReporter) throws ApplicationException {
        checkFeatureList(featureList, true);
        DatasetInfo datasetInfo = this.cache.getDatasetInfo(this.organismId);
        int i = datasetInfo.getNumCategories()[Constants.getIndexForGoBranch(coAnnotationSet.getGoBranch())];
        int numGenes = datasetInfo.getNumGenes();
        SymMatrix GetCoAnnotationMatrix = coAnnotationSet.GetCoAnnotationMatrix();
        DenseVector GetBHalf = coAnnotationSet.GetBHalf();
        double doubleValue = coAnnotationSet.GetConstant().doubleValue();
        GetCoAnnotationMatrix.setDiag(Constants.DISCRIMINANT_THRESHOLD);
        int size = featureList.size();
        logger.debug("Number of Genes " + numGenes + ", Number of Categories " + i + ", Number of networks: " + (size - 1));
        logger.debug("biasValue: " + (numGenes * numGenes * i));
        DenseMatrix denseMatrix = new DenseMatrix(size, 1);
        denseMatrix.set(0, 0, (MatrixUtils.sum((Vector) GetBHalf) * numGenes) + GetCoAnnotationMatrix.elementSum() + (doubleValue * numGenes * numGenes));
        FeatureLoader featureLoader = new FeatureLoader(this.cache, this.namespace, this.organismId);
        logger.debug("Ktt bias value is " + denseMatrix.get(0, 0));
        for (int i2 = 1; i2 < size; i2++) {
            if (this.progress.isCanceled()) {
                throw new CancellationException();
            }
            denseMatrix.set(i2, 0, computeKttElement(numGenes, featureLoader.load(featureList.get(i2)), GetCoAnnotationMatrix, GetBHalf, doubleValue));
        }
        return denseMatrix;
    }

    public static double computeKttElement(int i, SymMatrix symMatrix, SymMatrix symMatrix2, DenseVector denseVector, double d) {
        double elementSum = symMatrix.elementSum();
        DenseVector denseVector2 = new DenseVector(i);
        symMatrix.mult(denseVector.getData(), denseVector2.getData());
        return symMatrix.elementMultiplySum(symMatrix2) + MatrixUtils.sum((Vector) denseVector2) + (elementSum * d);
    }

    public DenseMatrix updateBasicKtK(DenseMatrix denseMatrix, FeatureList featureList, FeatureList featureList2, ProgressReporter progressReporter) throws ApplicationException {
        checkFeatureList(featureList, true);
        checkFeatureList(featureList2, false);
        int size = featureList.size();
        int size2 = featureList2.size();
        int i = size + size2;
        logger.debug("allocating new KtK and copying data over");
        DenseMatrix denseMatrix2 = new DenseMatrix(i, i);
        for (int i2 = 0; i2 < size; i2++) {
            for (int i3 = 0; i3 < size; i3++) {
                denseMatrix2.set(i2, i3, denseMatrix.get(i2, i3));
            }
        }
        FeatureLoader featureLoader = new FeatureLoader(this.cache, this.namespace, this.organismId);
        logger.debug("preloading new features");
        SymMatrix[] symMatrixArr = new SymMatrix[size2];
        for (int i4 = 0; i4 < size2; i4++) {
            symMatrixArr[i4] = featureLoader.load(featureList2.get(i4));
        }
        logger.debug(String.format("computing products between %d new and %d old features", Integer.valueOf(size2), Integer.valueOf(size)));
        for (int i5 = 1; i5 < size; i5++) {
            SymMatrix load = featureLoader.load(featureList.get(i5));
            for (int i6 = 0; i6 < size2; i6++) {
                if (this.progress.isCanceled()) {
                    throw new CancellationException();
                }
                double elementMultiplySum = load.elementMultiplySum(symMatrixArr[i6]);
                denseMatrix2.set(i5, i6 + size, elementMultiplySum);
                denseMatrix2.set(i6 + size, i5, elementMultiplySum);
            }
        }
        logger.debug(String.format("computing products between %d new features, and their biases", Integer.valueOf(size2)));
        for (int i7 = 0; i7 < size2; i7++) {
            SymMatrix symMatrix = symMatrixArr[i7];
            double elementSum = symMatrix.elementSum();
            denseMatrix2.set(i7 + size, 0, elementSum);
            denseMatrix2.set(0, i7 + size, elementSum);
            for (int i8 = 0; i8 < size2; i8++) {
                if (this.progress.isCanceled()) {
                    throw new CancellationException();
                }
                double elementMultiplySum2 = symMatrix.elementMultiplySum(featureLoader.load(featureList2.get(i8)));
                denseMatrix2.set(i7 + size, i8 + size, elementMultiplySum2);
                denseMatrix2.set(i8 + size, i7 + size, elementMultiplySum2);
            }
        }
        return denseMatrix2;
    }

    public DenseMatrix updateKtT(DenseMatrix denseMatrix, FeatureList featureList, FeatureList featureList2, CoAnnotationSet coAnnotationSet, ProgressReporter progressReporter) throws ApplicationException {
        checkFeatureList(featureList, true);
        checkFeatureList(featureList2, false);
        int numGenes = this.cache.getDatasetInfo(this.organismId).getNumGenes();
        SymMatrix GetCoAnnotationMatrix = coAnnotationSet.GetCoAnnotationMatrix();
        DenseVector GetBHalf = coAnnotationSet.GetBHalf();
        double doubleValue = coAnnotationSet.GetConstant().doubleValue();
        int size = featureList.size();
        int size2 = featureList2.size();
        DenseMatrix denseMatrix2 = new DenseMatrix(size + size2, 1);
        for (int i = 0; i < size; i++) {
            denseMatrix2.set(i, 0, denseMatrix.get(i, 0));
        }
        FeatureLoader featureLoader = new FeatureLoader(this.cache, this.namespace, this.organismId);
        for (int i2 = 0; i2 < size2; i2++) {
            if (this.progress.isCanceled()) {
                throw new CancellationException();
            }
            denseMatrix2.set(i2 + size, 0, computeKttElement(numGenes, featureLoader.load(featureList2.get(i2)), GetCoAnnotationMatrix, GetBHalf, doubleValue));
        }
        return denseMatrix2;
    }

    public static void checkFeatureList(FeatureList featureList, boolean z) throws ApplicationException {
        if (z) {
            if (featureList.get(0).getType() != Constants.NetworkType.BIAS) {
                throw new ApplicationException("must include bias in first row/col");
            }
        } else if (featureList.get(0).getType() == Constants.NetworkType.BIAS) {
            throw new ApplicationException("must not include bias in first row/col");
        }
    }
}
