package org.genemania.engine.actions;

import org.genemania.dto.AttributeDto;
import org.genemania.dto.InteractionVisitor;
import org.genemania.dto.NetworkCombinationRequestDto;
import org.genemania.dto.NetworkCombinationResponseDto;
import org.genemania.dto.NetworkDto;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.data.NodeIds;
import org.genemania.engine.core.integration.CombineNetworksOnly;
import org.genemania.engine.core.integration.CombinedKernelBuilder;
import org.genemania.engine.core.integration.Feature;
import org.genemania.engine.core.integration.FeatureWeightMap;
import org.genemania.engine.matricks.MatrixCursor;
import org.genemania.engine.matricks.SymMatrix;
import org.genemania.engine.matricks.custom.MultiOPCSymMatrix;
import org.genemania.engine.matricks.custom.OuterProductComboSymMatrix;
import org.genemania.exception.ApplicationException;

/* loaded from: input_file:org/genemania/engine/actions/CombineNetworks.class */
public class CombineNetworks {
    DataCache cache;
    NetworkCombinationRequestDto request;

    public CombineNetworks(DataCache dataCache, NetworkCombinationRequestDto networkCombinationRequestDto) {
        this.cache = dataCache;
        this.request = networkCombinationRequestDto;
    }

    public NetworkCombinationResponseDto process() throws ApplicationException {
        visit(combine());
        return new NetworkCombinationResponseDto();
    }

    private SymMatrix combine() throws ApplicationException {
        FeatureWeightMap buildWeightMap = buildWeightMap();
        return new CombinedKernelBuilder(this.cache).build(this.request.getOrganismId(), this.request.getNamespace(), CombineNetworksOnly.combine(buildWeightMap, this.request.getNamespace(), this.request.getOrganismId(), this.cache, this.request.getProgressReporter()), buildWeightMap);
    }

    private FeatureWeightMap buildWeightMap() {
        FeatureWeightMap featureWeightMap = new FeatureWeightMap();
        for (NetworkDto networkDto : this.request.getNetworks()) {
            if (networkDto.getWeight() != Constants.DISCRIMINANT_THRESHOLD) {
                featureWeightMap.put(new Feature(Constants.NetworkType.SPARSE_MATRIX, 1L, networkDto.getId()), Double.valueOf(networkDto.getWeight()));
            }
        }
        for (AttributeDto attributeDto : this.request.getAttributes()) {
            if (attributeDto.getWeight() != Constants.DISCRIMINANT_THRESHOLD) {
                featureWeightMap.put(new Feature(Constants.NetworkType.ATTRIBUTE_VECTOR, attributeDto.getGroupId(), attributeDto.getId()), Double.valueOf(attributeDto.getWeight()));
            }
        }
        return featureWeightMap;
    }

    private void visit(SymMatrix symMatrix) throws ApplicationException {
        if (symMatrix instanceof MultiOPCSymMatrix) {
            visitMultiOPCSymMatrix((MultiOPCSymMatrix) symMatrix);
        } else {
            visitSymMatrix(symMatrix);
        }
    }

    private void visitMultiOPCSymMatrix(MultiOPCSymMatrix multiOPCSymMatrix) throws ApplicationException {
        OuterProductComboSymMatrix[] combos = multiOPCSymMatrix.getCombos();
        SymMatrix matrix = multiOPCSymMatrix.getMatrix();
        for (OuterProductComboSymMatrix outerProductComboSymMatrix : combos) {
            matrix.add(1.0d, outerProductComboSymMatrix);
        }
        visitSymMatrix(matrix);
    }

    private void visitSymMatrix(SymMatrix symMatrix) throws ApplicationException {
        InteractionVisitor interactionVistor = this.request.getInteractionVistor();
        NodeIds nodeIds = this.cache.getNodeIds(this.request.getOrganismId());
        MatrixCursor cursor = symMatrix.cursor();
        while (cursor.next()) {
            int row = cursor.row();
            int col = cursor.col();
            if (row > col) {
                interactionVistor.visit(nodeIds.getIdForIndex(row), nodeIds.getIdForIndex(col), cursor.val());
            }
        }
    }
}
