001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.workinprogress.sgdsvm;
031
032import static java.lang.Math.exp;
033import static java.lang.Math.log;
034
035public enum LossFunctions implements Loss {
036        LogLoss
037        {
038                // logloss(a,y) = log(1+exp(-a*y))
039                @Override
040                public double loss(double a, double y) {
041                        final double z = a * y;
042                        if (z > 18)
043                                return exp(-z);
044                        if (z < -18)
045                                return -z;
046                        return log(1 + exp(-z));
047                }
048
049                // -dloss(a,y)/da
050                @Override
051                public double dloss(double a, double y) {
052                        final double z = a * y;
053                        if (z > 18)
054                                return y * exp(-z);
055                        if (z < -18)
056                                return y;
057                        return y / (1 + exp(z));
058                }
059        },
060        HingeLoss
061        {
062                // hingeloss(a,y) = max(0, 1-a*y)
063                @Override
064                public double loss(double a, double y) {
065                        final double z = a * y;
066                        if (z > 1)
067                                return 0;
068                        return 1 - z;
069                }
070
071                // -dloss(a,y)/da
072                @Override
073                public double dloss(double a, double y) {
074                        final double z = a * y;
075                        if (z > 1)
076                                return 0;
077                        return y;
078                }
079        },
080        SquaredHingeLoss
081        {
082                // squaredhingeloss(a,y) = 1/2 * max(0, 1-a*y)^2
083                @Override
084                public double loss(double a, double y) {
085                        final double z = a * y;
086                        if (z > 1)
087                                return 0;
088                        final double d = 1 - z;
089                        return 0.5 * d * d;
090
091                }
092
093                // -dloss(a,y)/da
094                @Override
095                public double dloss(double a, double y) {
096                        final double z = a * y;
097                        if (z > 1)
098                                return 0;
099                        return y * (1 - z);
100                }
101        },
102        SmoothHingeLoss
103        {
104                // smoothhingeloss(a,y) = ...
105                @Override
106                public double loss(double a, double y) {
107                        final double z = a * y;
108                        if (z > 1)
109                                return 0;
110                        if (z < 0)
111                                return 0.5 - z;
112                        final double d = 1 - z;
113                        return 0.5 * d * d;
114                }
115
116                // -dloss(a,y)/da
117                @Override
118                public double dloss(double a, double y) {
119                        final double z = a * y;
120                        if (z > 1)
121                                return 0;
122                        if (z < 0)
123                                return y;
124                        return y * (1 - z);
125                }
126        };
127}