/*
 * Decompiled with CFR 0.152.
 */
package be.ac.ulg.montefiore.run.jahmm.learn;

import be.ac.ulg.montefiore.run.jahmm.CentroidFactory;
import be.ac.ulg.montefiore.run.jahmm.Hmm;
import be.ac.ulg.montefiore.run.jahmm.Observation;
import be.ac.ulg.montefiore.run.jahmm.Opdf;
import be.ac.ulg.montefiore.run.jahmm.OpdfFactory;
import be.ac.ulg.montefiore.run.jahmm.ViterbiCalculator;
import be.ac.ulg.montefiore.run.jahmm.learn.Clusters;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class KMeansLearner<O extends Observation> {
    private Clusters<O> clusters;
    private int nbStates;
    private List<? extends List<? extends O>> obsSeqs;
    private OpdfFactory<? extends Opdf<O>> opdfFactory;
    private boolean terminated;

    public KMeansLearner(int n, OpdfFactory<? extends Opdf<O>> opdfFactory, List<? extends List<? extends O>> list) {
        this.obsSeqs = list;
        this.opdfFactory = opdfFactory;
        this.nbStates = n;
        List<O> list2 = KMeansLearner.flat(list);
        this.clusters = new Clusters<O>(n, list2);
        this.terminated = false;
    }

    public Hmm<O> iterate() {
        Hmm hmm = new Hmm(this.nbStates, this.opdfFactory);
        this.learnPi(hmm);
        this.learnAij(hmm);
        this.learnOpdf(hmm);
        this.terminated = this.optimizeCluster(hmm);
        return hmm;
    }

    public boolean isTerminated() {
        return this.terminated;
    }

    public Hmm<O> learn() {
        Hmm<O> hmm;
        do {
            hmm = this.iterate();
        } while (!this.isTerminated());
        return hmm;
    }

    private void learnPi(Hmm<?> hmm) {
        double[] dArray = new double[this.nbStates];
        for (int i = 0; i < this.nbStates; ++i) {
            dArray[i] = 0.0;
        }
        for (List<O> list : this.obsSeqs) {
            int n = this.clusters.clusterNb((Observation)list.get(0));
            dArray[n] = dArray[n] + 1.0;
        }
        for (int i = 0; i < this.nbStates; ++i) {
            hmm.setPi(i, dArray[i] / (double)this.obsSeqs.size());
        }
    }

    private void learnAij(Hmm<O> hmm) {
        int n;
        for (int i = 0; i < hmm.nbStates(); ++i) {
            for (int j = 0; j < hmm.nbStates(); ++j) {
                hmm.setAij(i, j, 0.0);
            }
        }
        for (List<O> list : this.obsSeqs) {
            if (list.size() < 2) continue;
            n = this.clusters.clusterNb((Observation)list.get(0));
            for (int i = 1; i < list.size(); ++i) {
                int n2 = n;
                n = this.clusters.clusterNb((Observation)list.get(i));
                hmm.setAij(n2, n, hmm.getAij(n2, n) + 1.0);
            }
        }
        for (int i = 0; i < hmm.nbStates(); ++i) {
            double d = 0.0;
            for (n = 0; n < hmm.nbStates(); ++n) {
                d += hmm.getAij(i, n);
            }
            if (d == 0.0) {
                for (n = 0; n < hmm.nbStates(); ++n) {
                    hmm.setAij(i, n, 1.0 / (double)hmm.nbStates());
                }
                continue;
            }
            for (n = 0; n < hmm.nbStates(); ++n) {
                hmm.setAij(i, n, hmm.getAij(i, n) / d);
            }
        }
    }

    private void learnOpdf(Hmm<O> hmm) {
        for (int i = 0; i < hmm.nbStates(); ++i) {
            Collection<O> collection = this.clusters.cluster(i);
            if (collection.isEmpty()) {
                hmm.setOpdf(i, this.opdfFactory.factor());
                continue;
            }
            hmm.getOpdf(i).fit(collection);
        }
    }

    private boolean optimizeCluster(Hmm<O> hmm) {
        boolean bl = false;
        for (List<O> list : this.obsSeqs) {
            ViterbiCalculator viterbiCalculator = new ViterbiCalculator(list, hmm);
            int[] nArray = viterbiCalculator.stateSequence();
            for (int i = 0; i < nArray.length; ++i) {
                Observation observation = (Observation)list.get(i);
                if (this.clusters.clusterNb(observation) == nArray[i]) continue;
                bl = true;
                this.clusters.remove(observation, this.clusters.clusterNb(observation));
                this.clusters.put((CentroidFactory)((Object)observation), nArray[i]);
            }
        }
        return !bl;
    }

    static <T> List<T> flat(List<? extends List<? extends T>> list) {
        ArrayList<T> arrayList = new ArrayList<T>();
        for (List<T> list2 : list) {
            arrayList.addAll(list2);
        }
        return arrayList;
    }
}

