source: mystic/examples/rosetta_parabola.py @ 855

Revision 855, 5.2 KB checked in by mmckerns, 5 months ago (diff)

updated copyright to 2016

Line 
1#!/usr/bin/env python
2#
3# Author: Mike McKerns (mmckerns @caltech and @uqfoundation)
4# Copyright (c) 1997-2016 California Institute of Technology.
5# License: 3-clause BSD.  The full license text is available at:
6#  - http://mmckerns.github.io/project/mystic/browser/mystic/LICENSE
7"""
8Example use of Forward Poly Model
9in mystic and PARK optimization frameworks.
10(built for mystic "trunk" and with park-1.2)
11
12for help, type "python rosetta_parabola_example.py --help"
13"""
14
15from math import pi
16from numpy import array, real, conjugate
17import numpy
18
19try: # check if park is installed
20    import park
21   #import park.parksnob
22    import park.parkde
23    Model = park.Model
24    __park = True
25except ImportError:
26    Model = object
27    __park = False
28
29
30def ForwardPolyFactory(params):
31    a,b,c = params
32    def forward_poly(x):
33        """ x should be a 1D (1 by N) numpy array """
34        return array((a*x*x + b*x + c))
35    return forward_poly
36
37def data(params):
38    fwd = ForwardPolyFactory(params)
39    x = (array([range(101)])-50.)[0]
40    return x,fwd(x)
41
42
43# --- Cost Function stuff ---
44# Here is the cost function
45def vec_cost_function(params):
46    return data(params)[1] - datapts
47
48# Here is the normed version
49def cost_function(params):
50    x = vec_cost_function(params)
51    return numpy.sum(real((conjugate(x)*x)))
52# --- Cost Function end ---
53
54
55# --- Plotting stuff ---
56import pylab
57def plot_sol(params,linestyle='b-'):
58    d = data(params)
59    pylab.plot(d[0],d[1],'%s'%linestyle,linewidth=2.0)
60    pylab.axis(plotview)
61    return
62# --- Plotting end ---
63
64
65# --- Call to Mystic ---
66def mystic_optimize(point):
67    from mystic.monitors import Monitor, VerboseMonitor
68    from mystic.solvers import NelderMeadSimplexSolver as fmin
69    from mystic.termination import CandidateRelativeTolerance as CRT
70    simplex, esow = VerboseMonitor(50), Monitor()
71    solver = fmin(len(point))
72    solver.SetInitialPoints(point)
73    min = [-100,-100,-100]; max = [100,100,100]
74    solver.SetStrictRanges(min,max)
75    solver.SetEvaluationMonitor(esow)
76    solver.SetGenerationMonitor(simplex)
77    solver.Solve(cost_function, CRT(1e-7,1e-7))
78    solution = solver.Solution()
79    return solution
80# --- Mystic end ---
81
82
83# --- Call to Park ---
84class PolyModel(Model):
85    """a park model:
86 - parameters are passed as named strings to set them as class attributes
87 - function that does the evaluation must be named "eval"
88 - __call__ generated that takes namestring and parameter-named keywords
89"""
90    parameters = ["a","b","c"]
91    def eval(self, x):
92        a = self.a
93        b = self.b
94        c = self.c
95        f = ForwardPolyFactory((a,b,c))
96        return f(x)
97    pass
98
99class Data1D(object):
100    """1d model data with the required park functions"""
101    def __init__(self,z):
102        self.z = z
103        return
104
105    def residuals(self,model):
106        x = (array([range(101)])-50.)[0]
107        return (model(x) - self.z).flatten()
108    pass
109
110
111def park_optimize(point):
112    # build the data instance
113    data1d = Data1D(datapts)
114
115    # build the model instance
116    a,b,c = point
117    model = PolyModel("mymodel",a=a,b=b,c=c)
118    # required to set bounds on the parameters
119    model.a = [-100,100]
120    model.b = [-100,100]
121    model.c = [-100,100]
122
123    # add a monitor, and set to print results to the console
124    handler=park.fitresult.ConsoleUpdate()
125
126    # select the fitter, and do the fit
127   #fitter=park.parksnob.Snobfit()
128    fitter=park.parkde.DiffEv()
129    # 'fit' requires a list of tuples of (model,data)
130    result=park.fit.fit([(model,data1d)],fitter=fitter,handler=handler)
131
132    # print results
133   #print result.calls     # print number of function calls
134   #result.print_summary() # print solution
135
136    # get the results back into a python object
137    solution = {}
138    for fitparam in result.parameters:
139        solution[fitparam.name] = fitparam.value
140    solution = [ solution['mymodel.a'],
141                 solution['mymodel.b'],
142                 solution['mymodel.c'] ]
143    return solution
144# --- Park end ---
145
146
147if __name__ == '__main__':
148    # parse user selection to solve with "mystic" [default] or "park"
149    from optparse import OptionParser
150    parser = OptionParser()
151    parser.add_option("-p","--park",action="store_true",dest="park",\
152                      default=False,help="solve with park (instead of mystic)")
153    parsed_opts, parsed_args = parser.parse_args()
154
155    # set plot window
156    from mystic.tools import getch
157    plotview = [-10,10, 0,100]
158
159    # Let the "actual parameters" be :
160    target = [1., 2., 1.]
161    print "Target: %s" % target
162
163    # Here is the "observed data"
164    x,datapts = data(target)
165    pylab.ion()
166    plot_sol(target,'r-')
167    pylab.draw()
168
169    # initial values
170    point = [100,-100,0]
171
172    # DO OPTIMIZATION STUFF HERE TO GET SOLUTION
173    if parsed_opts.park:
174        if __park:
175            print "Solving with park's DE optimizer..."
176            solution = park_optimize(point)
177        else:
178            print('This option requires park to be installed')
179            exit()
180    else:
181        print "Solving with mystic's fmin optimizer..."
182        solution = mystic_optimize(point)
183    print "Solved: %s" % solution
184
185    # plot the solution
186    plot_sol(solution,'g-')
187    pylab.draw()
188
189    getch()
190
191# End of file
Note: See TracBrowser for help on using the repository browser.