Changeset 538


Ignore:
Timestamp:
08/01/12 10:44:36 (4 years ago)
Author:
mmckerns
Message:

added 'info' option for termination conditions

Location:
mystic
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • mystic/mystic/termination.py

    r536 r538  
    77import numpy 
    88from numpy import absolute 
    9 from __builtin__ import bool as _bool 
    109abs = absolute 
    1110Inf = numpy.Inf 
     
    2120cost[-1] <= tolerance""" 
    2221    doc = "VTR with %s" % {'tolerance':tolerance, 'target':target} 
    23     def _VTR(inst, bool=True): 
    24          if bool: bool = _bool 
    25          else: bool = lambda x:x 
    26          hist = inst.energy_history 
    27          if not len(hist): return bool(null) 
    28          if abs(hist[-1] - target) <= tolerance: msg = doc 
    29          else: msg = null 
    30          return bool(msg) 
     22    def _VTR(inst, info=False): 
     23        if info: info = lambda x:x 
     24        else: info = bool 
     25        hist = inst.energy_history 
     26        if not len(hist): return info(null) 
     27        if abs(hist[-1] - target) <= tolerance: return info(doc) 
     28        return info(null) 
    3129   #_VTR.__doc__ = "%s(**%s)" % tuple(doc.split(" with ")) 
    3230    _VTR.__doc__ = doc 
     
    3937    doc = "ChangeOverGeneration with %s" % {'tolerance':tolerance, 
    4038                                            'generations':generations} 
    41     def _ChangeOverGeneration(inst): 
    42          hist = inst.energy_history 
    43          lg = len(hist) 
    44          if lg <= generations: return "" 
    45          if (hist[-generations]-hist[-1]) <= tolerance: return doc 
    46          return "" 
     39    def _ChangeOverGeneration(inst, info=False): 
     40        if info: info = lambda x:x 
     41        else: info = bool 
     42        hist = inst.energy_history 
     43        lg = len(hist) 
     44        if lg <= generations: return info(null) 
     45        if (hist[-generations]-hist[-1]) <= tolerance: return info(doc) 
     46        return info(null) 
    4747    _ChangeOverGeneration.__doc__ = doc 
    4848    return _ChangeOverGeneration 
     
    5555    doc = "NormalizedChangeOverGeneration with %s" % {'tolerance':tolerance, 
    5656                                                      'generations':generations} 
    57     def _NormalizedChangeOverGeneration(inst): 
    58          hist = inst.energy_history 
    59          lg = len(hist) 
    60          if lg <= generations: return "" 
    61          diff = tolerance*(abs(hist[-generations])+abs(hist[-1])) + eta 
    62          if 2.0*(hist[-generations]-hist[-1]) <= diff: return doc 
    63          return "" 
     57    def _NormalizedChangeOverGeneration(inst, info=False): 
     58        if info: info = lambda x:x 
     59        else: info = bool 
     60        hist = inst.energy_history 
     61        lg = len(hist) 
     62        if lg <= generations: return info(null) 
     63        diff = tolerance*(abs(hist[-generations])+abs(hist[-1])) + eta 
     64        if 2.0*(hist[-generations]-hist[-1]) <= diff: return info(doc) 
     65        return info(null) 
    6466    _NormalizedChangeOverGeneration.__doc__ = doc 
    6567    return _NormalizedChangeOverGeneration 
     
    7173    #NOTE: this termination expects nPop > 1 
    7274    doc = "CandidateRelativeTolerance with %s" % {'xtol':xtol, 'ftol':ftol} 
    73     def _CandidateRelativeTolerance(inst): 
    74          sim = numpy.array(inst.population) 
    75          fsim = numpy.array(inst.popEnergy) 
    76          if not len(fsim[1:]): 
    77              warn = "Warning: Invalid termination condition (nPop < 2)" 
    78              print warn 
    79              return warn 
    80          #   raise ValueError, "Invalid termination condition (nPop < 2)" 
    81          #FIXME: abs(inf - inf) will raise a warning... 
    82          errdict = numpy.seterr(invalid='ignore') #FIXME: turn off warning  
    83          answer = max(numpy.ravel(abs(sim[1:]-sim[0]))) <= xtol 
    84          answer = answer and max(abs(fsim[0]-fsim[1:])) <= ftol 
    85          numpy.seterr(invalid=errdict['invalid']) #FIXME: turn on warnings 
    86          if answer: return doc 
    87          return "" 
     75    def _CandidateRelativeTolerance(inst, info=False): 
     76        sim = numpy.array(inst.population) 
     77        fsim = numpy.array(inst.popEnergy) 
     78        if not len(fsim[1:]): 
     79            warn = "Warning: Invalid termination condition (nPop < 2)" 
     80            print warn 
     81            return warn 
     82        #   raise ValueError, "Invalid termination condition (nPop < 2)" 
     83        if info: info = lambda x:x 
     84        else: info = bool 
     85        #FIXME: abs(inf - inf) will raise a warning... 
     86        errdict = numpy.seterr(invalid='ignore') #FIXME: turn off warning  
     87        answer = max(numpy.ravel(abs(sim[1:]-sim[0]))) <= xtol 
     88        answer = answer and max(abs(fsim[0]-fsim[1:])) <= ftol 
     89        numpy.seterr(invalid=errdict['invalid']) #FIXME: turn on warnings 
     90        if answer: return info(doc) 
     91        return info(null) 
    8892    _CandidateRelativeTolerance.__doc__ = doc 
    8993    return _CandidateRelativeTolerance 
     
    9498sum(abs(last_params - current_params)) <= tolerance""" 
    9599    doc = "SolutionImprovement with %s" % {'tolerance':tolerance} 
    96     def _SolutionImprovement(inst): 
     100    def _SolutionImprovement(inst, info=False): 
     101        if info: info = lambda x:x 
     102        else: info = bool 
    97103        best = numpy.array(inst.bestSolution) 
    98104        trial = numpy.array(inst.trialSolution) 
    99105        update = best - trial #XXX: if inf - inf ? 
    100106        answer = numpy.add.reduce(abs(update)) <= tolerance 
    101         if answer: return doc 
    102         return "" 
     107        if answer: return info(doc) 
     108        return info(null) 
    103109    _SolutionImprovement.__doc__ = doc 
    104110    return _SolutionImprovement 
     
    114120    doc = "NormalizedCostTarget with %s" % {'fval':fval, 'tolerance':tolerance, 
    115121                                            'generations':generations} 
    116     def _NormalizedCostTarget(inst): 
    117          if generations and fval == None: 
    118              hist = inst.energy_history 
    119              lg = len(hist) 
    120              #XXX: throws error when hist is shorter than generations ? 
    121              if lg > generations and (hist[-generations]-hist[-1]) <= 0: 
    122                  return doc 
    123              return "" 
    124          if not generations and fval == None: return doc 
    125          if abs(inst.bestEnergy-fval) <= abs(tolerance * fval): return doc 
    126          return "" 
     122    def _NormalizedCostTarget(inst, info=False): 
     123        if info: info = lambda x:x 
     124        else: info = bool 
     125        if generations and fval == None: 
     126            hist = inst.energy_history 
     127            lg = len(hist) 
     128            #XXX: throws error when hist is shorter than generations ? 
     129            if lg > generations and (hist[-generations]-hist[-1]) <= 0: 
     130                return info(doc) 
     131            return info(null) 
     132        if not generations and fval == None: return info(doc) 
     133        if abs(inst.bestEnergy-fval) <= abs(tolerance * fval): return info(doc) 
     134        return info(null) 
    127135    _NormalizedCostTarget.__doc__ = doc 
    128136    return _NormalizedCostTarget 
     
    137145                                               'generations':generations, 
    138146                                               'target':target} 
    139     def _VTRChangeOverGeneration(inst): 
    140          hist = inst.energy_history 
    141          lg = len(hist) 
    142          #XXX: throws error when hist is shorter than generations ? 
    143          if (lg > generations and (hist[-generations]-hist[-1]) <= gtol)\ 
    144                 or ( abs(hist[-1] - target) <= ftol ): return doc 
    145          return "" 
     147    def _VTRChangeOverGeneration(inst, info=False): 
     148        if info: info = lambda x:x 
     149        else: info = bool 
     150        hist = inst.energy_history 
     151        lg = len(hist) 
     152        #XXX: throws error when hist is shorter than generations ? 
     153        if (lg > generations and (hist[-generations]-hist[-1]) <= gtol)\ 
     154               or ( abs(hist[-1] - target) <= ftol ): return info(doc) 
     155        return info(null) 
    146156    _VTRChangeOverGeneration.__doc__ = doc 
    147157    return _VTRChangeOverGeneration 
     
    152162abs(params - params[0]) <= tolerance""" 
    153163    doc = "PopulationSpread with %s" % {'tolerance':tolerance} 
    154     def _PopulationSpread(inst): 
    155          sim = numpy.array(inst.population) 
    156          #if not len(sim[1:]): 
    157          #    print "Warning: Invalid termination condition (nPop < 2)" 
    158          #    return True 
    159          if numpy.all(abs(sim - sim[0]) <= abs(tolerance * sim[0])): return doc 
    160          return "" 
     164    def _PopulationSpread(inst, info=False): 
     165        if info: info = lambda x:x 
     166        else: info = bool 
     167        sim = numpy.array(inst.population) 
     168        #if not len(sim[1:]): 
     169        #    print "Warning: Invalid termination condition (nPop < 2)" 
     170        #    return True 
     171        if numpy.all(abs(sim - sim[0]) <= abs(tolerance * sim[0])): return info(doc) 
     172        return info(null) 
    161173    _PopulationSpread.__doc__ = doc 
    162174    return _PopulationSpread 
     
    167179sum( abs(gradient)**norm )**(1.0/norm) <= tolerance""" 
    168180    doc = "GradientNormTolerance with %s" % {'tolerance':tolerance, 'norm':norm} 
    169     def _GradientNormTolerance(inst): 
     181    def _GradientNormTolerance(inst, info=False): 
    170182        try: 
    171183            gfk = inst.gfk #XXX: need to ensure that gfk is an array ? 
     
    174186            print warn 
    175187            return warn 
     188        if info: info = lambda x:x 
     189        else: info = bool 
    176190        if norm == Inf: 
    177191            gnorm = numpy.amax(abs(gfk)) 
     
    182196           #XXX: as norm < -large, gnorm approaches amin(abs(gfk)) --> then -inf 
    183197            gnorm = numpy.sum(abs(gfk)**norm,axis=0)**(1.0/norm) 
    184         if gnorm <= tolerance: return doc 
    185         return "" 
     198        if gnorm <= tolerance: return info(doc) 
     199        return info(null) 
    186200    _GradientNormTolerance.__doc__ = doc 
    187201    return _GradientNormTolerance 
  • mystic/tests/test_termination.py

    r465 r538  
    1515from numpy import inf 
    1616 
    17 def test_terminators(test, func=lambda x:x[0], verbose=False): 
    18   print test(lambda x:"", func, verbose) #XXX: just print settings 
    19   print "VTR():", test(VTR(), func) 
    20   print "VTR(inf):", test(VTR(inf), func) 
    21   print "COG():", test(ChangeOverGeneration(), func) 
    22   print "COG(gen=5):", test(ChangeOverGeneration(generations=5), func) 
    23   print "NCOG():", test(NormalizedChangeOverGeneration(), func) 
    24   print "NCOG(gen=5):", test(NormalizedChangeOverGeneration(generations=5),func) 
    25   print "CTR():", test(CandidateRelativeTolerance(), func) 
    26   print "CTR(ftol=inf):", test(CandidateRelativeTolerance(ftol=inf), func) 
    27   print "CTR(inf):", test(CandidateRelativeTolerance(inf), func) 
    28   print "SI():", test(SolutionImprovement(), func) 
    29   print "SI(inf):", test(SolutionImprovement(inf), func) 
    30   print "NCT():", test(NormalizedCostTarget(), func) 
    31   print "NCT(gen=5):", test(NormalizedCostTarget(generations=5), func) 
    32   print "NCT(gen=None):", test(NormalizedCostTarget(generations=None), func) 
    33   print "NCT(inf,inf):", test(NormalizedCostTarget(inf,inf), func) 
    34   print "VCOG():", test(VTRChangeOverGeneration(), func) 
    35   print "VCOG(gen=5):", test(VTRChangeOverGeneration(generations=5), func) 
    36   print "VCOG(inf):", test(VTRChangeOverGeneration(inf), func) 
    37   print "PS():", test(PopulationSpread(), func) 
    38   print "PS(inf):", test(PopulationSpread(inf), func) 
    39  #print "GNT():", test(GradientNormTolerance(), func) 
     17def test_terminators(test, func=lambda x:x[0], info=False, verbose=False): 
     18  print test(lambda x,y:"", func, info, verbose) #XXX: just print settings 
     19  print "VTR():", test(VTR(), func, info) 
     20  print "VTR(inf):", test(VTR(inf), func, info) 
     21  print "COG():", test(ChangeOverGeneration(), func, info) 
     22  print "COG(gen=5):", test(ChangeOverGeneration(generations=5), func, info) 
     23  print "NCOG():", test(NormalizedChangeOverGeneration(), func, info) 
     24  print "NCOG(gen=5):", test(NormalizedChangeOverGeneration(generations=5), func, info) 
     25  print "CTR():", test(CandidateRelativeTolerance(), func, info) 
     26  print "CTR(ftol=inf):", test(CandidateRelativeTolerance(ftol=inf), func, info) 
     27  print "CTR(inf):", test(CandidateRelativeTolerance(inf), func, info) 
     28  print "SI():", test(SolutionImprovement(), func, info) 
     29  print "SI(inf):", test(SolutionImprovement(inf), func, info) 
     30  print "NCT():", test(NormalizedCostTarget(), func, info) 
     31  print "NCT(gen=5):", test(NormalizedCostTarget(generations=5), func, info) 
     32  print "NCT(gen=None):", test(NormalizedCostTarget(generations=None), func, info) 
     33  print "NCT(inf,inf):", test(NormalizedCostTarget(inf,inf), func, info) 
     34  print "VCOG():", test(VTRChangeOverGeneration(), func, info) 
     35  print "VCOG(gen=5):", test(VTRChangeOverGeneration(generations=5), func, info) 
     36  print "VCOG(inf):", test(VTRChangeOverGeneration(inf), func, info) 
     37  print "PS():", test(PopulationSpread(), func, info) 
     38  print "PS(inf):", test(PopulationSpread(inf), func, info) 
     39 #print "GNT():", test(GradientNormTolerance(), func, info) 
    4040  return 
    4141 
     
    4949    return 
    5050 
    51 def test01(terminate, func=lambda x:x[0], debug=False): 
     51def test01(terminate, func=lambda x:x[0], info=False, debug=False): 
    5252  from mystic.solvers import DifferentialEvolutionSolver2 as DE2 
    5353  solver = DE2(3,5) 
     
    5656  solver.Solve(func, VTR()) 
    5757  if debug: verbosity(solver) 
    58   return terminate(solver) 
     58  return terminate(solver, info) 
    5959 
    60 def test02(terminate, func=lambda x:x[0], debug=False): 
     60def test02(terminate, func=lambda x:x[0], info=False, debug=False): 
    6161  from mystic.solvers import DifferentialEvolutionSolver2 as DE2 
    6262 #solver = DE2(3,1) #Solver throws ValueError "sample larger than population" 
     
    6767  solver.Solve(func, VTR()) 
    6868  if debug: verbosity(solver) 
    69   return terminate(solver) 
     69  return terminate(solver, info) 
    7070 
    71 def test03(terminate, func=lambda x:x[0], debug=False): 
     71def test03(terminate, func=lambda x:x[0], info=False, debug=False): 
    7272  from mystic.solvers import DifferentialEvolutionSolver as DE 
    7373  solver = DE(3,5) 
     
    7676  solver.Solve(func, VTR()) 
    7777  if debug: verbosity(solver) 
    78   return terminate(solver) 
     78  return terminate(solver, info) 
    7979 
    80 def test04(terminate, func=lambda x:x[0], debug=False): 
     80def test04(terminate, func=lambda x:x[0], info=False, debug=False): 
    8181  from mystic.solvers import DifferentialEvolutionSolver as DE 
    8282  solver = DE(1,5) 
     
    8585  solver.Solve(func, VTR()) 
    8686  if debug: verbosity(solver) 
    87   return terminate(solver) 
     87  return terminate(solver, info) 
    8888 
    89 def test05(terminate, func=lambda x:x[0], debug=False): 
     89def test05(terminate, func=lambda x:x[0], info=False, debug=False): 
    9090  from mystic.solvers import NelderMeadSimplexSolver as NM 
    9191  solver = NM(3) 
     
    9494  solver.Solve(func, VTR()) 
    9595  if debug: verbosity(solver) 
    96   return terminate(solver) 
     96  return terminate(solver, info) 
    9797 
    98 def test06(terminate, func=lambda x:x[0], debug=False): 
     98def test06(terminate, func=lambda x:x[0], info=False, debug=False): 
    9999  from mystic.solvers import NelderMeadSimplexSolver as NM 
    100100  solver = NM(1) 
     
    103103  solver.Solve(func, VTR()) 
    104104  if debug: verbosity(solver) 
    105   return terminate(solver) 
     105  return terminate(solver, info) 
    106106 
    107 def test07(terminate, func=lambda x:x[0], debug=False): 
     107def test07(terminate, func=lambda x:x[0], info=False, debug=False): 
    108108  from mystic.solvers import PowellDirectionalSolver as PDS 
    109109  solver = PDS(3) 
     
    112112  solver.Solve(func, VTR()) 
    113113  if debug: verbosity(solver) 
    114   return terminate(solver) 
     114  return terminate(solver, info) 
    115115 
    116 def test08(terminate, func=lambda x:x[0], debug=False): 
     116def test08(terminate, func=lambda x:x[0], info=False, debug=False): 
    117117  from mystic.solvers import PowellDirectionalSolver as PDS 
    118118  solver = PDS(1) 
     
    121121  solver.Solve(func, VTR()) 
    122122  if debug: verbosity(solver) 
    123   return terminate(solver) 
     123  return terminate(solver, info) 
    124124 
    125125 
    126126if __name__ == "__main__": 
    127127  verbose = True 
     128  info = False 
    128129  """NOTES: For x:x[0], test01-test04 returns either lists or floats; 
    129130while test05-test06 returns a ndarray for population, popEnergy, bestSolution; 
     
    150151 #function = lambda x:-inf 
    151152 
    152  #test_terminators(test01,function,verbose) 
    153  #test_terminators(test02,function,verbose) 
    154  #test_terminators(test03,function,verbose) 
    155  #test_terminators(test04,function,verbose) 
    156  #test_terminators(test05,function,verbose) 
    157  #test_terminators(test06,function,verbose) 
    158   test_terminators(test07,function,verbose) 
    159  #test_terminators(test08,function,verbose) 
     153 #test_terminators(test01,function,info,verbose) 
     154 #test_terminators(test02,function,info,verbose) 
     155 #test_terminators(test03,function,info,verbose) 
     156 #test_terminators(test04,function,info,verbose) 
     157 #test_terminators(test05,function,info,verbose) 
     158 #test_terminators(test06,function,info,verbose) 
     159  test_terminators(test07,function,info,verbose) 
     160 #test_terminators(test08,function,info,verbose) 
    160161 
    161162# EOF 
Note: See TracChangeset for help on using the changeset viewer.