/*
 * Decompiled with CFR 0.152.
 */
package org.apache.commons.math4.neuralnet.sofm;

import java.util.Collection;
import java.util.HashSet;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.DoubleUnaryOperator;
import org.apache.commons.math4.neuralnet.DistanceMeasure;
import org.apache.commons.math4.neuralnet.MapRanking;
import org.apache.commons.math4.neuralnet.Network;
import org.apache.commons.math4.neuralnet.Neuron;
import org.apache.commons.math4.neuralnet.UpdateAction;
import org.apache.commons.math4.neuralnet.sofm.LearningFactorFunction;
import org.apache.commons.math4.neuralnet.sofm.NeighbourhoodSizeFunction;

public class KohonenUpdateAction
implements UpdateAction {
    private final DistanceMeasure distance;
    private final LearningFactorFunction learningFactor;
    private final NeighbourhoodSizeFunction neighbourhoodSize;
    private final AtomicLong numberOfCalls = new AtomicLong(0L);

    public KohonenUpdateAction(DistanceMeasure distance, LearningFactorFunction learningFactor, NeighbourhoodSizeFunction neighbourhoodSize) {
        this.distance = distance;
        this.learningFactor = learningFactor;
        this.neighbourhoodSize = neighbourhoodSize;
    }

    @Override
    public void update(Network net, double[] features) {
        long numCalls = this.numberOfCalls.incrementAndGet() - 1L;
        double currentLearning = this.learningFactor.value(numCalls);
        Neuron best = this.findAndUpdateBestNeuron(net, features, currentLearning);
        int currentNeighbourhood = this.neighbourhoodSize.value(numCalls);
        Gaussian neighbourhoodDecay = new Gaussian(currentLearning, currentNeighbourhood);
        if (currentNeighbourhood > 0) {
            Collection<Neuron> neighbours = new HashSet<Neuron>();
            neighbours.add(best);
            HashSet<Neuron> exclude = new HashSet<Neuron>();
            exclude.add(best);
            int radius = 1;
            do {
                neighbours = net.getNeighbours(neighbours, exclude);
                for (Neuron n : neighbours) {
                    this.updateNeighbouringNeuron(n, features, neighbourhoodDecay.applyAsDouble(radius));
                }
                exclude.addAll(neighbours);
            } while (++radius <= currentNeighbourhood);
        }
    }

    public long getNumberOfCalls() {
        return this.numberOfCalls.get();
    }

    private boolean attemptNeuronUpdate(Neuron n, double[] features, double learningRate) {
        double[] expect = n.getFeatures();
        double[] update = this.computeFeatures(expect, features, learningRate);
        return n.compareAndSetFeatures(expect, update);
    }

    private void updateNeighbouringNeuron(Neuron n, double[] features, double learningRate) {
        while (!this.attemptNeuronUpdate(n, features, learningRate)) {
        }
    }

    private Neuron findAndUpdateBestNeuron(Network net, double[] features, double learningRate) {
        Neuron best;
        MapRanking rank = new MapRanking(net, this.distance);
        while (!this.attemptNeuronUpdate(best = rank.rank(features, 1).get(0), features, learningRate)) {
        }
        return best;
    }

    private double[] computeFeatures(double[] current, double[] sample, double learningRate) {
        int len = current.length;
        double[] r = new double[len];
        for (int i = 0; i < len; ++i) {
            double c = current[i];
            double s = sample[i];
            r[i] = c + learningRate * (s - c);
        }
        return r;
    }

    private static class Gaussian
    implements DoubleUnaryOperator {
        private final double i2s2;
        private final double norm;

        Gaussian(double norm, double sigma) {
            this.norm = norm;
            this.i2s2 = 1.0 / (2.0 * sigma * sigma);
        }

        @Override
        public double applyAsDouble(double x) {
            return this.norm * Math.exp(-x * x * this.i2s2);
        }
    }
}

