package org.genemania.engine.core.integration.attribute;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.actions.ComputeEnrichment;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.AttributeData;
import org.genemania.engine.core.data.AttributeGroups;
import org.genemania.engine.core.data.NodeIds;
import org.genemania.engine.core.integration.Feature;
import org.genemania.engine.core.integration.FeatureList;
import org.genemania.engine.core.utils.ObjectSelector;
import org.genemania.engine.matricks.Matrix;
import org.genemania.exception.ApplicationException;

/* loaded from: input_file:org/genemania/engine/core/integration/attribute/QueryEnrichedAttributeScorer.class */
public class QueryEnrichedAttributeScorer implements IAttributeScorer {
    private static Logger logger = Logger.getLogger(QueryEnrichedAttributeScorer.class);
    private static int MIN_NUM_TOTAL_GENES_PER_ATTRIBUTE = 2;
    DataCache cache;
    Vector labels;
    int minQueryGenesPerAttribute;

    public QueryEnrichedAttributeScorer(DataCache dataCache, Vector vector, int i) {
        this.cache = dataCache;
        this.labels = vector;
        this.minQueryGenesPerAttribute = i;
    }

    void logAttributeCounts(String str, long j, long j2) throws ApplicationException {
        org.genemania.engine.matricks.Vector columnSums = this.cache.getAttributeData(str, j, j2).getData().columnSums();
        for (int i = 0; i < columnSums.getSize(); i++) {
            if (columnSums.get(i) > Constants.DISCRIMINANT_THRESHOLD) {
                logger.debug(String.format("attribute %d has col sum %f", Integer.valueOf(i), Double.valueOf(columnSums.get(i))));
            }
        }
    }

    @Override // org.genemania.engine.core.integration.attribute.IAttributeScorer
    public ObjectSelector<Feature> scoreAttributes(String str, long j, long j2) throws ApplicationException {
        AttributeData attributeData = this.cache.getAttributeData(str, j, j2);
        DenseVector denseVector = new DenseVector(attributeData.getData().numCols());
        DenseVector computeSelectionMask = computeSelectionMask(j, this.labels);
        Matrix data = attributeData.getData();
        int numCols = data.numCols();
        int numRows = data.numRows();
        DenseVector denseVector2 = new DenseVector(numCols);
        data.columnSums(denseVector2.getData());
        DenseVector denseVector3 = new DenseVector(numCols);
        data.transMult(computeSelectionMask.getData(), denseVector3.getData());
        computePVals(data, numCols, numRows, denseVector2, denseVector3, computeSelectionMask, denseVector);
        return buildList(str, j, j2, denseVector, denseVector3, denseVector2);
    }

    private ObjectSelector<Feature> buildList(String str, long j, long j2, DenseVector denseVector, DenseVector denseVector2, DenseVector denseVector3) throws ApplicationException {
        ObjectSelector<Feature> objectSelector = new ObjectSelector<>();
        ArrayList<Long> arrayList = this.cache.getAttributeGroups(str, j).getAttributeGroups().get(Long.valueOf(j2));
        for (int i = 0; i < arrayList.size(); i++) {
            long longValue = arrayList.get(i).longValue();
            double d = denseVector2.get(i);
            double d2 = denseVector3.get(i);
            if (denseVector.get(i) < 0.99999d && d >= this.minQueryGenesPerAttribute && d2 > MIN_NUM_TOTAL_GENES_PER_ATTRIBUTE) {
                objectSelector.add((ObjectSelector<Feature>) new Feature(Constants.NetworkType.ATTRIBUTE_VECTOR, j2, longValue), Double.valueOf(denseVector.get(i)));
            }
        }
        logger.debug(String.format("ranked %d attributes by enrichment based on query list", Integer.valueOf(objectSelector.size())));
        return objectSelector;
    }

    public static FeatureList buildFeatureList(AttributeGroups attributeGroups) {
        FeatureList featureList = new FeatureList();
        Iterator<Long> it = attributeGroups.getAttributeGroups().keySet().iterator();
        while (it.hasNext()) {
            long longValue = it.next().longValue();
            Iterator<Long> it2 = attributeGroups.getAttributeGroups().get(Long.valueOf(longValue)).iterator();
            while (it2.hasNext()) {
                featureList.add(new Feature(Constants.NetworkType.ATTRIBUTE_VECTOR, longValue, it2.next().longValue()));
            }
        }
        return featureList;
    }

    DenseVector computeSelectionMask(long j, Vector vector) throws ApplicationException {
        DenseVector denseVector = new DenseVector(this.cache.getNodeIds(j).getNodeIds().length);
        for (int i = 0; i < vector.size(); i++) {
            if (vector.get(i) == 1.0d) {
                denseVector.set(i, 1.0d);
            }
        }
        return denseVector;
    }

    DenseVector computeSelectionMask(long j, Collection<Long> collection) throws ApplicationException {
        NodeIds nodeIds = this.cache.getNodeIds(j);
        DenseVector denseVector = new DenseVector(nodeIds.getNodeIds().length);
        Iterator<Long> it = collection.iterator();
        while (it.hasNext()) {
            denseVector.set(nodeIds.getIndexForId(it.next().longValue()), 1.0d);
        }
        return denseVector;
    }

    void computePVals(Matrix matrix, int i, int i2, DenseVector denseVector, DenseVector denseVector2, DenseVector denseVector3, DenseVector denseVector4) {
        int countMatches = MatrixUtils.countMatches(denseVector3, 1.0d);
        for (int i3 = 0; i3 < i; i3++) {
            denseVector4.set(i3, ComputeEnrichment.computeCumulHyperGeo(Math.round(denseVector2.get(i3)), i2, countMatches, Math.round(denseVector.get(i3))));
        }
    }
}
