source: mystic/examples_other/qld_circle_dual.py @ 855

Revision 855, 1.9 KB checked in by mmckerns, 5 months ago (diff)

updated copyright to 2016

  • Property svn:executable set to *
Line 
1#!/usr/bin/env python
2#
3# Author: Patrick Hung (patrickh @caltech)
4# Copyright (c) 1997-2016 California Institute of Technology.
5# License: 3-clause BSD.  The full license text is available at:
6#  - http://mmckerns.github.io/project/mystic/browser/mystic/LICENSE
7"""
8Solve the dual form of test_circle.py.
9
10Currently, it uses a package called "qld" that I wrote but not in
11the repo. yet. It wraps IQP from bell-labs. (code not GPL and has export
12restrictions.)
13"""
14
15from numpy import *
16import pylab
17from test_circle import sparse_circle, sv, x0, y0, R0
18getpoints = sparse_circle.forward
19import qld
20
21def getobjective(H,f, x):
22    return 0.5 * dot(dot(x,H),x) + dot(f,x)
23
24def chop(x):
25    if abs(x) > 1e-6:
26        return x
27    else:
28        return 0
29
30def round(x):
31    return array([chop(y) for y in x])
32
33
34def plot(xy, sv, x0, y0, R0, center, R):
35    import pylab
36    pylab.plot(xy[:,0],xy[:,1],'k+')
37    pylab.plot(xy[sv,0],xy[sv,1],'ro')
38    theta = arange(0, 2*pi, 0.02)
39    pylab.plot([center[0]],[center[1]],'bo')
40    pylab.plot([xy[sv0,0], center[0]],[xy[sv0,1], center[1]],'r--')
41    pylab.plot(R0 * cos(theta)+x0, R0*sin(theta)+y0, 'r-',linewidth=2)
42    pylab.plot(R * cos(theta)+center[0], R*sin(theta)+center[1], 'b-',linewidth=2)
43    pylab.axis('equal')
44    pylab.show()
45
46
47if __name__ == '__main__':
48    npt = 20
49    from test_circle import xy
50    npt1 = xy.shape[0]
51    if npt is not npt1:
52        xy = getpoints((x0,y0,R0),npt)
53    else:
54        pass
55    Q = dot(xy, transpose(xy))
56    f = -diag(Q)+10
57    H = Q*2
58    A = ones((1,npt))
59    b = ones(1)
60    x = qld.quadprog2(H, f, None, None, A, b, zeros(npt), ones(npt))
61
62    center = dot(x,xy)
63    print "center: " , center
64    # find support vectors (find numpy way please)
65   
66    sv = []
67    for i,v in enumerate(x):
68       if v > 0.001: sv.append(i)
69    sv0 = sv[0]
70   
71    print sv
72    R = linalg.norm(xy[sv0,:]-center)
73
74    plot(xy, sv, x0, y0, R0, center, R)
75
76# $Id$
77#
Note: See TracBrowser for help on using the repository browser.