/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees.m5;

import java.io.Serializable;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.Vector;
import weka.classifiers.trees.m5.SplitEvaluate;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;

public final class RidgeRegressionSplitInfo
implements Cloneable,
Serializable,
SplitEvaluate,
RevisionHandler {
    private static final long serialVersionUID = 3112734895125452171L;
    private int m_position;
    private double m_maxImpurity;
    private int m_splitAttr;
    private double m_splitValue;
    private int m_number;

    public RidgeRegressionSplitInfo(int low, int high, int attr) {
        this.initialize(low, high, attr);
    }

    @Override
    public final SplitEvaluate copy() throws Exception {
        RidgeRegressionSplitInfo s = (RidgeRegressionSplitInfo)this.clone();
        return s;
    }

    public final void initialize(int low, int high, int attr) {
        this.m_number = high - low + 1;
        this.m_position = -1;
        this.m_maxImpurity = -1.7976931348623157E308;
        this.m_splitAttr = attr;
        this.m_splitValue = 0.0;
    }

    public final Matrix getAInverse(Matrix G, double ridge) throws Exception {
        Matrix A = G.copy();
        for (int j = 0; j < A.numRows() - 1; ++j) {
            A.add(j, j, ridge);
        }
        DenseMatrix I = Matrices.identity((int)A.numRows());
        DenseMatrix AI = I.copy();
        A.solve((Matrix)I, (Matrix)AI);
        return AI;
    }

    public final void updateAInverse(Matrix AI, Instance inst, boolean add) throws Exception {
        int classIndex = inst.classIndex();
        int numAttributes = inst.numAttributes();
        DenseVector m = new DenseVector(numAttributes);
        int index = 0;
        for (int j = 0; j < numAttributes; ++j) {
            if (j == classIndex) continue;
            m.set(index++, inst.value(j));
        }
        m.set(numAttributes - 1, 1.0);
        Vector z = AI.mult((Vector)m, (Vector)new DenseVector(numAttributes));
        AI.rank1(add ? -1.0 / (1.0 + z.dot((Vector)m)) : 1.0 / (1.0 - z.dot((Vector)m)), z);
    }

    public final Matrix getG(Instances insts) throws Exception {
        int classIndex = insts.classIndex();
        int numAttributes = insts.numAttributes();
        int numInstances = insts.numInstances();
        DenseMatrix independentTransposed = new DenseMatrix(numAttributes, numInstances);
        for (int i = 0; i < numInstances; ++i) {
            int index = 0;
            for (int j = 0; j < numAttributes; ++j) {
                if (j == classIndex) continue;
                independentTransposed.set(index++, i, insts.instance(i).value(j));
            }
            independentTransposed.set(numAttributes - 1, i, 1.0);
        }
        return new DenseMatrix(numAttributes, numAttributes).rank1((Matrix)independentTransposed);
    }

    public final void updateG(Matrix G, Instance inst, boolean add) throws Exception {
        int classIndex = inst.classIndex();
        int numAttributes = inst.numAttributes();
        DenseVector vals = new DenseVector(numAttributes);
        int index = 0;
        for (int j = 0; j < numAttributes; ++j) {
            if (j == classIndex) continue;
            vals.set(index++, inst.value(j));
        }
        vals.set(numAttributes - 1, 1.0);
        G.rank1(add ? 1.0 : -1.0, (Vector)vals);
    }

    public final Vector getS(Instances insts) throws Exception {
        int classIndex = insts.classIndex();
        int numAttributes = insts.numAttributes();
        int numInstances = insts.numInstances();
        DenseVector S = new DenseVector(numAttributes);
        for (int i = 0; i < numInstances; ++i) {
            int index = 0;
            double classValue = insts.instance(i).classValue();
            for (int j = 0; j < numAttributes; ++j) {
                if (j == classIndex) continue;
                S.add(index++, insts.instance(i).value(j) * classValue);
            }
            S.add(S.size() - 1, classValue);
        }
        return S;
    }

    public final void updateS(Vector S, Instance inst, boolean add) throws Exception {
        int classIndex = inst.classIndex();
        int numAttributes = inst.numAttributes();
        int index = 0;
        double classValue = add ? inst.classValue() : -inst.classValue();
        for (int j = 0; j < numAttributes; ++j) {
            if (j == classIndex) continue;
            S.add(index++, inst.value(j) * classValue);
        }
        S.add(S.size() - 1, classValue);
    }

    public final double getRSS(Matrix G, Matrix AI, Vector S) throws Exception {
        int i;
        Vector AIS = AI.mult(S, (Vector)new DenseVector(S.size()));
        Vector GAIS = G.mult(AIS, (Vector)new DenseVector(AIS.size()));
        Vector AIGAIS = AI.mult(GAIS, (Vector)new DenseVector(GAIS.size()));
        double RSS = 0.0;
        for (i = 0; i < S.size(); ++i) {
            RSS += S.get(i) * AIGAIS.get(i);
        }
        for (i = 0; i < S.size(); ++i) {
            RSS -= 2.0 * S.get(i) * AIS.get(i);
        }
        return RSS;
    }

    @Override
    public final void attrSplit(int attr, Instances insts) throws Exception {
        int low = 0;
        int high = insts.numInstances() - 1;
        double ridge = 0.01;
        this.initialize(low, high, attr);
        if (this.m_number < 4) {
            return;
        }
        int len = this.m_number < 5 ? 1 : this.m_number / 5;
        this.m_position = low;
        Instances leftSubset = new Instances(insts, low, len);
        Instances rightSubset = new Instances(insts, len, this.m_number - len);
        Matrix GL = this.getG(leftSubset);
        Matrix GR = this.getG(rightSubset);
        Matrix AIL = this.getAInverse(GL, ridge);
        Matrix AIR = this.getAInverse(GR, ridge);
        Vector SL = this.getS(leftSubset);
        Vector SR = this.getS(rightSubset);
        for (int i = low + len; i <= high - len - 1; ++i) {
            double currentRSS;
            Instance currentInstance = insts.instance(i);
            Instance nextInstance = insts.instance(i + 1);
            this.updateS(SL, currentInstance, true);
            this.updateS(SR, currentInstance, false);
            this.updateG(GL, currentInstance, true);
            this.updateG(GR, currentInstance, false);
            this.updateAInverse(AIL, currentInstance, true);
            this.updateAInverse(AIR, currentInstance, false);
            double splitCandidate = (currentInstance.value(attr) + nextInstance.value(attr)) * 0.5;
            if (!(splitCandidate < nextInstance.value(attr)) || !(-(currentRSS = this.getRSS(GL, AIL, SL) + this.getRSS(GR, AIR, SR)) > this.m_maxImpurity)) continue;
            this.m_maxImpurity = -currentRSS;
            this.m_splitValue = splitCandidate;
            this.m_position = i;
        }
    }

    @Override
    public double maxImpurity() {
        return this.m_maxImpurity;
    }

    @Override
    public int splitAttr() {
        return this.m_splitAttr;
    }

    @Override
    public int position() {
        return this.m_position;
    }

    @Override
    public double splitValue() {
        return this.m_splitValue;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 10169 $");
    }
}

