Package aloha :: Module aloha_lib
[hide private]
[frames] | no frames]

Source Code for Module aloha.aloha_lib

   1  ################################################################################ 
   2  # 
   3  # Copyright (c) 2010 The MadGraph5_aMC@NLO Development team and Contributors 
   4  # 
   5  # This file is a part of the MadGraph5_aMC@NLO project, an application which  
   6  # automatically generates Feynman diagrams and matrix elements for arbitrary 
   7  # high-energy processes in the Standard Model and beyond. 
   8  # 
   9  # It is subject to the MadGraph5_aMC@NLO license which should accompany this  
  10  # distribution. 
  11  # 
  12  # For more information, visit madgraph.phys.ucl.ac.be and amcatnlo.web.cern.ch 
  13  # 
  14  ################################################################################ 
  15  ##   Diagram of Class 
  16  ## 
  17  ##    Variable (vartype:0)<--- ScalarVariable  
  18  ##                          | 
  19  ##                          +- LorentzObject  
  20  ##                                 
  21  ## 
  22  ##    list <--- AddVariable (vartype :1)    
  23  ##            
  24  ##    array <--- MultVariable  <--- MultLorentz (vartype:2)  
  25  ##            
  26  ##    list <--- LorentzObjectRepresentation (vartype :4) <-- ConstantObject 
  27  ##                                                               (vartype:5) 
  28  ## 
  29  ##    FracVariable (vartype:3) 
  30  ## 
  31  ##    MultContainer (vartype:6) 
  32  ## 
  33  ################################################################################ 
  34  ## 
  35  ##   Variable is in fact Factory wich adds a references to the variable name 
  36  ##   Into the KERNEL (Of class Computation) instantiate a real variable object 
  37  ##   (of class C_Variable, DVariable for complex/real) and return a MUltVariable 
  38  ##   with a single element. 
  39  ## 
  40  ##   Lorentz Object works in the same way. 
  41  ## 
  42  ################################################################################ 
  43   
  44   
  45  from __future__ import division 
  46  from array import array 
  47  import collections 
  48  from fractions import Fraction 
  49  import numbers 
  50  import re 
  51  import aloha # define mode of writting 
  52   
  53  try: 
  54      import madgraph.various.misc as misc 
  55  except Exception: 
  56      import aloha.misc as misc 
