/*
 * Decompiled with CFR 0.152.
 */
package moa.recommender.rc.predictor.impl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import moa.recommender.rc.data.RecommenderData;
import moa.recommender.rc.utils.Pair;
import moa.recommender.rc.utils.Rating;
import moa.recommender.rc.utils.SparseVector;
import moa.recommender.rc.utils.Updatable;

public class BRISMFPredictor
implements Updatable {
    protected RecommenderData data;
    protected int nFeatures;
    protected HashMap<Integer, float[]> userFeature;
    protected HashMap<Integer, float[]> itemFeature;
    protected Random rnd;
    protected double lRate = 0.01;
    protected double rFactor = 0.02;
    protected int nIterations = 30;

    public void setLRate(double lRate) {
        this.lRate = lRate;
    }

    public void setRFactor(double rFactor) {
        this.rFactor = rFactor;
    }

    public void setNIterations(int nIterations) {
        this.nIterations = nIterations;
    }

    public RecommenderData getData() {
        return this.data;
    }

    public BRISMFPredictor(int nFeatures, RecommenderData data, boolean train) {
        this.data = data;
        this.nFeatures = nFeatures;
        this.userFeature = new HashMap();
        this.itemFeature = new HashMap();
        this.rnd = new Random(12345L);
        data.attachUpdatable(this);
        if (train) {
            this.train();
        }
    }

    public BRISMFPredictor(int nFeatures, RecommenderData data, double lRate, double rFactor, boolean train) {
        this.data = data;
        this.nFeatures = nFeatures;
        this.userFeature = new HashMap();
        this.itemFeature = new HashMap();
        this.rnd = new Random(12345L);
        this.lRate = lRate;
        this.rFactor = rFactor;
        data.attachUpdatable(this);
        if (train) {
            this.train();
        }
    }

    private void resetFeatures(float[] feats, boolean userFeats) {
        int n = feats.length;
        for (int i = 0; i < n; ++i) {
            feats[i] = 0.01f * (this.rnd.nextFloat() * 2.0f - 1.0f);
        }
        if (userFeats) {
            feats[0] = 1.0f;
        } else {
            feats[1] = 1.0f;
        }
    }

    public double predictRating(int userID, int itemID) {
        float[] userFeats = this.userFeature.get(userID);
        float[] itemFeats = this.itemFeature.get(itemID);
        return this.predictRating(userFeats, itemFeats);
    }

    public double predictRating(float[] userFeats, float[] itemFeats) {
        double ret = this.data.getGlobalMean();
        if (userFeats != null && itemFeats != null) {
            for (int i = 0; i < this.nFeatures; ++i) {
                ret += (double)(userFeats[i] * itemFeats[i]);
            }
        }
        if (ret < this.data.getMinRating()) {
            ret = this.data.getMinRating();
        } else if (ret > this.data.getMaxRating()) {
            ret = this.data.getMaxRating();
        }
        return ret;
    }

    public float[] trainUserFeats(List<Integer> itm, List<Double> rat, int nIts) {
        float[] userFeats = new float[this.nFeatures];
        this.resetFeatures(userFeats, true);
        int n = itm.size();
        for (int k = 0; k < nIts; ++k) {
            for (int i = 0; i < n; ++i) {
                int itemID = itm.get(i);
                float[] itemFeats = this.itemFeature.get(itemID);
                double rating = rat.get(i);
                double pred = this.predictRating(userFeats, itemFeats);
                double err = rating - pred;
                if (itemFeats == null) continue;
                for (int j = 1; j < this.nFeatures; ++j) {
                    int n2 = j;
                    userFeats[n2] = (float)((double)userFeats[n2] + this.lRate * (err * (double)itemFeats[j] - this.rFactor * (double)userFeats[j]));
                }
            }
        }
        return userFeats;
    }

    public float[] trainItemFeats(int itemID, List<Integer> usr, List<Double> rat, int nIts) {
        float[] itemFeats = new float[this.nFeatures];
        this.resetFeatures(itemFeats, false);
        int n = usr.size();
        for (int k = 0; k < nIts; ++k) {
            for (int i = 0; i < n; ++i) {
                int userID = usr.get(i);
                float[] userFeats = this.userFeature.get(userID);
                double rating = rat.get(i);
                double pred = this.predictRating(userFeats, itemFeats);
                double err = rating - pred;
                if (userFeats == null) continue;
                itemFeats[0] = (float)((double)itemFeats[0] + this.lRate * (err * (double)userFeats[0] - this.rFactor * (double)itemFeats[0]));
                for (int j = 2; j < this.nFeatures; ++j) {
                    int n2 = j;
                    itemFeats[n2] = (float)((double)itemFeats[n2] + this.lRate * (err * (double)userFeats[j] - this.rFactor * (double)itemFeats[j]));
                }
            }
        }
        return itemFeats;
    }

    public void trainUser(int userID, List<Integer> itm, List<Double> rat, int nIts) {
        this.userFeature.put(userID, this.trainUserFeats(itm, rat, nIts));
    }

    public void trainUser(int userID, int nIts) {
        SparseVector usrRats = this.data.getRatingsUser(userID);
        ArrayList<Integer> itm = new ArrayList<Integer>();
        ArrayList<Double> rat = new ArrayList<Double>();
        Iterator<Pair<Integer, Double>> it = usrRats.iterator();
        while (it.hasNext()) {
            Pair<Integer, Double> p = it.next();
            itm.add(p.getFirst());
            rat.add(p.getSecond());
        }
        this.trainUser(userID, itm, rat, nIts);
    }

    public void trainUser(int userID, List<Integer> itm, List<Double> rat) {
        this.userFeature.put(userID, this.trainUserFeats(itm, rat, this.nIterations));
    }

    public void trainItem(int itemID) {
        SparseVector itmRats = this.data.getRatingsItem(itemID);
        ArrayList<Integer> usr = new ArrayList<Integer>();
        ArrayList<Double> rat = new ArrayList<Double>();
        Iterator<Pair<Integer, Double>> it = itmRats.iterator();
        while (it.hasNext()) {
            Pair<Integer, Double> p = it.next();
            usr.add(p.getFirst());
            rat.add(p.getSecond());
        }
        this.trainItem(itemID, usr, rat);
    }

    public void trainItem(int itemID, int nIts) {
        SparseVector itmRats = this.data.getRatingsItem(itemID);
        ArrayList<Integer> usr = new ArrayList<Integer>();
        ArrayList<Double> rat = new ArrayList<Double>();
        Iterator<Pair<Integer, Double>> it = itmRats.iterator();
        while (it.hasNext()) {
            Pair<Integer, Double> p = it.next();
            usr.add(p.getFirst());
            rat.add(p.getSecond());
        }
        this.trainItem(itemID, usr, rat, nIts);
    }

    public void trainUser(int userID) {
        SparseVector usrRats = this.data.getRatingsUser(userID);
        ArrayList<Integer> itm = new ArrayList<Integer>();
        ArrayList<Double> rat = new ArrayList<Double>();
        Iterator<Pair<Integer, Double>> it = usrRats.iterator();
        while (it.hasNext()) {
            Pair<Integer, Double> p = it.next();
            itm.add(p.getFirst());
            rat.add(p.getSecond());
        }
        this.trainUser(userID, itm, rat);
    }

    public void trainItem(int itemID, List<Integer> usr, List<Double> rat) {
        this.itemFeature.put(itemID, this.trainItemFeats(itemID, usr, rat, this.nIterations));
    }

    public void trainItem(int itemID, List<Integer> usr, List<Double> rat, int nIts) {
        this.itemFeature.put(itemID, this.trainItemFeats(itemID, usr, rat, nIts));
    }

    public void train() {
        float[] feats;
        this.userFeature.clear();
        this.itemFeature.clear();
        int n = this.data.getNumRatings();
        Iterator<Integer> it = this.data.getUsers().iterator();
        while (it.hasNext()) {
            feats = new float[this.nFeatures];
            this.resetFeatures(feats, true);
            this.userFeature.put(it.next(), feats);
        }
        it = this.data.getItems().iterator();
        while (it.hasNext()) {
            feats = new float[this.nFeatures];
            this.resetFeatures(feats, false);
            this.itemFeature.put(it.next(), feats);
        }
        int exit = 0;
        double lastRMSE = 1.0E20;
        int count = 0;
        int trainDiv = Math.max(20, n / 1000000);
        ArrayList<Rating> ratTest = new ArrayList<Rating>(n / trainDiv);
        do {
            long start = System.currentTimeMillis();
            Iterator<Rating> ratIt = this.data.ratingIterator();
            int idx = 0;
            while (ratIt.hasNext()) {
                Rating rat = ratIt.next();
                if (idx % trainDiv == 0) {
                    if (count == 0) {
                        ratTest.add(rat);
                    }
                } else {
                    int userID = rat.userID;
                    int itemID = rat.itemID;
                    double rating = rat.rating;
                    float[] userFeats = this.userFeature.get(userID);
                    float[] itemFeats = this.itemFeature.get(itemID);
                    double pred = this.predictRating(userFeats, itemFeats);
                    double err = rating - pred;
                    itemFeats[0] = (float)((double)itemFeats[0] + this.lRate * (err * (double)userFeats[0] - this.rFactor * (double)itemFeats[0]));
                    userFeats[1] = (float)((double)userFeats[1] + this.lRate * (err * (double)itemFeats[1] - this.rFactor * (double)userFeats[1]));
                    for (int j = 2; j < this.nFeatures; ++j) {
                        double uv = userFeats[j];
                        int n2 = j;
                        userFeats[n2] = (float)((double)userFeats[n2] + this.lRate * (err * (double)itemFeats[j] - this.rFactor * (double)userFeats[j]));
                        int n3 = j;
                        itemFeats[n3] = (float)((double)itemFeats[n3] + this.lRate * (err * uv - this.rFactor * (double)itemFeats[j]));
                    }
                }
                ++idx;
            }
            int nTest = ratTest.size();
            double sum = 0.0;
            for (int i = 0; i < nTest; ++i) {
                int userID = ((Rating)ratTest.get((int)i)).userID;
                int itemID = ((Rating)ratTest.get((int)i)).itemID;
                double rating = ((Rating)ratTest.get((int)i)).rating;
                double pred = this.predictRating(userID, itemID);
                sum += Math.pow(rating - pred, 2.0);
            }
            double curRMSE = Math.sqrt(sum / (double)nTest);
            System.out.println(curRMSE + " " + (System.currentTimeMillis() - start) / 1000L);
            if (curRMSE + 1.0E-4 >= lastRMSE) {
                ++exit;
            }
            lastRMSE = curRMSE;
            ++count;
        } while (exit < 1);
    }

    public float[] getUserFeatures(int userID) {
        return this.userFeature.get(userID);
    }

    public float[] getItemFeatures(int itemID) {
        return this.itemFeature.get(itemID);
    }

    public int getNumFeatures() {
        return this.nFeatures;
    }

    @Override
    public void updateNewUser(int userID, List<Integer> ratedItems, List<Double> ratings) {
        if (!ratedItems.isEmpty()) {
            this.trainUser(userID, ratedItems, ratings);
        }
    }

    @Override
    public void updateNewItem(int itemID, List<Integer> ratingUsers, List<Double> ratings) {
        if (!ratingUsers.isEmpty()) {
            this.trainItem(itemID, ratingUsers, ratings);
        }
    }

    @Override
    public void updateRemoveUser(int userID) {
        this.userFeature.remove(userID);
    }

    @Override
    public void updateRemoveItem(int itemID) {
        this.itemFeature.remove(itemID);
    }

    @Override
    public void updateSetRating(int userID, int itemID, double rating) {
        Pair<Integer, Double> p;
        double nUsr = this.data.countRatingsUser(userID);
        double nItm = this.data.countRatingsItem(itemID);
        double prob1 = Math.pow(0.99, nUsr);
        double prob2 = Math.pow(0.99, nItm);
        if (nUsr < 5.0 || this.rnd.nextDouble() < prob1) {
            SparseVector usrRats = this.data.getRatingsUser(userID);
            ArrayList<Integer> itm = new ArrayList<Integer>();
            ArrayList<Double> rat = new ArrayList<Double>();
            boolean found = false;
            Iterator<Pair<Integer, Double>> it = usrRats.iterator();
            while (it.hasNext()) {
                p = it.next();
                itm.add(p.getFirst());
                if (p.getFirst() == itemID) {
                    found = true;
                    rat.add(rating);
                    continue;
                }
                rat.add(p.getSecond());
            }
            if (!found) {
                itm.add(itemID);
                rat.add(rating);
            }
            this.trainUser(userID, itm, rat);
        }
        if (nItm < 5.0 || this.rnd.nextDouble() < prob2) {
            SparseVector itmRats = this.data.getRatingsItem(itemID);
            Iterator<Pair<Integer, Double>> it = itmRats.iterator();
            boolean found = false;
            ArrayList<Integer> usr = new ArrayList<Integer>();
            ArrayList<Double> rat = new ArrayList<Double>();
            while (it.hasNext()) {
                p = it.next();
                usr.add(p.getFirst());
                if (p.getFirst() == userID) {
                    found = true;
                    rat.add(rating);
                    continue;
                }
                rat.add(p.getSecond());
            }
            if (!found) {
                usr.add(itemID);
                rat.add(rating);
            }
            this.trainItem(itemID, usr, rat);
        }
    }

    @Override
    public void updateRemoveRating(int userID, int itemID) {
    }

    public List<Double> predictRatings(int userID, List<Integer> itemIDS) {
        int n = itemIDS.size();
        ArrayList<Double> ret = new ArrayList<Double>(n);
        for (int i = 0; i < n; ++i) {
            ret.add(this.predictRating(userID, itemIDS.get(i)));
        }
        return ret;
    }
}

