001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.demos.sandbox.ml.regression; 031 032import java.io.File; 033import java.io.IOException; 034 035import javax.swing.JFrame; 036 037import org.apache.commons.lang.StringUtils; 038import org.jfree.chart.ChartFactory; 039import org.jfree.chart.ChartPanel; 040import org.jfree.chart.JFreeChart; 041import org.jfree.data.time.Day; 042import org.jfree.data.time.TimeSeries; 043import org.joda.time.DateTime; 044import org.joda.time.format.DateTimeFormat; 045import org.joda.time.format.DateTimeFormatter; 046import org.openimaj.hadoop.tools.twitter.utils.WordDFIDF; 047import org.openimaj.hadoop.tools.twitter.utils.WordDFIDFTimeSeries; 048import org.openimaj.hadoop.tools.twitter.utils.WordDFIDFTimeSeriesCollection; 049import org.openimaj.io.Cache; 050import org.openimaj.io.IOUtils; 051import org.openimaj.ml.timeseries.IncompatibleTimeSeriesException; 052import org.openimaj.ml.timeseries.aggregator.MeanSquaredDifferenceAggregator; 053import org.openimaj.ml.timeseries.aggregator.SquaredSummedDifferenceAggregator; 054import org.openimaj.ml.timeseries.aggregator.WindowedLinearRegressionAggregator; 055import org.openimaj.ml.timeseries.processor.IntervalSummationProcessor; 056import org.openimaj.ml.timeseries.processor.MovingAverageProcessor; 057import org.openimaj.ml.timeseries.processor.WindowedLinearRegressionProcessor; 058import org.openimaj.ml.timeseries.series.DoubleSynchronisedTimeSeriesCollection; 059import org.openimaj.ml.timeseries.series.DoubleTimeSeries; 060import org.openimaj.twitter.finance.YahooFinanceData; 061import org.openimaj.util.pair.IndependentPair; 062 063/** 064 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 065 * 066 */ 067public class MultipleLinearRegressionPlayground { 068 /** 069 * @param args 070 * @throws IOException 071 * @throws IncompatibleTimeSeriesException 072 */ 073 public static void main(String[] args) throws IOException, IncompatibleTimeSeriesException { 074 final String start = "2010-01-01"; 075 final String end = "2010-12-31"; 076 final String learns = "2010-01-01"; 077 final String learne = "2010-05-01"; 078 linearRegressStocks(start, end, learns, learne, "MSFT", "AAPL"); 079 080 } 081 082 @SuppressWarnings("unchecked") 083 private static void linearRegressStocks(String start, String end, String learns, String learne, String... stocks) 084 throws IncompatibleTimeSeriesException, IOException 085 { 086 final DoubleSynchronisedTimeSeriesCollection dstsc = new DoubleSynchronisedTimeSeriesCollection(); 087 for (final String stock : stocks) { 088 YahooFinanceData data = new YahooFinanceData(stock, start, end, "YYYY-MM-dd"); 089 data = Cache.load(data); 090 final DoubleTimeSeries highseries = data.seriesMap().get("High"); 091 dstsc.addTimeSeries(stock, highseries); 092 } 093 TSCollection dataset = new TSCollection(); 094 timeSeriesToChart(dstsc, dataset); 095 final DoubleSynchronisedTimeSeriesCollection movingAverage = dstsc.processInternal(new MovingAverageProcessor(30l 096 * 24l * 60l * 60l * 1000l)); 097 timeSeriesToChart(movingAverage, dataset, "-MA"); 098 displayTimeSeries(dataset, StringUtils.join(stocks, " & "), "Date", "Price"); 099 100 dataset = new TSCollection(); 101 timeSeriesToChart("AAPL", dstsc.series("AAPL"), dataset); 102 103 final DoubleTimeSeries interp = dstsc.series("AAPL").process(new WindowedLinearRegressionProcessor(10, 7)); 104 timeSeriesToChart("AAPL-interp", interp, dataset); 105 long[] interpTimes = interp.getTimes(); 106 DoubleTimeSeries importantAAPL = dstsc.series("AAPL").get(interpTimes[0], interpTimes[interpTimes.length - 1]); 107 final DoubleSynchronisedTimeSeriesCollection aaplinterp = new DoubleSynchronisedTimeSeriesCollection( 108 IndependentPair.pair("AAPL", importantAAPL), 109 IndependentPair.pair("AAPL-interp", interp) 110 ); 111 System.out 112 .println("AAPL linear regression SSE: " + new SquaredSummedDifferenceAggregator().aggregate(aaplinterp)); 113 114 final DoubleTimeSeries interpmsft = new WindowedLinearRegressionAggregator("AAPL", 10, 7, true).aggregate(dstsc); 115 timeSeriesToChart("AAPL-interpmstf", interpmsft, dataset); 116 interpTimes = interpmsft.getTimes(); 117 importantAAPL = dstsc.series("AAPL").get(interpTimes[0], interpTimes[interpTimes.length - 1]); 118 final DoubleSynchronisedTimeSeriesCollection aaplmsftinterp = new DoubleSynchronisedTimeSeriesCollection( 119 IndependentPair.pair("AAPL", importantAAPL), 120 IndependentPair.pair("AAPLMSFT-interp", interpmsft) 121 ); 122 System.out.println("AAPL+MSFT linear regression SSE: " 123 + new SquaredSummedDifferenceAggregator().aggregate(aaplmsftinterp)); 124 displayTimeSeries(dataset, StringUtils.join(stocks, " & ") + " Interp", "Date", "Price"); 125 126 dataset = new TSCollection(); 127 final DoubleTimeSeries highseries = dstsc.series("AAPL"); 128 final DateTimeFormatter parser = DateTimeFormat.forPattern("YYYY-MM-dd"); 129 final long learnstart = parser.parseDateTime(learns).getMillis(); 130 final long learnend = parser.parseDateTime(learne).getMillis(); 131 final DoubleSynchronisedTimeSeriesCollection aaplworddfidf = loadwords("AAPL", dstsc.series("AAPL")); 132 final DoubleSynchronisedTimeSeriesCollection yearFirstHalf = aaplworddfidf.get(learnstart, learnend); 133 final DoubleTimeSeries interpidf107 = new WindowedLinearRegressionAggregator("AAPL", 10, 7, true) 134 .aggregate(aaplworddfidf); 135 final DoubleTimeSeries interpidf31 = new WindowedLinearRegressionAggregator("AAPL", 3, 1, true) 136 .aggregate(aaplworddfidf); 137 final DoubleTimeSeries interpidf107unseen = new WindowedLinearRegressionAggregator("AAPL", 10, 7, true, 138 yearFirstHalf).aggregate(aaplworddfidf); 139 140 final double e107 = MeanSquaredDifferenceAggregator.error(interpidf107, highseries); 141 final double e31 = MeanSquaredDifferenceAggregator.error(interpidf31, highseries); 142 final double e107u = MeanSquaredDifferenceAggregator.error(interpidf107unseen, highseries); 143 144 // dataset.addSeries(timeSeriesToChart(String.format("OLR (m=7,n=10) (MSE=%.2f)",e107),windowedLinearRegression107)); 145 // dataset.addSeries(timeSeriesToChart(String.format("OLR (m=1,n=3) (MSE=%.2f)",e31),windowedLinearRegression31)); 146 // dataset.addSeries(timeSeriesToChart(String.format("OLR unseen (m=7,n=10) (MSE=%.2f)",e107u),windowedLinearRegression107unseen)); 147 timeSeriesToChart("High Value", highseries, dataset); 148 timeSeriesToChart(String.format("OLR (m=7,n=10) (MSE=%.2f)", e107), interpidf107, dataset); 149 timeSeriesToChart(String.format("OLR (m=1,n=3) (MSE=%.2f)", e31), interpidf31, dataset); 150 timeSeriesToChart(String.format("OLR unseen (m=7,n=10) (MSE=%.2f)", e107u), interpidf107unseen, dataset); 151 displayTimeSeries(dataset, StringUtils.join(stocks, " & ") + " Interp", "Date", "Price"); 152 } 153 154 private static DoubleSynchronisedTimeSeriesCollection loadwords(String name, DoubleTimeSeries stocks) 155 throws IOException, IncompatibleTimeSeriesException 156 { 157 final WordDFIDFTimeSeriesCollection AAPLwords = IOUtils.read(new File( 158 "/Users/ss/Development/data/trendminer-data/datasets/sheffield/2010/part-r-00000"), 159 WordDFIDFTimeSeriesCollection.class); 160 AAPLwords.processInternalInplace(new IntervalSummationProcessor<WordDFIDF[], WordDFIDF, WordDFIDFTimeSeries>( 161 stocks.getTimes())); 162 163 final DoubleSynchronisedTimeSeriesCollection coll = new DoubleSynchronisedTimeSeriesCollection(); 164 coll.addTimeSeries(name, stocks); 165 for (final String aname : AAPLwords.getNames()) { 166 coll.addTimeSeries(aname, AAPLwords.series(aname).doubleTimeSeries()); 167 } 168 return coll; 169 } 170 171 private static void displayTimeSeries(TSCollection dataset, String name, String xname, String yname) { 172 final JFreeChart chart = ChartFactory.createTimeSeriesChart(name, xname, yname, dataset, true, false, false); 173 final ChartPanel panel = new ChartPanel(chart); 174 panel.setFillZoomRectangle(true); 175 final JFrame j = new JFrame(); 176 j.setContentPane(panel); 177 j.pack(); 178 j.setVisible(true); 179 j.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 180 } 181 182 private static void timeSeriesToChart(DoubleSynchronisedTimeSeriesCollection dstsc, TSCollection coll, 183 String... append) 184 { 185 for (final String seriesName : dstsc.getNames()) { 186 final DoubleTimeSeries series = dstsc.series(seriesName); 187 final TimeSeries ret = new TimeSeries(seriesName + StringUtils.join(append, "-")); 188 for (final IndependentPair<Long, Double> pair : series) { 189 final DateTime dt = new DateTime(pair.firstObject()); 190 final Day d = new Day(dt.getDayOfMonth(), dt.getMonthOfYear(), dt.getYear()); 191 ret.add(d, pair.secondObject()); 192 } 193 coll.addSeries(ret); 194 } 195 } 196 197 private static void timeSeriesToChart(String name, DoubleTimeSeries highseries, TSCollection coll) { 198 final TimeSeries ret = new TimeSeries(name); 199 for (final IndependentPair<Long, Double> pair : highseries) { 200 final DateTime dt = new DateTime(pair.firstObject()); 201 final Day d = new Day(dt.getDayOfMonth(), dt.getMonthOfYear(), dt.getYear()); 202 ret.add(d, pair.secondObject()); 203 } 204 coll.addSeries(ret); 205 } 206}