001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.demos.sandbox.ml.regression;
031
032import java.io.File;
033import java.io.IOException;
034
035import javax.swing.JFrame;
036
037import org.apache.commons.lang.StringUtils;
038import org.jfree.chart.ChartFactory;
039import org.jfree.chart.ChartPanel;
040import org.jfree.chart.JFreeChart;
041import org.jfree.data.time.Day;
042import org.jfree.data.time.TimeSeries;
043import org.joda.time.DateTime;
044import org.joda.time.format.DateTimeFormat;
045import org.joda.time.format.DateTimeFormatter;
046import org.openimaj.hadoop.tools.twitter.utils.WordDFIDF;
047import org.openimaj.hadoop.tools.twitter.utils.WordDFIDFTimeSeries;
048import org.openimaj.hadoop.tools.twitter.utils.WordDFIDFTimeSeriesCollection;
049import org.openimaj.io.Cache;
050import org.openimaj.io.IOUtils;
051import org.openimaj.ml.timeseries.IncompatibleTimeSeriesException;
052import org.openimaj.ml.timeseries.aggregator.MeanSquaredDifferenceAggregator;
053import org.openimaj.ml.timeseries.aggregator.SquaredSummedDifferenceAggregator;
054import org.openimaj.ml.timeseries.aggregator.WindowedLinearRegressionAggregator;
055import org.openimaj.ml.timeseries.processor.IntervalSummationProcessor;
056import org.openimaj.ml.timeseries.processor.MovingAverageProcessor;
057import org.openimaj.ml.timeseries.processor.WindowedLinearRegressionProcessor;
058import org.openimaj.ml.timeseries.series.DoubleSynchronisedTimeSeriesCollection;
059import org.openimaj.ml.timeseries.series.DoubleTimeSeries;
060import org.openimaj.twitter.finance.YahooFinanceData;
061import org.openimaj.util.pair.IndependentPair;
062
063/**
064 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
065 * 
066 */
067public class MultipleLinearRegressionPlayground {
068        /**
069         * @param args
070         * @throws IOException
071         * @throws IncompatibleTimeSeriesException
072         */
073        public static void main(String[] args) throws IOException, IncompatibleTimeSeriesException {
074                final String start = "2010-01-01";
075                final String end = "2010-12-31";
076                final String learns = "2010-01-01";
077                final String learne = "2010-05-01";
078                linearRegressStocks(start, end, learns, learne, "MSFT", "AAPL");
079
080        }
081
082        @SuppressWarnings("unchecked")
083        private static void linearRegressStocks(String start, String end, String learns, String learne, String... stocks)
084                        throws IncompatibleTimeSeriesException, IOException
085        {
086                final DoubleSynchronisedTimeSeriesCollection dstsc = new DoubleSynchronisedTimeSeriesCollection();
087                for (final String stock : stocks) {
088                        YahooFinanceData data = new YahooFinanceData(stock, start, end, "YYYY-MM-dd");
089                        data = Cache.load(data);
090                        final DoubleTimeSeries highseries = data.seriesMap().get("High");
091                        dstsc.addTimeSeries(stock, highseries);
092                }
093                TSCollection dataset = new TSCollection();
094                timeSeriesToChart(dstsc, dataset);
095                final DoubleSynchronisedTimeSeriesCollection movingAverage = dstsc.processInternal(new MovingAverageProcessor(30l
096                                * 24l * 60l * 60l * 1000l));
097                timeSeriesToChart(movingAverage, dataset, "-MA");
098                displayTimeSeries(dataset, StringUtils.join(stocks, " & "), "Date", "Price");
099
100                dataset = new TSCollection();
101                timeSeriesToChart("AAPL", dstsc.series("AAPL"), dataset);
102
103                final DoubleTimeSeries interp = dstsc.series("AAPL").process(new WindowedLinearRegressionProcessor(10, 7));
104                timeSeriesToChart("AAPL-interp", interp, dataset);
105                long[] interpTimes = interp.getTimes();
106                DoubleTimeSeries importantAAPL = dstsc.series("AAPL").get(interpTimes[0], interpTimes[interpTimes.length - 1]);
107                final DoubleSynchronisedTimeSeriesCollection aaplinterp = new DoubleSynchronisedTimeSeriesCollection(
108                                IndependentPair.pair("AAPL", importantAAPL),
109                                IndependentPair.pair("AAPL-interp", interp)
110                                );
111                System.out
112                                .println("AAPL linear regression SSE: " + new SquaredSummedDifferenceAggregator().aggregate(aaplinterp));
113
114                final DoubleTimeSeries interpmsft = new WindowedLinearRegressionAggregator("AAPL", 10, 7, true).aggregate(dstsc);
115                timeSeriesToChart("AAPL-interpmstf", interpmsft, dataset);
116                interpTimes = interpmsft.getTimes();
117                importantAAPL = dstsc.series("AAPL").get(interpTimes[0], interpTimes[interpTimes.length - 1]);
118                final DoubleSynchronisedTimeSeriesCollection aaplmsftinterp = new DoubleSynchronisedTimeSeriesCollection(
119                                IndependentPair.pair("AAPL", importantAAPL),
120                                IndependentPair.pair("AAPLMSFT-interp", interpmsft)
121                                );
122                System.out.println("AAPL+MSFT linear regression SSE: "
123                                + new SquaredSummedDifferenceAggregator().aggregate(aaplmsftinterp));
124                displayTimeSeries(dataset, StringUtils.join(stocks, " & ") + " Interp", "Date", "Price");
125
126                dataset = new TSCollection();
127                final DoubleTimeSeries highseries = dstsc.series("AAPL");
128                final DateTimeFormatter parser = DateTimeFormat.forPattern("YYYY-MM-dd");
129                final long learnstart = parser.parseDateTime(learns).getMillis();
130                final long learnend = parser.parseDateTime(learne).getMillis();
131                final DoubleSynchronisedTimeSeriesCollection aaplworddfidf = loadwords("AAPL", dstsc.series("AAPL"));
132                final DoubleSynchronisedTimeSeriesCollection yearFirstHalf = aaplworddfidf.get(learnstart, learnend);
133                final DoubleTimeSeries interpidf107 = new WindowedLinearRegressionAggregator("AAPL", 10, 7, true)
134                                .aggregate(aaplworddfidf);
135                final DoubleTimeSeries interpidf31 = new WindowedLinearRegressionAggregator("AAPL", 3, 1, true)
136                                .aggregate(aaplworddfidf);
137                final DoubleTimeSeries interpidf107unseen = new WindowedLinearRegressionAggregator("AAPL", 10, 7, true,
138                                yearFirstHalf).aggregate(aaplworddfidf);
139
140                final double e107 = MeanSquaredDifferenceAggregator.error(interpidf107, highseries);
141                final double e31 = MeanSquaredDifferenceAggregator.error(interpidf31, highseries);
142                final double e107u = MeanSquaredDifferenceAggregator.error(interpidf107unseen, highseries);
143
144                // dataset.addSeries(timeSeriesToChart(String.format("OLR (m=7,n=10) (MSE=%.2f)",e107),windowedLinearRegression107));
145                // dataset.addSeries(timeSeriesToChart(String.format("OLR (m=1,n=3) (MSE=%.2f)",e31),windowedLinearRegression31));
146                // dataset.addSeries(timeSeriesToChart(String.format("OLR unseen (m=7,n=10) (MSE=%.2f)",e107u),windowedLinearRegression107unseen));
147                timeSeriesToChart("High Value", highseries, dataset);
148                timeSeriesToChart(String.format("OLR (m=7,n=10) (MSE=%.2f)", e107), interpidf107, dataset);
149                timeSeriesToChart(String.format("OLR (m=1,n=3) (MSE=%.2f)", e31), interpidf31, dataset);
150                timeSeriesToChart(String.format("OLR unseen (m=7,n=10) (MSE=%.2f)", e107u), interpidf107unseen, dataset);
151                displayTimeSeries(dataset, StringUtils.join(stocks, " & ") + " Interp", "Date", "Price");
152        }
153
154        private static DoubleSynchronisedTimeSeriesCollection loadwords(String name, DoubleTimeSeries stocks)
155                        throws IOException, IncompatibleTimeSeriesException
156        {
157                final WordDFIDFTimeSeriesCollection AAPLwords = IOUtils.read(new File(
158                                "/Users/ss/Development/data/trendminer-data/datasets/sheffield/2010/part-r-00000"),
159                                WordDFIDFTimeSeriesCollection.class);
160                AAPLwords.processInternalInplace(new IntervalSummationProcessor<WordDFIDF[], WordDFIDF, WordDFIDFTimeSeries>(
161                                stocks.getTimes()));
162
163                final DoubleSynchronisedTimeSeriesCollection coll = new DoubleSynchronisedTimeSeriesCollection();
164                coll.addTimeSeries(name, stocks);
165                for (final String aname : AAPLwords.getNames()) {
166                        coll.addTimeSeries(aname, AAPLwords.series(aname).doubleTimeSeries());
167                }
168                return coll;
169        }
170
171        private static void displayTimeSeries(TSCollection dataset, String name, String xname, String yname) {
172                final JFreeChart chart = ChartFactory.createTimeSeriesChart(name, xname, yname, dataset, true, false, false);
173                final ChartPanel panel = new ChartPanel(chart);
174                panel.setFillZoomRectangle(true);
175                final JFrame j = new JFrame();
176                j.setContentPane(panel);
177                j.pack();
178                j.setVisible(true);
179                j.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
180        }
181
182        private static void timeSeriesToChart(DoubleSynchronisedTimeSeriesCollection dstsc, TSCollection coll,
183                        String... append)
184        {
185                for (final String seriesName : dstsc.getNames()) {
186                        final DoubleTimeSeries series = dstsc.series(seriesName);
187                        final TimeSeries ret = new TimeSeries(seriesName + StringUtils.join(append, "-"));
188                        for (final IndependentPair<Long, Double> pair : series) {
189                                final DateTime dt = new DateTime(pair.firstObject());
190                                final Day d = new Day(dt.getDayOfMonth(), dt.getMonthOfYear(), dt.getYear());
191                                ret.add(d, pair.secondObject());
192                        }
193                        coll.addSeries(ret);
194                }
195        }
196
197        private static void timeSeriesToChart(String name, DoubleTimeSeries highseries, TSCollection coll) {
198                final TimeSeries ret = new TimeSeries(name);
199                for (final IndependentPair<Long, Double> pair : highseries) {
200                        final DateTime dt = new DateTime(pair.firstObject());
201                        final Day d = new Day(dt.getDayOfMonth(), dt.getMonthOfYear(), dt.getYear());
202                        ret.add(d, pair.secondObject());
203                }
204                coll.addSeries(ret);
205        }
206}