57 58 -class defaultdict(collections.defaultdict):
59
60 - def __call__(self, *args):
61 return defaultdict(int)
62
63 -class Computation(dict):
64 """ a class to encapsulate all computation. Limit side effect """ 65
66 - def __init__(self):
67 self.objs = [] 68 self.use_tag = set() 69 self.id = -1 70 self.reduced_expr = {} 71 self.fct_expr = {} 72 self.reduced_expr2 = {} 73 self.inverted_fct = {} 74 self.has_pi = False # logical to check if pi is used in at least one fct 75 self.unknow_fct = [] 76 dict.__init__(self)
77
78 - def clean(self):
79 self.__init__() 80 self.clear()
81
82 - def add(self, name, obj):
83 self.id += 1 84 self.objs.append(obj) 85 self[name] = self.id 86 return self.id
87
88 - def get(self, name):
89 return self.objs[self[name]]
90
91 - def add_tag(self, tag):
92 self.use_tag.update(tag)
93
94 - def get_ids(self, variables):
95 """return the list of identification number associate to the 96 given variables names. If a variable didn't exists, create it (in complex). 97 """ 98 out = [] 99 for var in variables: 100 try: 101 id = self[var] 102 except KeyError: 103 assert var not in ['M','W'] 104 id = Variable(var).get_id() 105 out.append(id) 106 return out
107 108
109 - def add_expression_contraction(self, expression):
110 111 str_expr = str(expression) 112 if str_expr in self.reduced_expr: 113 out, tag = self.reduced_expr[str_expr] 114 self.add_tag((tag,)) 115 return out 116 if expression == 0: 117 return 0 118 new_2 = expression.simplify() 119 if new_2 == 0: 120 return 0 121 # Add a new variable 122 tag = 'TMP%s' % len(self.reduced_expr) 123 new = Variable(tag) 124 self.reduced_expr[str_expr] = [new, tag] 125 new_2 = new_2.factorize() 126 self.reduced_expr2[tag] = new_2 127 self.add_tag((tag,)) 128 #self.unknow_fct = [] 129 #return expression 130 return new
131 132 known_fct = ['/', 'log', 'pow', 'sin', 'cos', 'asin', 'acos', 'tan', 'cot', 'acot', 133 'theta_function', 'exp']
134 - def add_function_expression(self, fct_tag, *args):
135 136 if not (fct_tag.startswith('cmath.') or fct_tag in self.known_fct or 137 (fct_tag, len(args)) in self.unknow_fct): 138 self.unknow_fct.append( (fct_tag, len(args)) ) 139 140 argument = [] 141 for expression in args: 142 if isinstance(expression, (MultLorentz, AddVariable, LorentzObject)): 143 try: 144 expr = expression.expand().get_rep([0]) 145 except KeyError, error: 146 if error.args != ((0,),): 147 raise 148 else: 149 raise aloha.ALOHAERROR, '''Error in input format. 150 Argument of function (or denominator) should be scalar. 151 We found %s''' % expression 152 new = expr.simplify() 153 if not isinstance(new, numbers.Number): 154 new = new.factorize() 155 argument.append(new) 156 else: 157 argument.append(expression) 158 for arg in argument: 159 val = re.findall(r'''\bFCT(\d*)\b''', str(arg)) 160 for v in val: 161 self.add_tag(('FCT%s' % v,)) 162 163 # check if the function is a pure numerical function. 164 if (fct_tag.startswith('cmath.') or fct_tag in self.known_fct) and \ 165 all(isinstance(x, numbers.Number) for x in argument): 166 import cmath 167 if fct_tag.startswith('cmath.'): 168 module = '' 169 else: 170 module = 'cmath.' 171 try: 172 return str(eval("%s%s(%s)" % (module,fct_tag, ','.join(`x` for x in argument)))) 173 except Exception, error: 174 print error 175 print "cmath.%s(%s)" % (fct_tag, ','.join(`x` for x in argument)) 176 if str(fct_tag)+str(argument) in self.inverted_fct: 177 tag = self.inverted_fct[str(fct_tag)+str(argument)] 178 v = tag.split('(')[1][:-1] 179 self.add_tag(('FCT%s' % v,)) 180 return tag 181 else: 182 id = len(self.fct_expr) 183 tag = 'FCT%s' % id 184 self.inverted_fct[str(fct_tag)+str(argument)] = 'FCT(%s)' % id 185 self.fct_expr[tag] = (fct_tag, argument) 186 self.reduced_expr2[tag] = (fct_tag, argument) 187 self.add_tag((tag,)) 188 return 'FCT(%s)' % id
189 190 KERNEL = Computation()
191 192 #=============================================================================== 193 # AddVariable 194 #=============================================================================== 195 -class AddVariable(list):
196 """ A list of Variable/ConstantObject/... This object represent the operation 197 between those object.""" 198 199 #variable to fastenize class recognition 200 vartype = 1 201
202 - def __init__(self, old_data=[], prefactor=1):
203 """ initialization of the object with default value """ 204 205 self.prefactor = prefactor 206 #self.tag = set() 207 list.__init__(self, old_data)
208
209 - def simplify(self):
210 """ apply rule of simplification """ 211 212 # deal with one length object 213 if len(self) == 1: 214 return self.prefactor * self[0].simplify() 215 constant = 0 216 items = {} 217 pos = -1 218 for term in self[:]: 219 pos += 1 # current position in the real self 220 if not hasattr(term, 'vartype'): 221 if isinstance(term, dict): 222 # allow term of type{(0,):x} 223 assert term.values() == [0] 224 term = term[(0,)] 225 constant += term 226 del self[pos] 227 pos -= 1 228 continue 229 tag = tuple(term.sort()) 230 if tag in items: 231 orig_prefac = items[tag].prefactor # to assume to zero 0.33333 -0.3333 232 items[tag].prefactor += term.prefactor 233 if items[tag].prefactor and \ 234 abs(items[tag].prefactor) / (abs(orig_prefac)+abs(term.prefactor)) < 1e-8: 235 items[tag].prefactor = 0 236 del self[pos] 237 pos -=1 238 else: 239 items[tag] = term.__class__(term, term.prefactor) 240 self[pos] = items[tag] 241 242 # get the optimized prefactor 243 countprefact = defaultdict(int) 244 nbplus, nbminus = 0,0 245 if constant not in [0, 1,-1]: 246 countprefact[constant] += 1 247 if constant.real + constant.imag > 0: 248 nbplus += 1 249 else: 250 nbminus += 1 251 252 for var in items.values(): 253 if var.prefactor == 0: 254 self.remove(var) 255 else: 256 nb = var.prefactor 257 if nb in [1,-1]: 258 continue 259 countprefact[abs(nb)] +=1 260 if nb.real + nb.imag > 0: 261 nbplus += 1 262 else: 263 nbminus += 1 264 if countprefact and max(countprefact.values()) >1: 265 fact_prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0] 266 else: 267 fact_prefactor = 1 268 if nbplus < nbminus: 269 fact_prefactor *= -1 270 self.prefactor *= fact_prefactor 271 272 if fact_prefactor != 1: 273 for i,a in enumerate(self): 274 try: 275 a.prefactor /= fact_prefactor 276 except AttributeError: 277 self[i] /= fact_prefactor 278 279 if constant: 280 self.append(constant/ fact_prefactor ) 281 282 # deal with one/zero length object 283 varlen = len(self) 284 if varlen == 1: 285 if hasattr(self[0], 'vartype'): 286 return self.prefactor * self[0].simplify() 287 else: 288 #self[0] is a number 289 return self.prefactor * self[0] 290 elif varlen == 0: 291 return 0 #ConstantObject() 292 return self
293
294 - def split(self, variables_id):
295 """return a dict with the key being the power associated to each variables 296 and the value being the object remaining after the suppression of all 297 the variable""" 298 299 out = defaultdict(int) 300 for obj in self: 301 for key, value in obj.split(variables_id).items(): 302 out[key] += self.prefactor * value 303 return out
304
305 - def contains(self, variables):
306 """returns true if one of the variables is in the expression""" 307 308 return any((v in obj for obj in self for v in variables ))
309 310
311 - def get_all_var_names(self):
312 313 out = [] 314 for term in self: 315 if hasattr(term, 'get_all_var_names'): 316 out += term.get_all_var_names() 317 return out
318 319 320
321 - def replace(self, id, expression):
322 """replace one object (identify by his id) by a given expression. 323 Note that expression cann't be zero. 324 Note that this should be canonical form (this should contains ONLY 325 MULTVARIABLE) --so this should be called before a factorize. 326 """ 327 new = self.__class__() 328 329 for obj in self: 330 assert isinstance(obj, MultVariable) 331 tmp = obj.replace(id, expression) 332 new += tmp 333 new.prefactor = self.prefactor 334 return new
335 336
337 - def expand(self, veto=[]):
338 """Pass from High level object to low level object""" 339 340 if not self: 341 return self 342 if self.prefactor == 1: 343 new = self[0].expand(veto) 344 else: 345 new = self.prefactor * self[0].expand(veto) 346 347 for item in self[1:]: 348 if self.prefactor == 1: 349 try: 350 new += item.expand(veto) 351 except AttributeError: 352 new = new + item 353 354 else: 355 new += (self.prefactor) * item.expand(veto) 356 return new
357
358 - def __mul__(self, obj):
359 """define the multiplication of 360 - a AddVariable with a number 361 - a AddVariable with an AddVariable 362 other type of multiplication are define via the symmetric operation base 363 on the obj class.""" 364 365 366 if not hasattr(obj, 'vartype'): # obj is a number 367 if not obj: 368 return 0 369 return self.__class__(self, self.prefactor*obj) 370 elif obj.vartype == 1: # obj is an AddVariable 371 new = self.__class__([],self.prefactor * obj.prefactor) 372 new[:] = [i*j for i in self for j in obj] 373 return new 374 else: 375 #force the program to look at obj + self 376 return NotImplemented
377
378 - def __imul__(self, obj):
379 """define the multiplication of 380 - a AddVariable with a number 381 - a AddVariable with an AddVariable 382 other type of multiplication are define via the symmetric operation base 383 on the obj class.""" 384 385 if not hasattr(obj, 'vartype'): # obj is a number 386 if not obj: 387 return 0 388 self.prefactor *= obj 389 return self 390 elif obj.vartype == 1: # obj is an AddVariable 391 new = self.__class__([], self.prefactor * obj.prefactor) 392 new[:] = [i*j for i in self for j in obj] 393 return new 394 else: 395 #force the program to look at obj + self 396 return NotImplemented
397
398 - def __neg__(self):
399 self.prefactor *= -1 400 return self
401
402 - def __add__(self, obj):
403 """Define all the different addition.""" 404 405 if not hasattr(obj, 'vartype'): 406 if not obj: # obj is zero 407 return self 408 new = self.__class__(self, self.prefactor) 409 new.append(obj/self.prefactor) 410 return new 411 elif obj.vartype == 2: # obj is a MultVariable 412 new = AddVariable(self, self.prefactor) 413 if self.prefactor == 1: 414 new.append(obj) 415 else: 416 new.append((1/self.prefactor)*obj) 417 return new 418 elif obj.vartype == 1: # obj is a AddVariable 419 new = AddVariable(self, self.prefactor) 420 for item in obj: 421 new.append(obj.prefactor/self.prefactor * item) 422 return new 423 else: 424 #force to look at obj + self 425 return NotImplemented
426
427 - def __iadd__(self, obj):
428 """Define all the different addition.""" 429 430 if not hasattr(obj, 'vartype'): 431 if not obj: # obj is zero 432 return self 433 self.append(obj/self.prefactor) 434 return self 435 elif obj.vartype == 2: # obj is a MultVariable 436 if self.prefactor == 1: 437 self.append(obj) 438 else: 439 self.append((1/self.prefactor)*obj) 440 return self 441 elif obj.vartype == 1: # obj is a AddVariable 442 for item in obj: 443 self.append(obj.prefactor/self.prefactor * item) 444 return self 445 else: 446 #force to look at obj + self 447 return NotImplemented
448
449 - def __sub__(self, obj):
450 return self + (-1) * obj
451
452 - def __rsub__(self, obj):
453 return (-1) * self + obj
454 455 __radd__ = __add__ 456 __rmul__ = __mul__ 457 458
459 - def __div__(self, obj):
460 return self.__mul__(1/obj)
461 462 __truediv__ = __div__ 463
464 - def __rdiv__(self, obj):
465 return self.__rmult__(1/obj)
466
467 - def __str__(self):
468 text = '' 469 if self.prefactor != 1: 470 text += str(self.prefactor) + ' * ' 471 text += '( ' 472 text += ' + '.join([str(item) for item in self]) 473 text += ' )' 474 return text
475
476 - def count_term(self):
477 # Count the number of appearance of each variable and find the most 478 #present one in order to factorize her 479 count = defaultdict(int) 480 correlation = defaultdict(defaultdict(int)) 481 for i,term in enumerate(self): 482 try: 483 set_term = set(term) 484 except TypeError: 485 #constant term 486 continue 487 for val1 in set_term: 488 count[val1] +=1 489 # allow to find optimized factorization for identical count 490 for val2 in set_term: 491 correlation[val1][val2] += 1 492 493 maxnb = max(count.values()) if count else 0 494 possibility = [v for v,val in count.items() if val == maxnb] 495 if maxnb == 1: 496 return 1, None 497 elif len(possibility) == 1: 498 return maxnb, possibility[0] 499 #import random 500 #return maxnb, random.sample(possibility,1)[0] 501 502 #return maxnb, possibility[0] 503 max_wgt, maxvar = 0, None 504 for var in possibility: 505 wgt = sum(w**2 for w in correlation[var].values())/len(correlation[var]) 506 if wgt > max_wgt: 507 maxvar = var 508 max_wgt = wgt 509 str_maxvar = str(KERNEL.objs[var]) 510 elif wgt == max_wgt: 511 # keep the one with the lowest string expr 512 new_str = str(KERNEL.objs[var]) 513 if new_str < str_maxvar: 514 maxvar = var 515 str_maxvar = new_str 516 return maxnb, maxvar
517
518 - def factorize(self):
519 """ try to factorize as much as possible the expression """ 520 521 max, maxvar = self.count_term() 522 if max <= 1: 523 #no factorization possible 524 return self 525 else: 526 # split in MAXVAR * NEWADD + CONSTANT 527 newadd = AddVariable() 528 constant = AddVariable() 529 #fill NEWADD and CONSTANT 530 for term in self: 531 try: 532 term.remove(maxvar) 533 except Exception: 534 constant.append(term) 535 else: 536 if len(term): 537 newadd.append(term) 538 else: 539 newadd.append(term.prefactor) 540 newadd = newadd.factorize() 541 542 # optimize the prefactor 543 if isinstance(newadd, AddVariable): 544 countprefact = defaultdict(int) 545 nbplus, nbminus = 0,0 546 for nb in [a.prefactor for a in newadd if hasattr(a, 'prefactor')]: 547 countprefact[abs(nb)] +=1 548 if nb.real + nb.imag > 0: 549 nbplus += 1 550 else: 551 nbminus += 1 552 553 newadd.prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0] 554 if nbplus < nbminus: 555 newadd.prefactor *= -1 556 if newadd.prefactor != 1: 557 for i,a in enumerate(newadd): 558 try: 559 a.prefactor /= newadd.prefactor 560 except AttributeError: 561 newadd[i] /= newadd.prefactor 562 563 564 if len(constant) > 1: 565 constant = constant.factorize() 566 elif constant: 567 constant = constant[0] 568 else: 569 out = MultContainer([KERNEL.objs[maxvar], newadd]) 570 out.prefactor = self.prefactor 571 if newadd.prefactor != 1: 572 out.prefactor *= newadd.prefactor 573 newadd.prefactor = 1 574 return out 575 out = AddVariable([MultContainer([KERNEL.objs[maxvar], newadd]), constant], 576 self.prefactor) 577 return out
578
579 -class MultContainer(list):
580 581 vartype = 6 582
583 - def __init__(self,*args):
584 self.prefactor =1 585 list.__init__(self, *args)
586
587 - def __str__(self):
588 """ String representation """ 589 if self.prefactor !=1: 590 text = '(%s * %s)' % (self.prefactor, ' * '.join([str(t) for t in self])) 591 else: 592 text = '(%s)' % (' * '.join([str(t) for t in self])) 593 return text
594
595 - def factorize(self):
596 self[:] = [term.factorize() for term in self]
597
598 599 -class MultVariable(array):
600 """ A list of Variable with multiplication as operator between themselves. 601 Represented by array for speed optimization 602 """ 603 vartype=2 604 addclass = AddVariable 605
606 - def __new__(cls, old=[], prefactor=1):
607 return array.__new__(cls, 'i', old)
608 609
610 - def __init__(self, old=[], prefactor=1):
611 """ initialization of the object with default value """ 612 #array.__init__(self, 'i', old) <- done already in new !! 613 self.prefactor = prefactor 614 assert isinstance(self.prefactor, (float,int,long,complex))
615
616 - def get_id(self):
617 assert len(self) == 1 618 return self[0]
619
620 - def sort(self):
621 a = list(self) 622 a.sort() 623 self[:] = array('i',a) 624 return self
625
626 - def simplify(self):
627 """ simplify the product""" 628 if not len(self): 629 return self.prefactor 630 return self
631
632 - def split(self, variables_id):
633 """return a dict with the key being the power associated to each variables 634 and the value being the object remaining after the suppression of all 635 the variable""" 636 637 key = tuple([self.count(i) for i in variables_id]) 638 arg = [id for id in self if id not in variables_id] 639 self[:] = array('i', arg) 640 return SplitCoefficient([(key,self)])
641
642 - def replace(self, id, expression):
643 """replace one object (identify by his id) by a given expression. 644 Note that expression cann't be zero. 645 """ 646 assert hasattr(expression, 'vartype') , 'expression should be of type Add or Mult' 647 648 if expression.vartype == 1: # AddVariable 649 nb = self.count(id) 650 if not nb: 651 return self 652 for i in range(nb): 653 self.remove(id) 654 new = self 655 for i in range(nb): 656 new *= expression 657 return new 658 elif expression.vartype == 2: # MultLorentz 659 # be carefull about A -> A * B 660 nb = self.count(id) 661 for i in range(nb): 662 self.remove(id) 663 self.__imul__(expression) 664 return self 665 # elif expression.vartype == 0: # Variable 666 # new_id = expression.id 667 # assert new_id != id 668 # while 1: 669 # try: 670 # self.remove(id) 671 # except ValueError: 672 # break 673 # else: 674 # self.append(new_id) 675 # return self 676 else: 677 raise Exception, 'Cann\'t replace a Variable by %s' % type(expression)
678 679
680 - def get_all_var_names(self):
681 """return the list of variable used in this multiplication""" 682 return ['%s' % KERNEL.objs[n] for n in self]
683 684 685 686 #Defining rule of Multiplication
687 - def __mul__(self, obj):
688 """Define the multiplication with different object""" 689 690 if not hasattr(obj, 'vartype'): # should be a number 691 if obj: 692 return self.__class__(self, obj*self.prefactor) 693 else: 694 return 0 695 elif obj.vartype == 1: # obj is an AddVariable 696 new = obj.__class__([], self.prefactor*obj.prefactor) 697 old, self.prefactor = self.prefactor, 1 698 new[:] = [self * term for term in obj] 699 self.prefactor = old 700 return new 701 elif obj.vartype == 4: 702 return NotImplemented 703 704 return self.__class__(array.__add__(self, obj), self.prefactor * obj.prefactor)
705 706 __rmul__ = __mul__ 707
708 - def __imul__(self, obj):
709 """Define the multiplication with different object""" 710 711 if not hasattr(obj, 'vartype'): # should be a number 712 if obj: 713 self.prefactor *= obj 714 return self 715 else: 716 return 0 717 elif obj.vartype == 1: # obj is an AddVariable 718 new = obj.__class__([], self.prefactor * obj.prefactor) 719 self.prefactor = 1 720 new[:] = [self * term for term in obj] 721 return new 722 elif obj.vartype == 4: 723 return NotImplemented 724 725 self.prefactor *= obj.prefactor 726 return array.__iadd__(self, obj)
727
728 - def __pow__(self,value):
729 out = 1 730 for i in range(value): 731 out *= self 732 return out
733 734
735 - def __add__(self, obj):
736 """ define the adition with different object""" 737 738 if not obj: 739 return self 740 elif not hasattr(obj, 'vartype') or obj.vartype == 2: 741 new = self.addclass([self, obj]) 742 return new 743 else: 744 #call the implementation of addition implemented in obj 745 return NotImplemented
746 __radd__ = __add__ 747 __iadd__ = __add__ 748
749 - def __sub__(self, obj):
750 return self + (-1) * obj
751
752 - def __neg__(self):
753 self.prefactor *=-1 754 return self
755
756 - def __rsub__(self, obj):
757 return (-1) * self + obj
758
759 - def __idiv__(self,obj):
760 """ ONLY NUMBER DIVISION ALLOWED""" 761 assert not hasattr(obj, 'vartype') 762 self.prefactor /= obj 763 return self
764 765 __div__ = __idiv__ 766 __truediv__ = __div__ 767 768
769 - def __str__(self):
770 """ String representation """ 771 t = ['%s' % KERNEL.objs[n] for n in self] 772 if self.prefactor != 1: 773 text = '(%s * %s)' % (self.prefactor,' * '.join(t)) 774 else: 775 text = '(%s)' % (' * '.join(t)) 776 return text
777 778 __rep__ = __str__ 779
780 - def factorize(self):
781 return self
782
783 784 #=============================================================================== 785 # FactoryVar 786 #=============================================================================== 787 -class C_Variable(str):
788 vartype=0 789 type = 'complex'
790
791 -class R_Variable(str):
792 vartype=0 793 type = 'double'
794
795 -class ExtVariable(str):
796 vartype=0 797 type = 'parameter'
798
799 800 -class FactoryVar(object):
801 """This is the standard object for all the variable linked to expression. 802 """ 803 mult_class = MultVariable # The class for the multiplication 804
805 - def __new__(cls, name, baseclass, *args):
806 """Factory class return a MultVariable.""" 807 808 if name in KERNEL: 809 return cls.mult_class([KERNEL[name]]) 810 else: 811 obj = baseclass(name, *args) 812 id = KERNEL.add(name, obj) 813 obj.id = id 814 return cls.mult_class([id])
815
816 -class Variable(FactoryVar):
817
818 - def __new__(self, name, type=C_Variable):
819 return FactoryVar(name, type)
820
821 -class DVariable(FactoryVar):
822
823 - def __new__(self, name):
824 825 if aloha.complex_mass: 826 #some parameter are pass to complex 827 if name[0] in ['M','W'] or name.startswith('OM'): 828 return FactoryVar(name, C_Variable) 829 if aloha.loop_mode and name.startswith('P'): 830 return FactoryVar(name, C_Variable) 831 #Normal case: 832 return FactoryVar(name, R_Variable)
833
834 835 836 837 #=============================================================================== 838 # Object for Analytical Representation of Lorentz object (not scalar one) 839 #=============================================================================== 840 841 842 #=============================================================================== 843 # MultLorentz 844 #=============================================================================== 845 -class MultLorentz(MultVariable):
846 """Specific class for LorentzObject Multiplication""" 847 848 add_class = AddVariable # Define which class describe the addition 849
850 - def find_lorentzcontraction(self):
851 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining 852 the contraction in this Multiplication.""" 853 854 out = {} 855 len_mult = len(self) 856 # Loop over the element 857 for i, fact in enumerate(self): 858 # and over the indices of this element 859 for j in range(len(fact.lorentz_ind)): 860 # in order to compare with the other element of the multiplication 861 for k in range(i+1,len_mult): 862 fact2 = self[k] 863 try: 864 l = fact2.lorentz_ind.index(fact.lorentz_ind[j]) 865 except Exception: 866 pass 867 else: 868 out[(i, j)] = (k, l) 869 out[(k, l)] = (i, j) 870 return out
871
872 - def find_spincontraction(self):
873 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining 874 the contraction in this Multiplication.""" 875 876 out = {} 877 len_mult = len(self) 878 # Loop over the element 879 for i, fact in enumerate(self): 880 # and over the indices of this element 881 for j in range(len(fact.spin_ind)): 882 # in order to compare with the other element of the multiplication 883 for k in range(i+1, len_mult): 884 fact2 = self[k] 885 try: 886 l = fact2.spin_ind.index(fact.spin_ind[j]) 887 except Exception: 888 pass 889 else: 890 out[(i, j)] = (k, l) 891 out[(k, l)] = (i, j) 892 893 return out
894
895 - def neighboor(self, home):
896 """return one variable which are contracted with var and not yet expanded""" 897 898 for var in self.unused: 899 obj = KERNEL.objs[var] 900 if obj.has_component(home.lorentz_ind, home.spin_ind): 901 return obj 902 return None
903 904 905 906
907 - def expand(self, veto=[]):
908 """ expand each part of the product and combine them. 909 Try to use a smart order in order to minimize the number of uncontracted indices. 910 Veto forbids the use of sub-expression if it contains some of the variable in the 911 expression. Veto contains the id of the vetoed variables 912 """ 913 914 self.unused = self[:] # list of not expanded 915 # made in a list the interesting starting point for the computation 916 basic_end_point = [var for var in self if KERNEL.objs[var].contract_first] 917 product_term = [] #store result of intermediate chains 918 current = None # current point in the working chain 919 920 while self.unused: 921 #Loop untill we have expand everything 922 if not current: 923 # First we need to have a starting point 924 try: 925 # look in priority in basic_end_point (P/S/fermion/...) 926 current = basic_end_point.pop() 927 except Exception: 928 #take one of the remaining 929 current = self.unused.pop() 930 else: 931 #check that this one is not already use 932 if current not in self.unused: 933 current = None 934 continue 935 #remove of the unuse (usualy done in the pop) 936 self.unused.remove(current) 937 cur_obj = KERNEL.objs[current] 938 # initialize the new chain 939 product_term.append(cur_obj.expand()) 940 941 # We have a point -> find the next one 942 var_obj = self.neighboor(product_term[-1]) 943 # provide one term which is contracted with current and which is not 944 #yet expanded. 945 if var_obj: 946 product_term[-1] *= var_obj.expand() 947 cur_obj = var_obj 948 self.unused.remove(cur_obj.id) 949 continue 950 951 current = None 952 953 954 # Multiply all those current 955 # For Fermion/Vector only one can carry index. 956 out = self.prefactor 957 for fact in product_term[:]: 958 if hasattr(fact, 'vartype') and fact.lorentz_ind == fact.spin_ind == []: 959 scalar = fact.get_rep([0]) 960 if hasattr(scalar, 'vartype') and scalar.vartype == 1: 961 if not veto or not scalar.contains(veto): 962 scalar = scalar.simplify() 963 prefactor = 1 964 965 if hasattr(scalar, 'vartype') and scalar.prefactor not in [1,-1]: 966 prefactor = scalar.prefactor 967 scalar.prefactor = 1 968 new = KERNEL.add_expression_contraction(scalar) 969 fact.set_rep([0], prefactor * new) 970 out *= fact 971 return out
972
973 - def __copy__(self):
974 """ create a shadow copy """ 975 new = MultLorentz(self) 976 new.prefactor = self.prefactor 977 return new
978
979 #=============================================================================== 980 # LorentzObject 981 #=============================================================================== 982 -class LorentzObject(object):
983 """ A symbolic Object for All Helas object. All Helas Object Should 984 derivated from this class""" 985 986 contract_first = 0 987 mult_class = MultLorentz # The class for the multiplication 988 add_class = AddVariable # The class for the addition 989
990 - def __init__(self, name, lor_ind, spin_ind, tags=[]):
991 """ initialization of the object with default value """ 992 assert isinstance(lor_ind, list) 993 assert isinstance(spin_ind, list) 994 995 self.name = name 996 self.lorentz_ind = lor_ind 997 self.spin_ind = spin_ind 998 KERNEL.add_tag(set(tags))
999
1000 - def expand(self):
1001 """Expand the content information into LorentzObjectRepresentation.""" 1002 1003 try: 1004 return self.representation 1005 except Exception: 1006 self.create_representation() 1007 return self.representation
1008
1009 - def create_representation(self):
1010 raise self.VariableError("This Object %s doesn't have define representation" % self.__class__.__name__)
1011
1012 - def has_component(self, lor_list, spin_list):
1013 """check if this Lorentz Object have some of those indices""" 1014 1015 if any([id in self.lorentz_ind for id in lor_list]) or \ 1016 any([id in self.spin_ind for id in spin_list]): 1017 return True
1018 1019 1020
1021 - def __str__(self):
1022 return '%s' % self.name
1023
1024 -class FactoryLorentz(FactoryVar):
1025 """ A symbolic Object for All Helas object. All Helas Object Should 1026 derivated from this class""" 1027 1028 mult_class = MultLorentz # The class for the multiplication 1029 object_class = LorentzObject # Define How to create the basic object. 1030
1031 - def __new__(cls, *args):
1032 name = cls.get_unique_name(*args) 1033 return FactoryVar.__new__(cls, name, cls.object_class, *args)
1034 1035 @classmethod
1036 - def get_unique_name(cls, *args):
1037 """default way to have a unique name""" 1038 return '_L_%(class)s_%(args)s' % \ 1039 {'class':cls.__name__, 1040 'args': '_'.join(args) 1041 }
1042
1043 1044 #=============================================================================== 1045 # LorentzObjectRepresentation 1046 #=============================================================================== 1047 -class LorentzObjectRepresentation(dict):
1048 """A concrete representation of the LorentzObject.""" 1049 1050 vartype = 4 # Optimization for instance recognition 1051
1052 - class LorentzObjectRepresentationError(Exception):
1053 """Specify error for LorentzObjectRepresentation"""
1054
1055 - def __init__(self, representation, lorentz_indices, spin_indices):
1056 """ initialize the lorentz object representation""" 1057 1058 self.lorentz_ind = lorentz_indices #lorentz indices 1059 self.nb_lor = len(lorentz_indices) #their number 1060 self.spin_ind = spin_indices #spin indices 1061 self.nb_spin = len(spin_indices) #their number 1062 self.nb_ind = self.nb_lor + self.nb_spin #total number of indices 1063 1064 1065 #store the representation 1066 if self.lorentz_ind or self.spin_ind: 1067 dict.__init__(self, representation) 1068 elif isinstance(representation,dict): 1069 if len(representation) == 0: 1070 self[(0,)] = 0 1071 elif len(representation) == 1 and (0,) in representation: 1072 self[(0,)] = representation[(0,)] 1073 else: 1074 raise self.LorentzObjectRepresentationError("There is no key of (0,) in representation.") 1075 else: 1076 if isinstance(representation,dict): 1077 try: 1078 self[(0,)] = representation[(0,)] 1079 except Exception: 1080 if representation: 1081 raise LorentzObjectRepresentation.LorentzObjectRepresentationError("There is no key of (0,) in representation.") 1082 else: 1083 self[(0,)] = 0 1084 else: 1085 self[(0,)] = representation
1086
1087 - def __str__(self):
1088 """ string representation """ 1089 text = 'lorentz index :' + str(self.lorentz_ind) + '\n' 1090 text += 'spin index :' + str(self.spin_ind) + '\n' 1091 #text += 'other info ' + str(self.tag) + '\n' 1092 for ind in self.listindices(): 1093 ind = tuple(ind) 1094 text += str(ind) + ' --> ' 1095 text += str(self.get_rep(ind)) + '\n' 1096 return text
1097
1098 - def get_rep(self, indices):
1099 """return the value/Variable associate to the indices""" 1100 return self[tuple(indices)]
1101
1102 - def set_rep(self, indices, value):
1103 """assign 'value' at the indices position""" 1104 1105 self[tuple(indices)] = value
1106
1107 - def listindices(self):
1108 """Return an iterator in order to be able to loop easily on all the 1109 indices of the object.""" 1110 return IndicesIterator(self.nb_ind)
1111 1112 @staticmethod
1113 - def get_mapping(l1,l2, switch_order=[]):
1114 shift = len(switch_order) 1115 for value in l1: 1116 try: 1117 index = l2.index(value) 1118 except Exception: 1119 raise LorentzObjectRepresentation.LorentzObjectRepresentationError( 1120 "Invalid addition. Object doen't have the same lorentz "+ \ 1121 "indices : %s != %s" % (l1, l2)) 1122 else: 1123 switch_order.append(shift + index) 1124 return switch_order
1125 1126
1127 - def __add__(self, obj, fact=1):
1128 1129 if not obj: 1130 return self 1131 1132 if not hasattr(obj, 'vartype'): 1133 assert self.lorentz_ind == [] 1134 assert self.spin_ind == [] 1135 new = self[(0,)] + obj * fact 1136 out = LorentzObjectRepresentation(new, [], []) 1137 return out 1138 1139 assert(obj.vartype == 4 == self.vartype) # are LorentzObjectRepresentation 1140 1141 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind: 1142 # if the order of indices are different compute a mapping 1143 switch_order = [] 1144 self.get_mapping(self.lorentz_ind, obj.lorentz_ind, switch_order) 1145 self.get_mapping(self.spin_ind, obj.spin_ind, switch_order) 1146 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))]) 1147 else: 1148 # no mapping needed (define switch as identity) 1149 switch = lambda ind : (ind) 1150 1151 # Some sanity check 1152 assert tuple(self.lorentz_ind+self.spin_ind) == tuple(switch(obj.lorentz_ind+obj.spin_ind)), '%s!=%s' % (self.lorentz_ind+self.spin_ind, switch(obj.lorentz_ind+self.spin_ind)) 1153 assert tuple(self.lorentz_ind) == tuple(switch(obj.lorentz_ind)), '%s!=%s' % (tuple(self.lorentz_ind), switch(obj.lorentz_ind)) 1154 1155 # define an empty representation 1156 new = LorentzObjectRepresentation({}, obj.lorentz_ind, obj.spin_ind) 1157 1158 # loop over all indices and fullfill the new object 1159 if fact == 1: 1160 for ind in self.listindices(): 1161 value = obj.get_rep(ind) + self.get_rep(switch(ind)) 1162 new.set_rep(ind, value) 1163 else: 1164 for ind in self.listindices(): 1165 value = fact * obj.get_rep(switch(ind)) + self.get_rep(ind) 1166 new.set_rep(ind, value) 1167 1168 return new
1169
1170 - def __iadd__(self, obj, fact=1):
1171 1172 if not obj: 1173 return self 1174 1175 assert(obj.vartype == 4 == self.vartype) # are LorentzObjectRepresentation 1176 1177 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind: 1178 1179 # if the order of indices are different compute a mapping 1180 switch_order = [] 1181 self.get_mapping(obj.lorentz_ind, self.lorentz_ind, switch_order) 1182 self.get_mapping(obj.spin_ind, self.spin_ind, switch_order) 1183 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))]) 1184 else: 1185 # no mapping needed (define switch as identity) 1186 switch = lambda ind : (ind) 1187 1188 # Some sanity check 1189 assert tuple(switch(self.lorentz_ind+self.spin_ind)) == tuple(obj.lorentz_ind+obj.spin_ind), '%s!=%s' % (switch(self.lorentz_ind+self.spin_ind), (obj.lorentz_ind+obj.spin_ind)) 1190 assert tuple(switch(self.lorentz_ind) )== tuple(obj.lorentz_ind), '%s!=%s' % (switch(self.lorentz_ind), tuple(obj.lorentz_ind)) 1191 1192 # loop over all indices and fullfill the new object 1193 if fact == 1: 1194 for ind in self.listindices(): 1195 self[tuple(ind)] += obj.get_rep(switch(ind)) 1196 else: 1197 for ind in self.listindices(): 1198 self[tuple(ind)] += fact * obj.get_rep(switch(ind)) 1199 return self
1200
1201 - def __sub__(self, obj):
1202 return self.__add__(obj, fact= -1)
1203
1204 - def __rsub__(self, obj):
1205 return obj.__add__(self, fact= -1)
1206
1207 - def __isub__(self, obj):
1208 return self.__add__(obj, fact= -1)
1209
1210 - def __neg__(self):
1211 self *= -1 1212 return self
1213
1214 - def __mul__(self, obj):
1215 """multiplication performing directly the einstein/spin sommation. 1216 """ 1217 1218 if not hasattr(obj, 'vartype'): 1219 out = LorentzObjectRepresentation({}, self.lorentz_ind, self.spin_ind) 1220 for ind in out.listindices(): 1221 out.set_rep(ind, obj * self.get_rep(ind)) 1222 return out 1223 1224 # Sanity Check 1225 assert(obj.__class__ == LorentzObjectRepresentation), \ 1226 '%s is not valid class for this operation' %type(obj) 1227 1228 # compute information on the status of the index (which are contracted/ 1229 #not contracted 1230 l_ind, sum_l_ind = self.compare_indices(self.lorentz_ind, \ 1231 obj.lorentz_ind) 1232 s_ind, sum_s_ind = self.compare_indices(self.spin_ind, \ 1233 obj.spin_ind) 1234 if not(sum_l_ind or sum_s_ind): 1235 # No contraction made a tensor product 1236 return self.tensor_product(obj) 1237 1238 # elsewher made a spin contraction 1239 # create an empty representation but with correct indices 1240 new_object = LorentzObjectRepresentation({}, l_ind, s_ind) 1241 #loop and fullfill the representation 1242 for indices in new_object.listindices(): 1243 #made a dictionary (pos -> index_value) for how call the object 1244 dict_l_ind = self.pass_ind_in_dict(indices[:len(l_ind)], l_ind) 1245 dict_s_ind = self.pass_ind_in_dict(indices[len(l_ind):], s_ind) 1246 #add the new value 1247 new_object.set_rep(indices, \ 1248 self.contraction(obj, sum_l_ind, sum_s_ind, \ 1249 dict_l_ind, dict_s_ind)) 1250 1251 return new_object
1252 1253 __rmul__ = __mul__ 1254 __imul__ = __mul__ 1255
1256 - def contraction(self, obj, l_sum, s_sum, l_dict, s_dict):
1257 """ make the Lorentz/spin contraction of object self and obj. 1258 l_sum/s_sum are the position of the sum indices 1259 l_dict/s_dict are dict given the value of the fix indices (indices->value) 1260 """ 1261 out = 0 # initial value for the output 1262 len_l = len(l_sum) #store len for optimization 1263 len_s = len(s_sum) # same 1264 1265 # loop over the possibility for the sum indices and update the dictionary 1266 # (indices->value) 1267 for l_value in IndicesIterator(len_l): 1268 l_dict.update(self.pass_ind_in_dict(l_value, l_sum)) 1269 for s_value in IndicesIterator(len_s): 1270 #s_dict_final = s_dict.copy() 1271 s_dict.update(self.pass_ind_in_dict(s_value, s_sum)) 1272 1273 #return the indices in the correct order 1274 self_ind = self.combine_indices(l_dict, s_dict) 1275 obj_ind = obj.combine_indices(l_dict, s_dict) 1276 1277 # call the object 1278 factor = obj.get_rep(obj_ind) * self.get_rep(self_ind) 1279 1280 if factor: 1281 #compute the prefactor due to the lorentz contraction 1282 try: 1283 factor.prefactor *= (-1) ** (len(l_value) - l_value.count(0)) 1284 except Exception: 1285 factor *= (-1) ** (len(l_value) - l_value.count(0)) 1286 out += factor 1287 return out
1288
1289 - def tensor_product(self, obj):
1290 """ return the tensorial product of the object""" 1291 assert(obj.vartype == 4) #isinstance(obj, LorentzObjectRepresentation)) 1292 1293 new_object = LorentzObjectRepresentation({}, \ 1294 self.lorentz_ind + obj.lorentz_ind, \ 1295 self.spin_ind + obj.spin_ind) 1296 1297 #some shortcut 1298 lor1 = self.nb_lor 1299 lor2 = obj.nb_lor 1300 spin1 = self.nb_spin 1301 spin2 = obj.nb_spin 1302 1303 #define how to call build the indices first for the first object 1304 if lor1 == 0 == spin1: 1305 #special case for scalar 1306 selfind = lambda indices: [0] 1307 else: 1308 selfind = lambda indices: indices[:lor1] + \ 1309 indices[lor1 + lor2: lor1 + lor2 + spin1] 1310 1311 #then for the second 1312 if lor2 == 0 == spin2: 1313 #special case for scalar 1314 objind = lambda indices: [0] 1315 else: 1316 objind = lambda indices: indices[lor1: lor1 + lor2] + \ 1317 indices[lor1 + lor2 + spin1:] 1318 1319 # loop on the indices and assign the product 1320 for indices in new_object.listindices(): 1321 1322 fac1 = self.get_rep(tuple(selfind(indices))) 1323 fac2 = obj.get_rep(tuple(objind(indices))) 1324 new_object.set_rep(indices, fac1 * fac2) 1325 1326 return new_object
1327
1328 - def factorize(self):
1329 """Try to factorize each component""" 1330 for ind, fact in self.items(): 1331 if fact: 1332 self.set_rep(ind, fact.factorize()) 1333 1334 1335 return self
1336
1337 - def simplify(self):
1338 """Check if we can simplify the object (check for non treated Sum)""" 1339 1340 #Look for internal simplification 1341 for ind, term in self.items(): 1342 if hasattr(term, 'vartype'): 1343 self[ind] = term.simplify() 1344 #no additional simplification 1345 return self
1346 1347 @staticmethod
1348 - def compare_indices(list1, list2):
1349 """return two list, the first one contains the position of non summed 1350 index and the second one the position of summed index.""" 1351 #init object 1352 1353 # equivalent set call --slightly slower 1354 #return list(set(list1) ^ set(list2)), list(set(list1) & set(list2)) 1355 1356 1357 are_unique, are_sum = [], [] 1358 # loop over the first list and check if they are in the second list 1359 1360 for indice in list1: 1361 if indice in list2: 1362 are_sum.append(indice) 1363 else: 1364 are_unique.append(indice) 1365 # loop over the second list for additional unique item 1366 1367 for indice in list2: 1368 if indice not in are_sum: 1369 are_unique.append(indice) 1370 1371 # return value 1372 return are_unique, are_sum
1373 1374 @staticmethod
1375 - def pass_ind_in_dict(indices, key):
1376 """made a dictionary (pos -> index_value) for how call the object""" 1377 if not key: 1378 return {} 1379 out = {} 1380 for i, ind in enumerate(indices): 1381 out[key[i]] = ind 1382 return out
1383
1384 - def combine_indices(self, l_dict, s_dict):
1385 """return the indices in the correct order following the dicts rules""" 1386 1387 out = [] 1388 # First for the Lorentz indices 1389 for value in self.lorentz_ind: 1390 out.append(l_dict[value]) 1391 # Same for the spin 1392 for value in self.spin_ind: 1393 out.append(s_dict[value]) 1394 1395 return out
1396
1397 - def split(self, variables_id):
1398 """return a dict with the key being the power associated to each variables 1399 and the value being the object remaining after the suppression of all 1400 the variable""" 1401 1402 out = SplitCoefficient() 1403 zero_rep = {} 1404 for ind in self.listindices(): 1405 zero_rep[tuple(ind)] = 0 1406 1407 for ind in self.listindices(): 1408 # There is no function split if the element is just a simple number 1409 if isinstance(self.get_rep(ind), numbers.Number): 1410 if tuple([0]*len(variables_id)) in out: 1411 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind) 1412 else: 1413 out[tuple([0]*len(variables_id))] = \ 1414 LorentzObjectRepresentation(dict(zero_rep), 1415 self.lorentz_ind, self.spin_ind) 1416 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind) 1417 continue 1418 1419 for key, value in self.get_rep(ind).split(variables_id).items(): 1420 if key in out: 1421 out[key][tuple(ind)] += value 1422 else: 1423 out[key] = LorentzObjectRepresentation(dict(zero_rep), 1424 self.lorentz_ind, self.spin_ind) 1425 out[key][tuple(ind)] += value 1426 1427 return out
1428
1429 1430 1431 1432 #=============================================================================== 1433 # IndicesIterator 1434 #=============================================================================== 1435 -class IndicesIterator:
1436 """Class needed for the iterator""" 1437
1438 - def __init__(self, len):
1439 """ create an iterator looping over the indices of a list of len "len" 1440 with each value can take value between 0 and 3 """ 1441 1442 self.len = len # number of indices 1443 if len: 1444 # initialize the position. The first position is -1 due to the method 1445 #in place which start by rising an index before returning smtg 1446 self.data = [-1] + [0] * (len - 1) 1447 else: 1448 # Special case for Scalar object 1449 self.data = 0 1450 self.next = self.nextscalar
1451
1452 - def __iter__(self):
1453 return self
1454
1455 - def next(self):
1456 for i in range(self.len): 1457 if self.data[i] < 3: 1458 self.data[i] += 1 1459 return self.data 1460 else: 1461 self.data[i] = 0 1462 raise StopIteration
1463
1464 - def nextscalar(self):
1465 if self.data: 1466 raise StopIteration 1467 else: 1468 self.data = True 1469 return [0]
1470
1471 -class SplitCoefficient(dict):
1472
1473 - def __init__(self, *args, **opt):
1474 dict.__init__(self, *args, **opt) 1475 self.tag=set()
1476
1477 - def get_max_rank(self):
1478 """return the highest rank of the coefficient""" 1479 1480 return max([max(arg[:4]) for arg in self])
1481 1482 1483 if '__main__' ==__name__: 1484 1485 import cProfile
1486 - def create():
1487 for i in range(10000): 1488 LorentzObjectRepresentation.compare_indices(range(i%10),[4,3,5])
1489 1490 cProfile.run('create()') 1491