/*
 * Decompiled with CFR 0.152.
 */
package weka.filters.unsupervised.instance;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.rules.ZeroR;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedAttributesHandler;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.UnsupervisedFilter;

public class RemoveMisclassified
extends Filter
implements UnsupervisedFilter,
OptionHandler,
WeightedAttributesHandler,
WeightedInstancesHandler {
    static final long serialVersionUID = 5469157004717663171L;
    protected Classifier m_cleansingClassifier = new ZeroR();
    protected int m_classIndex = -1;
    protected int m_numOfCrossValidationFolds = 0;
    protected int m_numOfCleansingIterations = 0;
    protected double m_numericClassifyThreshold = 0.1;
    protected boolean m_invertMatching = false;
    protected boolean m_firstBatchFinished = false;

    @Override
    public Capabilities getCapabilities() {
        Capabilities result;
        if (this.getClassifier() == null) {
            result = super.getCapabilities();
            result.disableAll();
        } else {
            result = this.getClassifier().getCapabilities();
        }
        result.setMinimumNumberInstances(0);
        return result;
    }

    @Override
    public boolean setInputFormat(Instances instanceInfo) throws Exception {
        super.setInputFormat(instanceInfo);
        this.setOutputFormat(instanceInfo);
        this.m_firstBatchFinished = false;
        return true;
    }

    private Instances cleanseTrain(Instances data) throws Exception {
        Instances buildSet = new Instances(data);
        Instances inverseSet = new Instances(data, data.numInstances());
        int count = 0;
        int iterations = 0;
        int classIndex = this.m_classIndex;
        if (classIndex < 0) {
            classIndex = data.classIndex();
        }
        if (classIndex < 0) {
            classIndex = data.numAttributes() - 1;
        }
        while (count != buildSet.numInstances() && (this.m_numOfCleansingIterations <= 0 || ++iterations <= this.m_numOfCleansingIterations)) {
            count = buildSet.numInstances();
            buildSet.setClassIndex(classIndex);
            this.m_cleansingClassifier.buildClassifier(buildSet);
            Instances temp = new Instances(buildSet, buildSet.numInstances());
            for (int i = 0; i < buildSet.numInstances(); ++i) {
                Instance inst = buildSet.instance(i);
                double ans = this.m_cleansingClassifier.classifyInstance(inst);
                if (buildSet.classAttribute().isNumeric()) {
                    if (ans >= inst.classValue() - this.m_numericClassifyThreshold && ans <= inst.classValue() + this.m_numericClassifyThreshold) {
                        temp.add(inst);
                        continue;
                    }
                    if (!this.m_invertMatching) continue;
                    inverseSet.add(inst);
                    continue;
                }
                if (ans == inst.classValue()) {
                    temp.add(inst);
                    continue;
                }
                if (!this.m_invertMatching) continue;
                inverseSet.add(inst);
            }
            buildSet = temp;
        }
        if (this.m_invertMatching) {
            inverseSet.setClassIndex(data.classIndex());
            return inverseSet;
        }
        buildSet.setClassIndex(data.classIndex());
        return buildSet;
    }

    private Instances cleanseCross(Instances data) throws Exception {
        Instances crossSet = new Instances(data);
        Instances temp = new Instances(data, data.numInstances());
        Instances inverseSet = new Instances(data, data.numInstances());
        int count = 0;
        int iterations = 0;
        int classIndex = this.m_classIndex;
        if (classIndex < 0) {
            classIndex = data.classIndex();
        }
        if (classIndex < 0) {
            classIndex = data.numAttributes() - 1;
        }
        while (count != crossSet.numInstances() && crossSet.numInstances() >= this.m_numOfCrossValidationFolds) {
            count = crossSet.numInstances();
            if (this.m_numOfCleansingIterations > 0 && ++iterations > this.m_numOfCleansingIterations) break;
            crossSet.setClassIndex(classIndex);
            if (crossSet.classAttribute().isNominal()) {
                crossSet.stratify(this.m_numOfCrossValidationFolds);
            }
            temp = new Instances(crossSet, crossSet.numInstances());
            for (int fold = 0; fold < this.m_numOfCrossValidationFolds; ++fold) {
                Instances train = crossSet.trainCV(this.m_numOfCrossValidationFolds, fold);
                this.m_cleansingClassifier.buildClassifier(train);
                Instances test = crossSet.testCV(this.m_numOfCrossValidationFolds, fold);
                for (int i = 0; i < test.numInstances(); ++i) {
                    Instance inst = test.instance(i);
                    double ans = this.m_cleansingClassifier.classifyInstance(inst);
                    if (crossSet.classAttribute().isNumeric()) {
                        if (ans >= inst.classValue() - this.m_numericClassifyThreshold && ans <= inst.classValue() + this.m_numericClassifyThreshold) {
                            temp.add(inst);
                            continue;
                        }
                        if (!this.m_invertMatching) continue;
                        inverseSet.add(inst);
                        continue;
                    }
                    if (ans == inst.classValue()) {
                        temp.add(inst);
                        continue;
                    }
                    if (!this.m_invertMatching) continue;
                    inverseSet.add(inst);
                }
            }
            crossSet = temp;
        }
        if (this.m_invertMatching) {
            inverseSet.setClassIndex(data.classIndex());
            return inverseSet;
        }
        crossSet.setClassIndex(data.classIndex());
        return crossSet;
    }

    @Override
    public boolean input(Instance instance) throws Exception {
        if (this.inputFormatPeek() == null) {
            throw new NullPointerException("No input instance format defined");
        }
        if (this.m_NewBatch) {
            this.resetQueue();
            this.m_NewBatch = false;
        }
        if (this.m_firstBatchFinished) {
            this.push(instance);
            return true;
        }
        this.bufferInput(instance);
        return false;
    }

    @Override
    public boolean batchFinished() throws Exception {
        if (this.getInputFormat() == null) {
            throw new IllegalStateException("No input instance format defined");
        }
        if (!this.m_firstBatchFinished) {
            Instances filtered = this.m_numOfCrossValidationFolds < 2 ? this.cleanseTrain(this.getInputFormat()) : this.cleanseCross(this.getInputFormat());
            for (int i = 0; i < filtered.numInstances(); ++i) {
                this.push(filtered.instance(i), false);
            }
            this.m_firstBatchFinished = true;
            this.flushInput();
        }
        this.m_NewBatch = true;
        return this.numPendingOutput() != 0;
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>(6);
        newVector.addElement(new Option("\tFull class name of classifier to use, followed\n\tby scheme options. eg:\n\t\t\"weka.classifiers.bayes.NaiveBayes -D\"\n\t(default: weka.classifiers.rules.ZeroR)", "W", 1, "-W <classifier specification>"));
        newVector.addElement(new Option("\tAttribute on which misclassifications are based.\n\tIf < 0 will use any current set class or default to the last attribute.", "C", 1, "-C <class index>"));
        newVector.addElement(new Option("\tThe number of folds to use for cross-validation cleansing.\n\t(<2 = no cross-validation - default).", "F", 1, "-F <number of folds>"));
        newVector.addElement(new Option("\tThreshold for the max error when predicting numeric class.\n\t(Value should be >= 0, default = 0.1).", "T", 1, "-T <threshold>"));
        newVector.addElement(new Option("\tThe maximum number of cleansing iterations to perform.\n\t(<1 = until fully cleansed - default)", "I", 1, "-I"));
        newVector.addElement(new Option("\tInvert the match so that correctly classified instances are discarded.\n", "V", 0, "-V"));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String[] classifierSpec;
        String classifierString = Utils.getOption('W', options);
        if (classifierString.length() == 0) {
            classifierString = ZeroR.class.getName();
        }
        if ((classifierSpec = Utils.splitOptions(classifierString)).length == 0) {
            throw new Exception("Invalid classifier specification string");
        }
        String classifierName = classifierSpec[0];
        classifierSpec[0] = "";
        this.setClassifier(AbstractClassifier.forName(classifierName, classifierSpec));
        String cString = Utils.getOption('C', options);
        if (cString.length() != 0) {
            this.setClassIndex(new Double(cString).intValue());
        } else {
            this.setClassIndex(-1);
        }
        String fString = Utils.getOption('F', options);
        if (fString.length() != 0) {
            this.setNumFolds(new Double(fString).intValue());
        } else {
            this.setNumFolds(0);
        }
        String tString = Utils.getOption('T', options);
        if (tString.length() != 0) {
            this.setThreshold(new Double(tString));
        } else {
            this.setThreshold(0.1);
        }
        String iString = Utils.getOption('I', options);
        if (iString.length() != 0) {
            this.setMaxIterations(new Double(iString).intValue());
        } else {
            this.setMaxIterations(0);
        }
        if (Utils.getFlag('V', options)) {
            this.setInvert(true);
        } else {
            this.setInvert(false);
        }
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> options = new Vector<String>();
        options.add("-W");
        options.add("" + this.getClassifierSpec());
        options.add("-C");
        options.add("" + this.getClassIndex());
        options.add("-F");
        options.add("" + this.getNumFolds());
        options.add("-T");
        options.add("" + this.getThreshold());
        options.add("-I");
        options.add("" + this.getMaxIterations());
        if (this.getInvert()) {
            options.add("-V");
        }
        return options.toArray(new String[0]);
    }

    public String globalInfo() {
        return "A filter that removes instances which are incorrectly classified. Useful for removing outliers.";
    }

    public String classifierTipText() {
        return "The classifier upon which to base the misclassifications.";
    }

    public void setClassifier(Classifier classifier) {
        this.m_cleansingClassifier = classifier;
    }

    public Classifier getClassifier() {
        return this.m_cleansingClassifier;
    }

    protected String getClassifierSpec() {
        Classifier c = this.getClassifier();
        if (c instanceof OptionHandler) {
            return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)((Object)c)).getOptions());
        }
        return c.getClass().getName();
    }

    public String classIndexTipText() {
        return "Index of the class upon which to base the misclassifications. If < 0 will use any current set class or default to the last attribute.";
    }

    public void setClassIndex(int classIndex) {
        this.m_classIndex = classIndex;
    }

    public int getClassIndex() {
        return this.m_classIndex;
    }

    public String numFoldsTipText() {
        return "The number of cross-validation folds to use. If < 2 then no cross-validation will be performed.";
    }

    public void setNumFolds(int numOfFolds) {
        this.m_numOfCrossValidationFolds = numOfFolds;
    }

    public int getNumFolds() {
        return this.m_numOfCrossValidationFolds;
    }

    public String thresholdTipText() {
        return "Threshold for the max allowable error when predicting a numeric class. Should be >= 0.";
    }

    public void setThreshold(double threshold) {
        this.m_numericClassifyThreshold = threshold;
    }

    public double getThreshold() {
        return this.m_numericClassifyThreshold;
    }

    public String maxIterationsTipText() {
        return "The maximum number of iterations to perform. < 1 means filter will go until fully cleansed.";
    }

    public void setMaxIterations(int iterations) {
        this.m_numOfCleansingIterations = iterations;
    }

    public int getMaxIterations() {
        return this.m_numOfCleansingIterations;
    }

    public String invertTipText() {
        return "Whether or not to invert the selection. If true, correctly classified instances will be discarded.";
    }

    public void setInvert(boolean invert) {
        this.m_invertMatching = invert;
    }

    public boolean getInvert() {
        return this.m_invertMatching;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision$");
    }

    public static void main(String[] argv) {
        RemoveMisclassified.runFilter(new RemoveMisclassified(), argv);
    }
}

