/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.types;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import java.awt.geom.Point2D;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.logging.Logger;

public class GainRatio
extends RankedFeatureVector {
    private static final Logger logger = MalletLogger.getLogger(GainRatio.class.getName());
    private static final long serialVersionUID = 1L;
    public static final double log2 = Math.log(2.0);
    double[] m_splitPoints;
    double m_baseEntropy;
    LabelVector m_baseLabelDistribution;
    int m_numSplitPointsForBestFeature;
    int m_minNumInsts;

    protected static Object[] calcGainRatios(InstanceList ilist, int[] instIndices, int minNumInsts) {
        int numInsts = instIndices.length;
        Alphabet dataDict = ilist.getDataAlphabet();
        LabelAlphabet targetDict = (LabelAlphabet)ilist.getTargetAlphabet();
        double[] targetCounts = new double[targetDict.size()];
        for (int ii = 0; ii < numInsts; ++ii) {
            Instance inst = (Instance)ilist.get(instIndices[ii]);
            Labeling labeling = inst.getLabeling();
            double labelWeightSum = 0.0;
            for (int ll = 0; ll < labeling.numLocations(); ++ll) {
                int li = labeling.indexAtLocation(ll);
                double labelWeight = labeling.valueAtLocation(ll);
                labelWeightSum += labelWeight;
                int n = li;
                targetCounts[n] = targetCounts[n] + labelWeight;
            }
            assert (Maths.almostEquals(labelWeightSum, 1.0));
        }
        double[] targetDistribution = new double[targetDict.size()];
        double baseEntropy = 0.0;
        for (int ci = 0; ci < targetDict.size(); ++ci) {
            double p;
            targetDistribution[ci] = p = targetCounts[ci] / (double)numInsts;
            if (!(p > 0.0)) continue;
            baseEntropy -= p * Math.log(p) / log2;
        }
        LabelVector baseLabelDistribution = new LabelVector(targetDict, targetDistribution);
        double infoGainSum = 0.0;
        int totalNumSplitPoints = 0;
        double[] passTestTargetCounts = new double[targetDict.size()];
        HashMap[] featureToInfo = new HashMap[dataDict.size()];
        for (int fi = 0; fi < dataDict.size(); ++fi) {
            if ((fi + 1) % 1000 == 0) {
                logger.info("at feature " + (fi + 1) + " / " + dataDict.size());
            }
            featureToInfo[fi] = new HashMap();
            Arrays.fill(passTestTargetCounts, 0.0);
            instIndices = GainRatio.sortInstances(ilist, instIndices, fi);
            for (int ii = 0; ii < numInsts - 1; ++ii) {
                double passProportion;
                Instance inst = (Instance)ilist.get(instIndices[ii]);
                Instance instPlusOne = (Instance)ilist.get(instIndices[ii + 1]);
                FeatureVector fv1 = (FeatureVector)inst.getData();
                FeatureVector fv2 = (FeatureVector)instPlusOne.getData();
                double lower = fv1.value(fi);
                double higher = fv2.value(fi);
                Labeling labeling = inst.getLabeling();
                for (int ll = 0; ll < labeling.numLocations(); ++ll) {
                    int li = labeling.indexAtLocation(ll);
                    double labelWeight = labeling.valueAtLocation(ll);
                    int n = li;
                    passTestTargetCounts[n] = passTestTargetCounts[n] + labelWeight;
                }
                if (Maths.almostEquals(lower, higher) || inst.getLabeling().toString().equals(instPlusOne.getLabeling().toString())) continue;
                ++totalNumSplitPoints;
                double splitPoint = (lower + higher) / 2.0;
                double numPassInsts = ii + 1;
                double numFailInsts = (double)numInsts - numPassInsts;
                if (numPassInsts < (double)minNumInsts || numFailInsts < (double)minNumInsts || Maths.almostEquals(passProportion = numPassInsts / (double)numInsts, 0.0) || Maths.almostEquals(passProportion, 1.0)) continue;
                double passEntropy = 0.0;
                double failEntropy = 0.0;
                for (int ci = 0; ci < targetDict.size(); ++ci) {
                    double failTestTargetCount;
                    double p;
                    if (numPassInsts > 0.0 && (p = passTestTargetCounts[ci] / numPassInsts) > 0.0) {
                        passEntropy -= p * Math.log(p) / log2;
                    }
                    if (!(numFailInsts > 0.0) || !((p = (failTestTargetCount = targetCounts[ci] - passTestTargetCounts[ci]) / numFailInsts) > 0.0)) continue;
                    failEntropy -= p * Math.log(p) / log2;
                }
                double gainDT = baseEntropy - passProportion * passEntropy - (1.0 - passProportion) * failEntropy;
                infoGainSum += gainDT;
                double splitDT = -passProportion * Math.log(passProportion) / log2 - (1.0 - passProportion) * Math.log(1.0 - passProportion) / log2;
                double gainRatio = gainDT / splitDT;
                featureToInfo[fi].put(splitPoint, new Point2D.Double(gainDT, gainRatio));
            }
        }
        double[] gainRatios = new double[dataDict.size()];
        double[] splitPoints = new double[dataDict.size()];
        int numSplitsForBestFeature = 0;
        if (totalNumSplitPoints == 0 || Maths.almostEquals(infoGainSum, 0.0)) {
            return new Object[]{gainRatios, splitPoints, baseEntropy, baseLabelDistribution, numSplitsForBestFeature};
        }
        double avgInfoGain = infoGainSum / (double)totalNumSplitPoints;
        double maxGainRatio = 0.0;
        double gainForMaxGainRatio = 0.0;
        int xxx = 0;
        for (int fi = 0; fi < dataDict.size(); ++fi) {
            double featureMaxGainRatio = 0.0;
            double featureGainForMaxGainRatio = 0.0;
            double bestSplitPoint = Double.NaN;
            for (Object key : featureToInfo[fi].keySet()) {
                Point2D.Double pt = (Point2D.Double)featureToInfo[fi].get(key);
                double splitPoint = (Double)key;
                double infoGain = pt.getX();
                double gainRatio = pt.getY();
                if (infoGain >= avgInfoGain) {
                    if (!(gainRatio > featureMaxGainRatio) && (gainRatio != featureMaxGainRatio || !(infoGain > featureGainForMaxGainRatio))) continue;
                    featureMaxGainRatio = gainRatio;
                    featureGainForMaxGainRatio = infoGain;
                    bestSplitPoint = splitPoint;
                    continue;
                }
                ++xxx;
            }
            assert (!Double.isNaN(bestSplitPoint));
            gainRatios[fi] = featureMaxGainRatio;
            splitPoints[fi] = bestSplitPoint;
            if (!(featureMaxGainRatio > maxGainRatio) && (featureMaxGainRatio != maxGainRatio || !(featureGainForMaxGainRatio > gainForMaxGainRatio))) continue;
            maxGainRatio = featureMaxGainRatio;
            gainForMaxGainRatio = featureGainForMaxGainRatio;
            numSplitsForBestFeature = featureToInfo[fi].size();
        }
        logger.info("label distrib:\n" + baseLabelDistribution);
        logger.info("base entropy=" + baseEntropy + ", info gain sum=" + infoGainSum + ", total num split points=" + totalNumSplitPoints + ", avg info gain=" + avgInfoGain + ", num splits with < avg gain=" + xxx);
        return new Object[]{gainRatios, splitPoints, baseEntropy, baseLabelDistribution, numSplitsForBestFeature};
    }

    public static int[] sortInstances(InstanceList ilist, int[] instIndices, int featureIndex) {
        ArrayList<Point2D.Double> list = new ArrayList<Point2D.Double>();
        for (int ii = 0; ii < instIndices.length; ++ii) {
            Instance inst = (Instance)ilist.get(instIndices[ii]);
            FeatureVector fv = (FeatureVector)inst.getData();
            list.add(new Point2D.Double(instIndices[ii], fv.value(featureIndex)));
        }
        Collections.sort(list, new Comparator(){

            public int compare(Object o1, Object o2) {
                Point2D.Double p1 = (Point2D.Double)o1;
                Point2D.Double p2 = (Point2D.Double)o2;
                if (p1.y == p2.y) {
                    assert (p1.x != p2.x);
                    return p1.x > p2.x ? 1 : -1;
                }
                return p1.y > p2.y ? 1 : -1;
            }
        });
        int[] sorted = new int[instIndices.length];
        for (int i = 0; i < list.size(); ++i) {
            sorted[i] = (int)((Point2D.Double)list.get(i)).getX();
        }
        return sorted;
    }

    public static GainRatio createGainRatio(InstanceList ilist) {
        int[] instIndices = new int[ilist.size()];
        for (int ii = 0; ii < instIndices.length; ++ii) {
            instIndices[ii] = ii;
        }
        return GainRatio.createGainRatio(ilist, instIndices, 2);
    }

    public static GainRatio createGainRatio(InstanceList ilist, int[] instIndices, int minNumInsts) {
        Object[] objs = GainRatio.calcGainRatios(ilist, instIndices, minNumInsts);
        double[] gainRatios = (double[])objs[0];
        double[] splitPoints = (double[])objs[1];
        double baseEntropy = (Double)objs[2];
        LabelVector baseLabelDistribution = (LabelVector)objs[3];
        int numSplitPointsForBestFeature = (Integer)objs[4];
        return new GainRatio(ilist.getDataAlphabet(), gainRatios, splitPoints, baseEntropy, baseLabelDistribution, numSplitPointsForBestFeature, minNumInsts);
    }

    protected GainRatio(Alphabet dataAlphabet, double[] gainRatios, double[] splitPoints, double baseEntropy, LabelVector baseLabelDistribution, int numSplitPointsForBestFeature, int minNumInsts) {
        super(dataAlphabet, gainRatios);
        this.m_splitPoints = splitPoints;
        this.m_baseEntropy = baseEntropy;
        this.m_baseLabelDistribution = baseLabelDistribution;
        this.m_numSplitPointsForBestFeature = numSplitPointsForBestFeature;
        this.m_minNumInsts = minNumInsts;
    }

    public double getMaxValuedThreshold() {
        return this.getThresholdAtRank(0);
    }

    public double getThresholdAtRank(int rank) {
        int index = this.getIndexAtRank(rank);
        return this.m_splitPoints[index];
    }

    public double getBaseEntropy() {
        return this.m_baseEntropy;
    }

    public LabelVector getBaseLabelDistribution() {
        return this.m_baseLabelDistribution;
    }

    public int getNumSplitPointsForBestFeature() {
        return this.m_numSplitPointsForBestFeature;
    }
}

