001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.demos.sandbox.ml.regression;
031
032import java.io.IOException;
033
034import javax.swing.JFrame;
035
036import org.jfree.chart.ChartFactory;
037import org.jfree.chart.ChartPanel;
038import org.jfree.chart.JFreeChart;
039import org.jfree.data.time.Day;
040import org.jfree.data.time.TimeSeries;
041import org.jfree.data.time.TimeSeriesCollection;
042import org.joda.time.DateTime;
043import org.joda.time.format.DateTimeFormat;
044import org.joda.time.format.DateTimeFormatter;
045import org.openimaj.io.Cache;
046import org.openimaj.ml.timeseries.IncompatibleTimeSeriesException;
047import org.openimaj.ml.timeseries.aggregator.MeanSquaredDifferenceAggregator;
048import org.openimaj.ml.timeseries.processor.MovingAverageProcessor;
049import org.openimaj.ml.timeseries.processor.WindowedLinearRegressionProcessor;
050import org.openimaj.ml.timeseries.series.DoubleTimeSeries;
051import org.openimaj.twitter.finance.YahooFinanceData;
052import org.openimaj.util.pair.IndependentPair;
053
054/**
055 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
056 * 
057 */
058public class LinearRegressionPlayground {
059        /**
060         * @param args
061         * @throws IOException
062         * @throws IncompatibleTimeSeriesException
063         */
064        public static void main(String[] args) throws IOException, IncompatibleTimeSeriesException {
065                final String stock = "AAPL";
066                final String start = "2010-01-01";
067                final String end = "2010-12-31";
068                final String learns = "2010-01-01";
069                final String learne = "2010-05-01";
070                final DateTimeFormatter parser = DateTimeFormat.forPattern("YYYY-MM-dd");
071                final long learnstart = parser.parseDateTime(learns).getMillis();
072                final long learnend = parser.parseDateTime(learne).getMillis();
073                YahooFinanceData data = new YahooFinanceData(stock, start, end, "YYYY-MM-dd");
074                data = Cache.load(data);
075                final DoubleTimeSeries highseries = data.seriesMap().get("High");
076                final DoubleTimeSeries yearFirstHalf = highseries.get(learnstart, learnend);
077                TimeSeriesCollection dataset = new TimeSeriesCollection();
078                dataset.addSeries(timeSeriesToChart("High Value", highseries));
079                final DoubleTimeSeries movingAverage = highseries.process(new MovingAverageProcessor(30l * 24l * 60l * 60l
080                                * 1000l));
081                final DoubleTimeSeries halfYearMovingAverage = yearFirstHalf.process(new MovingAverageProcessor(30l * 24l * 60l
082                                * 60l * 1000l));
083
084                dataset.addSeries(
085                                timeSeriesToChart(
086                                                "High Value MA",
087                                                movingAverage
088                                ));
089                dataset.addSeries(
090                                timeSeriesToChart(
091                                                "High Value MA Regressed (all seen)",
092                                                movingAverage.process(new WindowedLinearRegressionProcessor(10, 7))
093                                ));
094                dataset.addSeries(
095                                timeSeriesToChart(
096                                                "High Value MA Regressed (latter half unseen)",
097                                                movingAverage.process(new WindowedLinearRegressionProcessor(halfYearMovingAverage, 10, 7))
098                                ));
099                displayTimeSeries(dataset, stock, "Date", "Price");
100                dataset = new TimeSeriesCollection();
101                dataset.addSeries(timeSeriesToChart("High Value", highseries));
102                // final DoubleTimeSeries linearRegression = highseries.process(new
103                // LinearRegressionProcessor());
104
105                // double lrmsd =
106                // MeanSquaredDifferenceAggregator.error(linearRegression,highseries);
107                // dataset.addSeries(timeSeriesToChart(String.format("OLR (MSE=%.2f)",lrmsd),linearRegression));
108                final DoubleTimeSeries windowedLinearRegression107 = highseries.process(new WindowedLinearRegressionProcessor(10,
109                                7));
110                final DoubleTimeSeries windowedLinearRegression31 = highseries
111                                .process(new WindowedLinearRegressionProcessor(3, 1));
112                final DoubleTimeSeries windowedLinearRegression107unseen = highseries
113                                .process(new WindowedLinearRegressionProcessor(yearFirstHalf, 10, 7));
114
115                final double e107 = MeanSquaredDifferenceAggregator.error(windowedLinearRegression107, highseries);
116                final double e31 = MeanSquaredDifferenceAggregator.error(windowedLinearRegression31, highseries);
117                final double e107u = MeanSquaredDifferenceAggregator.error(windowedLinearRegression107unseen, highseries);
118
119                dataset.addSeries(timeSeriesToChart(String.format("OLR (m=7,n=10) (MSE=%.2f)", e107), windowedLinearRegression107));
120                dataset.addSeries(timeSeriesToChart(String.format("OLR (m=1,n=3) (MSE=%.2f)", e31), windowedLinearRegression31));
121                dataset.addSeries(timeSeriesToChart(String.format("OLR unseen (m=7,n=10) (MSE=%.2f)", e107u),
122                                windowedLinearRegression107unseen));
123                displayTimeSeries(dataset, stock, "Date", "Price");
124
125        }
126
127        private static void displayTimeSeries(TimeSeriesCollection dataset, String name, String xname, String yname) {
128                final JFreeChart chart = ChartFactory.createTimeSeriesChart(name, xname, yname, dataset, true, false, false);
129                final ChartPanel panel = new ChartPanel(chart);
130                panel.setFillZoomRectangle(true);
131                final JFrame j = new JFrame();
132                j.setContentPane(panel);
133                j.pack();
134                j.setVisible(true);
135                j.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
136        }
137
138        private static org.jfree.data.time.TimeSeries timeSeriesToChart(String name, DoubleTimeSeries highseries) {
139                final TimeSeries ret = new TimeSeries(name);
140                for (final IndependentPair<Long, Double> pair : highseries) {
141                        final DateTime dt = new DateTime(pair.firstObject());
142                        final Day d = new Day(dt.getDayOfMonth(), dt.getMonthOfYear(), dt.getYear());
143                        ret.add(d, pair.secondObject());
144                }
145                return ret;
146        }
147}