001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.demos.sandbox.ml.regression; 031 032import java.io.IOException; 033 034import javax.swing.JFrame; 035 036import org.jfree.chart.ChartFactory; 037import org.jfree.chart.ChartPanel; 038import org.jfree.chart.JFreeChart; 039import org.jfree.data.time.Day; 040import org.jfree.data.time.TimeSeries; 041import org.jfree.data.time.TimeSeriesCollection; 042import org.joda.time.DateTime; 043import org.joda.time.format.DateTimeFormat; 044import org.joda.time.format.DateTimeFormatter; 045import org.openimaj.io.Cache; 046import org.openimaj.ml.timeseries.IncompatibleTimeSeriesException; 047import org.openimaj.ml.timeseries.aggregator.MeanSquaredDifferenceAggregator; 048import org.openimaj.ml.timeseries.processor.MovingAverageProcessor; 049import org.openimaj.ml.timeseries.processor.WindowedLinearRegressionProcessor; 050import org.openimaj.ml.timeseries.series.DoubleTimeSeries; 051import org.openimaj.twitter.finance.YahooFinanceData; 052import org.openimaj.util.pair.IndependentPair; 053 054/** 055 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 056 * 057 */ 058public class LinearRegressionPlayground { 059 /** 060 * @param args 061 * @throws IOException 062 * @throws IncompatibleTimeSeriesException 063 */ 064 public static void main(String[] args) throws IOException, IncompatibleTimeSeriesException { 065 final String stock = "AAPL"; 066 final String start = "2010-01-01"; 067 final String end = "2010-12-31"; 068 final String learns = "2010-01-01"; 069 final String learne = "2010-05-01"; 070 final DateTimeFormatter parser = DateTimeFormat.forPattern("YYYY-MM-dd"); 071 final long learnstart = parser.parseDateTime(learns).getMillis(); 072 final long learnend = parser.parseDateTime(learne).getMillis(); 073 YahooFinanceData data = new YahooFinanceData(stock, start, end, "YYYY-MM-dd"); 074 data = Cache.load(data); 075 final DoubleTimeSeries highseries = data.seriesMap().get("High"); 076 final DoubleTimeSeries yearFirstHalf = highseries.get(learnstart, learnend); 077 TimeSeriesCollection dataset = new TimeSeriesCollection(); 078 dataset.addSeries(timeSeriesToChart("High Value", highseries)); 079 final DoubleTimeSeries movingAverage = highseries.process(new MovingAverageProcessor(30l * 24l * 60l * 60l 080 * 1000l)); 081 final DoubleTimeSeries halfYearMovingAverage = yearFirstHalf.process(new MovingAverageProcessor(30l * 24l * 60l 082 * 60l * 1000l)); 083 084 dataset.addSeries( 085 timeSeriesToChart( 086 "High Value MA", 087 movingAverage 088 )); 089 dataset.addSeries( 090 timeSeriesToChart( 091 "High Value MA Regressed (all seen)", 092 movingAverage.process(new WindowedLinearRegressionProcessor(10, 7)) 093 )); 094 dataset.addSeries( 095 timeSeriesToChart( 096 "High Value MA Regressed (latter half unseen)", 097 movingAverage.process(new WindowedLinearRegressionProcessor(halfYearMovingAverage, 10, 7)) 098 )); 099 displayTimeSeries(dataset, stock, "Date", "Price"); 100 dataset = new TimeSeriesCollection(); 101 dataset.addSeries(timeSeriesToChart("High Value", highseries)); 102 // final DoubleTimeSeries linearRegression = highseries.process(new 103 // LinearRegressionProcessor()); 104 105 // double lrmsd = 106 // MeanSquaredDifferenceAggregator.error(linearRegression,highseries); 107 // dataset.addSeries(timeSeriesToChart(String.format("OLR (MSE=%.2f)",lrmsd),linearRegression)); 108 final DoubleTimeSeries windowedLinearRegression107 = highseries.process(new WindowedLinearRegressionProcessor(10, 109 7)); 110 final DoubleTimeSeries windowedLinearRegression31 = highseries 111 .process(new WindowedLinearRegressionProcessor(3, 1)); 112 final DoubleTimeSeries windowedLinearRegression107unseen = highseries 113 .process(new WindowedLinearRegressionProcessor(yearFirstHalf, 10, 7)); 114 115 final double e107 = MeanSquaredDifferenceAggregator.error(windowedLinearRegression107, highseries); 116 final double e31 = MeanSquaredDifferenceAggregator.error(windowedLinearRegression31, highseries); 117 final double e107u = MeanSquaredDifferenceAggregator.error(windowedLinearRegression107unseen, highseries); 118 119 dataset.addSeries(timeSeriesToChart(String.format("OLR (m=7,n=10) (MSE=%.2f)", e107), windowedLinearRegression107)); 120 dataset.addSeries(timeSeriesToChart(String.format("OLR (m=1,n=3) (MSE=%.2f)", e31), windowedLinearRegression31)); 121 dataset.addSeries(timeSeriesToChart(String.format("OLR unseen (m=7,n=10) (MSE=%.2f)", e107u), 122 windowedLinearRegression107unseen)); 123 displayTimeSeries(dataset, stock, "Date", "Price"); 124 125 } 126 127 private static void displayTimeSeries(TimeSeriesCollection dataset, String name, String xname, String yname) { 128 final JFreeChart chart = ChartFactory.createTimeSeriesChart(name, xname, yname, dataset, true, false, false); 129 final ChartPanel panel = new ChartPanel(chart); 130 panel.setFillZoomRectangle(true); 131 final JFrame j = new JFrame(); 132 j.setContentPane(panel); 133 j.pack(); 134 j.setVisible(true); 135 j.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 136 } 137 138 private static org.jfree.data.time.TimeSeries timeSeriesToChart(String name, DoubleTimeSeries highseries) { 139 final TimeSeries ret = new TimeSeries(name); 140 for (final IndependentPair<Long, Double> pair : highseries) { 141 final DateTime dt = new DateTime(pair.firstObject()); 142 final Day d = new Day(dt.getDayOfMonth(), dt.getMonthOfYear(), dt.getYear()); 143 ret.add(d, pair.secondObject()); 144 } 145 return ret; 146 } 147}