Package aloha :: Module aloha_lib
[hide private]
[frames] | no frames]

Source Code for Module aloha.aloha_lib

   1  ################################################################################ 
   2  # 
   3  # Copyright (c) 2010 The MadGraph5_aMC@NLO Development team and Contributors 
   4  # 
   5  # This file is a part of the MadGraph5_aMC@NLO project, an application which  
   6  # automatically generates Feynman diagrams and matrix elements for arbitrary 
   7  # high-energy processes in the Standard Model and beyond. 
   8  # 
   9  # It is subject to the MadGraph5_aMC@NLO license which should accompany this  
  10  # distribution. 
  11  # 
  12  # For more information, visit madgraph.phys.ucl.ac.be and amcatnlo.web.cern.ch 
  13  # 
  14  ################################################################################ 
  15  ##   Diagram of Class 
  16  ## 
  17  ##    Variable (vartype:0)<--- ScalarVariable  
  18  ##                          | 
  19  ##                          +- LorentzObject  
  20  ##                                 
  21  ## 
  22  ##    list <--- AddVariable (vartype :1)    
  23  ##            
  24  ##    array <--- MultVariable  <--- MultLorentz (vartype:2)  
  25  ##            
  26  ##    list <--- LorentzObjectRepresentation (vartype :4) <-- ConstantObject 
  27  ##                                                               (vartype:5) 
  28  ## 
  29  ##    FracVariable (vartype:3) 
  30  ## 
  31  ##    MultContainer (vartype:6) 
  32  ## 
  33  ################################################################################ 
  34  ## 
  35  ##   Variable is in fact Factory wich adds a references to the variable name 
  36  ##   Into the KERNEL (Of class Computation) instantiate a real variable object 
  37  ##   (of class C_Variable, DVariable for complex/real) and return a MUltVariable 
  38  ##   with a single element. 
  39  ## 
  40  ##   Lorentz Object works in the same way. 
  41  ## 
  42  ################################################################################ 
  43   
  44   
  45  from __future__ import division 
  46  from array import array 
  47  import collections 
  48  from fractions import Fraction 
  49  import numbers 
  50  import re 
  51  import aloha # define mode of writting 
52 53 -class defaultdict(collections.defaultdict):
54
55 - def __call__(self, *args):
56 return defaultdict(int)
57
58 -class Computation(dict):
59 """ a class to encapsulate all computation. Limit side effect """ 60
61 - def __init__(self):
62 self.objs = [] 63 self.use_tag = set() 64 self.id = -1 65 self.reduced_expr = {} 66 self.fct_expr = {} 67 self.reduced_expr2 = {} 68 self.inverted_fct = {} 69 self.has_pi = False # logical to check if pi is used in at least one fct 70 self.unknow_fct = [] 71 dict.__init__(self)
72
73 - def clean(self):
74 self.__init__() 75 self.clear()
76
77 - def add(self, name, obj):
78 self.id += 1 79 self.objs.append(obj) 80 self[name] = self.id 81 return self.id
82
83 - def get(self, name):
84 return self.objs[self[name]]
85
86 - def add_tag(self, tag):
87 self.use_tag.update(tag)
88
89 - def get_ids(self, variables):
90 """return the list of identification number associate to the 91 given variables names. If a variable didn't exists, create it (in complex). 92 """ 93 out = [] 94 for var in variables: 95 try: 96 id = self[var] 97 except KeyError: 98 assert var not in ['M','W'] 99 id = Variable(var).get_id() 100 out.append(id) 101 return out
102 103
104 - def add_expression_contraction(self, expression):
105 106 str_expr = str(expression) 107 if str_expr in self.reduced_expr: 108 out, tag = self.reduced_expr[str_expr] 109 self.add_tag((tag,)) 110 return out 111 if expression == 0: 112 return 0 113 new_2 = expression.simplify() 114 if new_2 == 0: 115 return 0 116 # Add a new variable 117 tag = 'TMP%s' % len(self.reduced_expr) 118 new = Variable(tag) 119 self.reduced_expr[str_expr] = [new, tag] 120 new_2 = new_2.factorize() 121 self.reduced_expr2[tag] = new_2 122 self.add_tag((tag,)) 123 #self.unknow_fct = [] 124 #return expression 125 return new
126 127 known_fct = ['/', 'log', 'pow', 'sin', 'cos', 'asin', 'acos', 'tan', 'cot', 'acot', 128 'theta_function', 'exp']
129 - def add_function_expression(self, fct_tag, *args):
130 131 if not (fct_tag.startswith('cmath.') or fct_tag in self.known_fct or 132 (fct_tag, len(args)) in self.unknow_fct): 133 self.unknow_fct.append( (fct_tag, len(args)) ) 134 135 argument = [] 136 for expression in args: 137 if isinstance(expression, (MultLorentz, AddVariable, LorentzObject)): 138 try: 139 expr = expression.expand().get_rep([0]) 140 except KeyError, error: 141 if error.args != ((0,),): 142 raise 143 else: 144 raise aloha.ALOHAERROR, '''Error in input format. 145 Argument of function (or denominator) should be scalar. 146 We found %s''' % expression 147 new = expr.simplify() 148 new = expr.factorize() 149 argument.append(new) 150 else: 151 argument.append(expression) 152 for arg in argument: 153 val = re.findall(r'''\bFCT(\d*)\b''', str(arg)) 154 for v in val: 155 self.add_tag(('FCT%s' % v,)) 156 157 if str(fct_tag)+str(argument) in self.inverted_fct: 158 tag = self.inverted_fct[str(fct_tag)+str(argument)] 159 v = tag.split('(')[1][:-1] 160 self.add_tag(('FCT%s' % v,)) 161 return tag 162 else: 163 id = len(self.fct_expr) 164 tag = 'FCT%s' % id 165 self.inverted_fct[str(fct_tag)+str(argument)] = 'FCT(%s)' % id 166 self.fct_expr[tag] = (fct_tag, argument) 167 self.reduced_expr2[tag] = (fct_tag, argument) 168 self.add_tag((tag,)) 169 return 'FCT(%s)' % id
170 171 KERNEL = Computation()
172 173 #=============================================================================== 174 # AddVariable 175 #=============================================================================== 176 -class AddVariable(list):
177 """ A list of Variable/ConstantObject/... This object represent the operation 178 between those object.""" 179 180 #variable to fastenize class recognition 181 vartype = 1 182
183 - def __init__(self, old_data=[], prefactor=1):
184 """ initialization of the object with default value """ 185 186 self.prefactor = prefactor 187 #self.tag = set() 188 list.__init__(self, old_data)
189
190 - def simplify(self):
191 """ apply rule of simplification """ 192 193 # deal with one length object 194 if len(self) == 1: 195 return self.prefactor * self[0].simplify() 196 constant = 0 197 items = {} 198 pos = -1 199 for term in self[:]: 200 pos += 1 # current position in the real self 201 if not hasattr(term, 'vartype'): 202 if isinstance(term, dict): 203 # allow term of type{(0,):x} 204 assert term.values() == [0] 205 term = term[(0,)] 206 constant += term 207 del self[pos] 208 pos -= 1 209 continue 210 tag = tuple(term.sort()) 211 if tag in items: 212 items[tag].prefactor += term.prefactor 213 del self[pos] 214 pos -=1 215 else: 216 items[tag] = term.__class__(term, term.prefactor) 217 self[pos] = items[tag] 218 219 # get the optimized prefactor 220 countprefact = defaultdict(int) 221 nbplus, nbminus = 0,0 222 if constant not in [0, 1,-1]: 223 countprefact[constant] += 1 224 if constant.real + constant.imag > 0: 225 nbplus += 1 226 else: 227 nbminus += 1 228 229 for var in items.values(): 230 if var.prefactor == 0: 231 self.remove(var) 232 else: 233 nb = var.prefactor 234 if nb in [1,-1]: 235 continue 236 countprefact[abs(nb)] +=1 237 if nb.real + nb.imag > 0: 238 nbplus += 1 239 else: 240 nbminus += 1 241 if countprefact and max(countprefact.values()) >1: 242 fact_prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0] 243 else: 244 fact_prefactor = 1 245 if nbplus < nbminus: 246 fact_prefactor *= -1 247 self.prefactor *= fact_prefactor 248 249 if fact_prefactor != 1: 250 for i,a in enumerate(self): 251 try: 252 a.prefactor /= fact_prefactor 253 except AttributeError: 254 self[i] /= fact_prefactor 255 256 if constant: 257 self.append(constant/ fact_prefactor ) 258 259 # deal with one/zero length object 260 varlen = len(self) 261 if varlen == 1: 262 if hasattr(self[0], 'vartype'): 263 return self.prefactor * self[0].simplify() 264 else: 265 #self[0] is a number 266 return self.prefactor * self[0] 267 elif varlen == 0: 268 return 0 #ConstantObject() 269 return self
270
271 - def split(self, variables_id):
272 """return a dict with the key being the power associated to each variables 273 and the value being the object remaining after the suppression of all 274 the variable""" 275 276 out = defaultdict(int) 277 for obj in self: 278 for key, value in obj.split(variables_id).items(): 279 out[key] += self.prefactor * value 280 return out
281
282 - def contains(self, variables):
283 """returns true if one of the variables is in the expression""" 284 285 return any((v in obj for obj in self for v in variables ))
286 287
288 - def get_all_var_names(self):
289 290 out = [] 291 for term in self: 292 if hasattr(term, 'get_all_var_names'): 293 out += term.get_all_var_names() 294 return out
295 296 297
298 - def replace(self, id, expression):
299 """replace one object (identify by his id) by a given expression. 300 Note that expression cann't be zero. 301 Note that this should be canonical form (this should contains ONLY 302 MULTVARIABLE) --so this should be called before a factorize. 303 """ 304 new = self.__class__() 305 306 for obj in self: 307 assert isinstance(obj, MultVariable) 308 tmp = obj.replace(id, expression) 309 new += tmp 310 new.prefactor = self.prefactor 311 return new
312 313
314 - def expand(self, veto=[]):
315 """Pass from High level object to low level object""" 316 317 if not self: 318 return self 319 if self.prefactor == 1: 320 new = self[0].expand(veto) 321 else: 322 new = self.prefactor * self[0].expand(veto) 323 324 for item in self[1:]: 325 if self.prefactor == 1: 326 try: 327 new += item.expand(veto) 328 except AttributeError: 329 new = new + item 330 331 else: 332 new += (self.prefactor) * item.expand(veto) 333 return new
334
335 - def __mul__(self, obj):
336 """define the multiplication of 337 - a AddVariable with a number 338 - a AddVariable with an AddVariable 339 other type of multiplication are define via the symmetric operation base 340 on the obj class.""" 341 342 343 if not hasattr(obj, 'vartype'): # obj is a number 344 if not obj: 345 return 0 346 return self.__class__(self, self.prefactor*obj) 347 elif obj.vartype == 1: # obj is an AddVariable 348 new = self.__class__([],self.prefactor * obj.prefactor) 349 new[:] = [i*j for i in self for j in obj] 350 return new 351 else: 352 #force the program to look at obj + self 353 return NotImplemented
354
355 - def __imul__(self, obj):
356 """define the multiplication of 357 - a AddVariable with a number 358 - a AddVariable with an AddVariable 359 other type of multiplication are define via the symmetric operation base 360 on the obj class.""" 361 362 if not hasattr(obj, 'vartype'): # obj is a number 363 if not obj: 364 return 0 365 self.prefactor *= obj 366 return self 367 elif obj.vartype == 1: # obj is an AddVariable 368 new = self.__class__([], self.prefactor * obj.prefactor) 369 new[:] = [i*j for i in self for j in obj] 370 return new 371 else: 372 #force the program to look at obj + self 373 return NotImplemented
374
375 - def __neg__(self):
376 self.prefactor *= -1 377 return self
378
379 - def __add__(self, obj):
380 """Define all the different addition.""" 381 382 if not hasattr(obj, 'vartype'): 383 if not obj: # obj is zero 384 return self 385 new = self.__class__(self, self.prefactor) 386 new.append(obj/self.prefactor) 387 return new 388 elif obj.vartype == 2: # obj is a MultVariable 389 new = AddVariable(self, self.prefactor) 390 if self.prefactor == 1: 391 new.append(obj) 392 else: 393 new.append((1/self.prefactor)*obj) 394 return new 395 elif obj.vartype == 1: # obj is a AddVariable 396 new = AddVariable(self, self.prefactor) 397 for item in obj: 398 new.append(obj.prefactor/self.prefactor * item) 399 return new 400 else: 401 #force to look at obj + self 402 return NotImplemented
403
404 - def __iadd__(self, obj):
405 """Define all the different addition.""" 406 407 if not hasattr(obj, 'vartype'): 408 if not obj: # obj is zero 409 return self 410 self.append(obj/self.prefactor) 411 return self 412 elif obj.vartype == 2: # obj is a MultVariable 413 if self.prefactor == 1: 414 self.append(obj) 415 else: 416 self.append((1/self.prefactor)*obj) 417 return self 418 elif obj.vartype == 1: # obj is a AddVariable 419 for item in obj: 420 self.append(obj.prefactor/self.prefactor * item) 421 return self 422 else: 423 #force to look at obj + self 424 return NotImplemented
425
426 - def __sub__(self, obj):
427 return self + (-1) * obj
428
429 - def __rsub__(self, obj):
430 return (-1) * self + obj
431 432 __radd__ = __add__ 433 __rmul__ = __mul__ 434 435
436 - def __div__(self, obj):
437 return self.__mul__(1/obj)
438 439 __truediv__ = __div__ 440
441 - def __rdiv__(self, obj):
442 return self.__rmult__(1/obj)
443
444 - def __str__(self):
445 text = '' 446 if self.prefactor != 1: 447 text += str(self.prefactor) + ' * ' 448 text += '( ' 449 text += ' + '.join([str(item) for item in self]) 450 text += ' )' 451 return text
452
453 - def count_term(self):
454 # Count the number of appearance of each variable and find the most 455 #present one in order to factorize her 456 count = defaultdict(int) 457 correlation = defaultdict(defaultdict(int)) 458 for i,term in enumerate(self): 459 try: 460 set_term = set(term) 461 except TypeError: 462 #constant term 463 continue 464 for val1 in set_term: 465 count[val1] +=1 466 # allow to find optimized factorization for identical count 467 for val2 in set_term: 468 correlation[val1][val2] += 1 469 470 maxnb = max(count.values()) if count else 0 471 possibility = [v for v,val in count.items() if val == maxnb] 472 if maxnb == 1: 473 return 1, None 474 elif len(possibility) == 1: 475 return maxnb, possibility[0] 476 #import random 477 #return maxnb, random.sample(possibility,1)[0] 478 479 #return maxnb, possibility[0] 480 max_wgt, maxvar = 0, None 481 for var in possibility: 482 wgt = sum(w**2 for w in correlation[var].values())/len(correlation[var]) 483 if wgt > max_wgt: 484 maxvar = var 485 max_wgt = wgt 486 str_maxvar = str(KERNEL.objs[var]) 487 elif wgt == max_wgt: 488 # keep the one with the lowest string expr 489 new_str = str(KERNEL.objs[var]) 490 if new_str < str_maxvar: 491 maxvar = var 492 str_maxvar = new_str 493 return maxnb, maxvar
494
495 - def factorize(self):
496 """ try to factorize as much as possible the expression """ 497 498 max, maxvar = self.count_term() 499 if max <= 1: 500 #no factorization possible 501 return self 502 else: 503 # split in MAXVAR * NEWADD + CONSTANT 504 newadd = AddVariable() 505 constant = AddVariable() 506 #fill NEWADD and CONSTANT 507 for term in self: 508 try: 509 term.remove(maxvar) 510 except Exception: 511 constant.append(term) 512 else: 513 if len(term): 514 newadd.append(term) 515 else: 516 newadd.append(term.prefactor) 517 newadd = newadd.factorize() 518 519 # optimize the prefactor 520 if isinstance(newadd, AddVariable): 521 countprefact = defaultdict(int) 522 nbplus, nbminus = 0,0 523 for nb in [a.prefactor for a in newadd if hasattr(a, 'prefactor')]: 524 countprefact[abs(nb)] +=1 525 if nb.real + nb.imag > 0: 526 nbplus += 1 527 else: 528 nbminus += 1 529 530 newadd.prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0] 531 if nbplus < nbminus: 532 newadd.prefactor *= -1 533 if newadd.prefactor != 1: 534 for i,a in enumerate(newadd): 535 try: 536 a.prefactor /= newadd.prefactor 537 except AttributeError: 538 newadd[i] /= newadd.prefactor 539 540 541 if len(constant) > 1: 542 constant = constant.factorize() 543 elif constant: 544 constant = constant[0] 545 else: 546 out = MultContainer([KERNEL.objs[maxvar], newadd]) 547 out.prefactor = self.prefactor 548 if newadd.prefactor != 1: 549 out.prefactor *= newadd.prefactor 550 newadd.prefactor = 1 551 return out 552 out = AddVariable([MultContainer([KERNEL.objs[maxvar], newadd]), constant], 553 self.prefactor) 554 return out
555
556 -class MultContainer(list):
557 558 vartype = 6 559
560 - def __init__(self,*args):
561 self.prefactor =1 562 list.__init__(self, *args)
563
564 - def __str__(self):
565 """ String representation """ 566 if self.prefactor !=1: 567 text = '(%s * %s)' % (self.prefactor, ' * '.join([str(t) for t in self])) 568 else: 569 text = '(%s)' % (' * '.join([str(t) for t in self])) 570 return text
571
572 - def factorize(self):
573 self[:] = [term.factorize() for term in self]
574
575 576 -class MultVariable(array):
577 """ A list of Variable with multiplication as operator between themselves. 578 Represented by array for speed optimization 579 """ 580 vartype=2 581 addclass = AddVariable 582
583 - def __new__(cls, old=[], prefactor=1):
584 return array.__new__(cls, 'i', old)
585 586
587 - def __init__(self, old=[], prefactor=1):
588 """ initialization of the object with default value """ 589 #array.__init__(self, 'i', old) <- done already in new !! 590 self.prefactor = prefactor 591 assert isinstance(self.prefactor, (float,int,long,complex))
592
593 - def get_id(self):
594 assert len(self) == 1 595 return self[0]
596
597 - def sort(self):
598 a = list(self) 599 a.sort() 600 self[:] = array('i',a) 601 return self
602
603 - def simplify(self):
604 """ simplify the product""" 605 if not len(self): 606 return self.prefactor 607 return self
608
609 - def split(self, variables_id):
610 """return a dict with the key being the power associated to each variables 611 and the value being the object remaining after the suppression of all 612 the variable""" 613 614 key = tuple([self.count(i) for i in variables_id]) 615 arg = [id for id in self if id not in variables_id] 616 self[:] = array('i', arg) 617 return SplitCoefficient([(key,self)])
618
619 - def replace(self, id, expression):
620 """replace one object (identify by his id) by a given expression. 621 Note that expression cann't be zero. 622 """ 623 assert hasattr(expression, 'vartype') , 'expression should be of type Add or Mult' 624 625 if expression.vartype == 1: # AddVariable 626 nb = self.count(id) 627 if not nb: 628 return self 629 for i in range(nb): 630 self.remove(id) 631 new = self 632 for i in range(nb): 633 new *= expression 634 return new 635 elif expression.vartype == 2: # MultLorentz 636 # be carefull about A -> A * B 637 nb = self.count(id) 638 for i in range(nb): 639 self.remove(id) 640 self.__imul__(expression) 641 return self 642 # elif expression.vartype == 0: # Variable 643 # new_id = expression.id 644 # assert new_id != id 645 # while 1: 646 # try: 647 # self.remove(id) 648 # except ValueError: 649 # break 650 # else: 651 # self.append(new_id) 652 # return self 653 else: 654 raise Exception, 'Cann\'t replace a Variable by %s' % type(expression)
655 656
657 - def get_all_var_names(self):
658 """return the list of variable used in this multiplication""" 659 return ['%s' % KERNEL.objs[n] for n in self]
660 661 662 663 #Defining rule of Multiplication
664 - def __mul__(self, obj):
665 """Define the multiplication with different object""" 666 667 if not hasattr(obj, 'vartype'): # should be a number 668 if obj: 669 return self.__class__(self, obj*self.prefactor) 670 else: 671 return 0 672 elif obj.vartype == 1: # obj is an AddVariable 673 new = obj.__class__([], self.prefactor*obj.prefactor) 674 old, self.prefactor = self.prefactor, 1 675 new[:] = [self * term for term in obj] 676 self.prefactor = old 677 return new 678 elif obj.vartype == 4: 679 return NotImplemented 680 681 return self.__class__(array.__add__(self, obj), self.prefactor * obj.prefactor)
682 683 __rmul__ = __mul__ 684
685 - def __imul__(self, obj):
686 """Define the multiplication with different object""" 687 688 if not hasattr(obj, 'vartype'): # should be a number 689 if obj: 690 self.prefactor *= obj 691 return self 692 else: 693 return 0 694 elif obj.vartype == 1: # obj is an AddVariable 695 new = obj.__class__([], self.prefactor * obj.prefactor) 696 self.prefactor = 1 697 new[:] = [self * term for term in obj] 698 return new 699 elif obj.vartype == 4: 700 return NotImplemented 701 702 self.prefactor *= obj.prefactor 703 return array.__iadd__(self, obj)
704
705 - def __pow__(self,value):
706 out = 1 707 for i in range(value): 708 out *= self 709 return out
710 711
712 - def __add__(self, obj):
713 """ define the adition with different object""" 714 715 if not obj: 716 return self 717 elif not hasattr(obj, 'vartype') or obj.vartype == 2: 718 new = self.addclass([self, obj]) 719 return new 720 else: 721 #call the implementation of addition implemented in obj 722 return NotImplemented
723 __radd__ = __add__ 724 __iadd__ = __add__ 725
726 - def __sub__(self, obj):
727 return self + (-1) * obj
728
729 - def __neg__(self):
730 self.prefactor *=-1 731 return self
732
733 - def __rsub__(self, obj):
734 return (-1) * self + obj
735
736 - def __idiv__(self,obj):
737 """ ONLY NUMBER DIVISION ALLOWED""" 738 assert not hasattr(obj, 'vartype') 739 self.prefactor /= obj 740 return self
741 742 __div__ = __idiv__ 743 __truediv__ = __div__ 744 745
746 - def __str__(self):
747 """ String representation """ 748 t = ['%s' % KERNEL.objs[n] for n in self] 749 if self.prefactor != 1: 750 text = '(%s * %s)' % (self.prefactor,' * '.join(t)) 751 else: 752 text = '(%s)' % (' * '.join(t)) 753 return text
754 755 __rep__ = __str__ 756
757 - def factorize(self):
758 return self
759
760 761 #=============================================================================== 762 # FactoryVar 763 #=============================================================================== 764 -class C_Variable(str):
765 vartype=0 766 type = 'complex'
767
768 -class R_Variable(str):
769 vartype=0 770 type = 'double'
771
772 -class ExtVariable(str):
773 vartype=0 774 type = 'parameter'
775
776 777 -class FactoryVar(object):
778 """This is the standard object for all the variable linked to expression. 779 """ 780 mult_class = MultVariable # The class for the multiplication 781
782 - def __new__(cls, name, baseclass, *args):
783 """Factory class return a MultVariable.""" 784 785 if name in KERNEL: 786 return cls.mult_class([KERNEL[name]]) 787 else: 788 obj = baseclass(name, *args) 789 id = KERNEL.add(name, obj) 790 obj.id = id 791 return cls.mult_class([id])
792
793 -class Variable(FactoryVar):
794
795 - def __new__(self, name, type=C_Variable):
796 return FactoryVar(name, type)
797
798 -class DVariable(FactoryVar):
799
800 - def __new__(self, name):
801 802 if aloha.complex_mass: 803 #some parameter are pass to complex 804 if name[0] in ['M','W'] or name.startswith('OM'): 805 return FactoryVar(name, C_Variable) 806 if aloha.loop_mode and name.startswith('P'): 807 return FactoryVar(name, C_Variable) 808 #Normal case: 809 return FactoryVar(name, R_Variable)
810
811 812 813 814 #=============================================================================== 815 # Object for Analytical Representation of Lorentz object (not scalar one) 816 #=============================================================================== 817 818 819 #=============================================================================== 820 # MultLorentz 821 #=============================================================================== 822 -class MultLorentz(MultVariable):
823 """Specific class for LorentzObject Multiplication""" 824 825 add_class = AddVariable # Define which class describe the addition 826
827 - def find_lorentzcontraction(self):
828 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining 829 the contraction in this Multiplication.""" 830 831 out = {} 832 len_mult = len(self) 833 # Loop over the element 834 for i, fact in enumerate(self): 835 # and over the indices of this element 836 for j in range(len(fact.lorentz_ind)): 837 # in order to compare with the other element of the multiplication 838 for k in range(i+1,len_mult): 839 fact2 = self[k] 840 try: 841 l = fact2.lorentz_ind.index(fact.lorentz_ind[j]) 842 except Exception: 843 pass 844 else: 845 out[(i, j)] = (k, l) 846 out[(k, l)] = (i, j) 847 return out
848
849 - def find_spincontraction(self):
850 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining 851 the contraction in this Multiplication.""" 852 853 out = {} 854 len_mult = len(self) 855 # Loop over the element 856 for i, fact in enumerate(self): 857 # and over the indices of this element 858 for j in range(len(fact.spin_ind)): 859 # in order to compare with the other element of the multiplication 860 for k in range(i+1, len_mult): 861 fact2 = self[k] 862 try: 863 l = fact2.spin_ind.index(fact.spin_ind[j]) 864 except Exception: 865 pass 866 else: 867 out[(i, j)] = (k, l) 868 out[(k, l)] = (i, j) 869 870 return out
871
872 - def neighboor(self, home):
873 """return one variable which are contracted with var and not yet expanded""" 874 875 for var in self.unused: 876 obj = KERNEL.objs[var] 877 if obj.has_component(home.lorentz_ind, home.spin_ind): 878 return obj 879 return None
880 881 882 883
884 - def expand(self, veto=[]):
885 """ expand each part of the product and combine them. 886 Try to use a smart order in order to minimize the number of uncontracted indices. 887 Veto forbids the use of sub-expression if it contains some of the variable in the 888 expression. Veto contains the id of the vetoed variables 889 """ 890 891 self.unused = self[:] # list of not expanded 892 # made in a list the interesting starting point for the computation 893 basic_end_point = [var for var in self if KERNEL.objs[var].contract_first] 894 product_term = [] #store result of intermediate chains 895 current = None # current point in the working chain 896 897 while self.unused: 898 #Loop untill we have expand everything 899 if not current: 900 # First we need to have a starting point 901 try: 902 # look in priority in basic_end_point (P/S/fermion/...) 903 current = basic_end_point.pop() 904 except Exception: 905 #take one of the remaining 906 current = self.unused.pop() 907 else: 908 #check that this one is not already use 909 if current not in self.unused: 910 current = None 911 continue 912 #remove of the unuse (usualy done in the pop) 913 self.unused.remove(current) 914 cur_obj = KERNEL.objs[current] 915 # initialize the new chain 916 product_term.append(cur_obj.expand()) 917 918 # We have a point -> find the next one 919 var_obj = self.neighboor(product_term[-1]) 920 # provide one term which is contracted with current and which is not 921 #yet expanded. 922 if var_obj: 923 product_term[-1] *= var_obj.expand() 924 cur_obj = var_obj 925 self.unused.remove(cur_obj.id) 926 continue 927 928 current = None 929 930 931 # Multiply all those current 932 # For Fermion/Vector only one can carry index. 933 out = self.prefactor 934 for fact in product_term[:]: 935 if hasattr(fact, 'vartype') and fact.lorentz_ind == fact.spin_ind == []: 936 scalar = fact.get_rep([0]) 937 if hasattr(scalar, 'vartype') and scalar.vartype == 1: 938 if not veto or not scalar.contains(veto): 939 scalar = scalar.simplify() 940 prefactor = 1 941 942 if hasattr(scalar, 'vartype') and scalar.prefactor not in [1,-1]: 943 prefactor = scalar.prefactor 944 scalar.prefactor = 1 945 new = KERNEL.add_expression_contraction(scalar) 946 fact.set_rep([0], prefactor * new) 947 out *= fact 948 return out
949
950 - def __copy__(self):
951 """ create a shadow copy """ 952 new = MultLorentz(self) 953 new.prefactor = self.prefactor 954 return new
955
956 #=============================================================================== 957 # LorentzObject 958 #=============================================================================== 959 -class LorentzObject(object):
960 """ A symbolic Object for All Helas object. All Helas Object Should 961 derivated from this class""" 962 963 contract_first = 0 964 mult_class = MultLorentz # The class for the multiplication 965 add_class = AddVariable # The class for the addition 966
967 - def __init__(self, name, lor_ind, spin_ind, tags=[]):
968 """ initialization of the object with default value """ 969 assert isinstance(lor_ind, list) 970 assert isinstance(spin_ind, list) 971 972 self.name = name 973 self.lorentz_ind = lor_ind 974 self.spin_ind = spin_ind 975 KERNEL.add_tag(set(tags))
976
977 - def expand(self):
978 """Expand the content information into LorentzObjectRepresentation.""" 979 980 try: 981 return self.representation 982 except Exception: 983 self.create_representation() 984 return self.representation
985
986 - def create_representation(self):
987 raise self.VariableError("This Object %s doesn't have define representation" % self.__class__.__name__)
988
989 - def has_component(self, lor_list, spin_list):
990 """check if this Lorentz Object have some of those indices""" 991 992 if any([id in self.lorentz_ind for id in lor_list]) or \ 993 any([id in self.spin_ind for id in spin_list]): 994 return True
995 996 997
998 - def __str__(self):
999 return '%s' % self.name
1000
1001 -class FactoryLorentz(FactoryVar):
1002 """ A symbolic Object for All Helas object. All Helas Object Should 1003 derivated from this class""" 1004 1005 mult_class = MultLorentz # The class for the multiplication 1006 object_class = LorentzObject # Define How to create the basic object. 1007
1008 - def __new__(cls, *args):
1009 name = cls.get_unique_name(*args) 1010 return FactoryVar.__new__(cls, name, cls.object_class, *args)
1011 1012 @classmethod
1013 - def get_unique_name(cls, *args):
1014 """default way to have a unique name""" 1015 return '_L_%(class)s_%(args)s' % \ 1016 {'class':cls.__name__, 1017 'args': '_'.join(args) 1018 }
1019
1020 1021 #=============================================================================== 1022 # LorentzObjectRepresentation 1023 #=============================================================================== 1024 -class LorentzObjectRepresentation(dict):
1025 """A concrete representation of the LorentzObject.""" 1026 1027 vartype = 4 # Optimization for instance recognition 1028
1029 - class LorentzObjectRepresentationError(Exception):
1030 """Specify error for LorentzObjectRepresentation"""
1031
1032 - def __init__(self, representation, lorentz_indices, spin_indices):
1033 """ initialize the lorentz object representation""" 1034 1035 self.lorentz_ind = lorentz_indices #lorentz indices 1036 self.nb_lor = len(lorentz_indices) #their number 1037 self.spin_ind = spin_indices #spin indices 1038 self.nb_spin = len(spin_indices) #their number 1039 self.nb_ind = self.nb_lor + self.nb_spin #total number of indices 1040 1041 #store the representation 1042 if self.lorentz_ind or self.spin_ind: 1043 dict.__init__(self, representation) 1044 elif isinstance(representation,dict): 1045 if len(representation) == 0: 1046 self[(0,)] = 0 1047 elif len(representation) == 1 and (0,) in representation: 1048 self[(0,)] = representation[(0,)] 1049 else: 1050 raise self.LorentzObjectRepresentationError("There is no key of (0,) in representation.") 1051 else: 1052 if isinstance(representation,dict): 1053 try: 1054 self[(0,)] = representation[(0,)] 1055 except Exception: 1056 if representation: 1057 raise LorentzObjectRepresentation.LorentzObjectRepresentationError("There is no key of (0,) in representation.") 1058 else: 1059 self[(0,)] = 0 1060 else: 1061 self[(0,)] = representation
1062
1063 - def __str__(self):
1064 """ string representation """ 1065 text = 'lorentz index :' + str(self.lorentz_ind) + '\n' 1066 text += 'spin index :' + str(self.spin_ind) + '\n' 1067 #text += 'other info ' + str(self.tag) + '\n' 1068 for ind in self.listindices(): 1069 ind = tuple(ind) 1070 text += str(ind) + ' --> ' 1071 text += str(self.get_rep(ind)) + '\n' 1072 return text
1073
1074 - def get_rep(self, indices):
1075 """return the value/Variable associate to the indices""" 1076 return self[tuple(indices)]
1077
1078 - def set_rep(self, indices, value):
1079 """assign 'value' at the indices position""" 1080 1081 self[tuple(indices)] = value
1082
1083 - def listindices(self):
1084 """Return an iterator in order to be able to loop easily on all the 1085 indices of the object.""" 1086 return IndicesIterator(self.nb_ind)
1087 1088 @staticmethod
1089 - def get_mapping(l1,l2, switch_order=[]):
1090 shift = len(switch_order) 1091 for value in l1: 1092 try: 1093 index = l2.index(value) 1094 except Exception: 1095 raise LorentzObjectRepresentation.LorentzObjectRepresentationError( 1096 "Invalid addition. Object doen't have the same lorentz "+ \ 1097 "indices : %s != %s" % (l1, l2)) 1098 else: 1099 switch_order.append(shift + index) 1100 return switch_order
1101 1102
1103 - def __add__(self, obj, fact=1):
1104 1105 if not obj: 1106 return self 1107 1108 if not hasattr(obj, 'vartype'): 1109 assert self.lorentz_ind == [] 1110 assert self.spin_ind == [] 1111 new = self[(0,)] + obj * fact 1112 out = LorentzObjectRepresentation(new, [], []) 1113 return out 1114 1115 assert(obj.vartype == 4 == self.vartype) # are LorentzObjectRepresentation 1116 1117 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind: 1118 # if the order of indices are different compute a mapping 1119 switch_order = [] 1120 self.get_mapping(self.lorentz_ind, obj.lorentz_ind, switch_order) 1121 self.get_mapping(self.spin_ind, obj.spin_ind, switch_order) 1122 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))]) 1123 else: 1124 # no mapping needed (define switch as identity) 1125 switch = lambda ind : (ind) 1126 1127 # Some sanity check 1128 assert tuple(self.lorentz_ind+self.spin_ind) == tuple(switch(obj.lorentz_ind+obj.spin_ind)), '%s!=%s' % (self.lorentz_ind+self.spin_ind, switch(obj.lorentz_ind+self.spin_ind)) 1129 assert tuple(self.lorentz_ind) == tuple(switch(obj.lorentz_ind)), '%s!=%s' % (tuple(self.lorentz_ind), switch(obj.lorentz_ind)) 1130 1131 # define an empty representation 1132 new = LorentzObjectRepresentation({}, obj.lorentz_ind, obj.spin_ind) 1133 1134 # loop over all indices and fullfill the new object 1135 if fact == 1: 1136 for ind in self.listindices(): 1137 value = obj.get_rep(ind) + self.get_rep(switch(ind)) 1138 new.set_rep(ind, value) 1139 else: 1140 for ind in self.listindices(): 1141 value = fact * obj.get_rep(switch(ind)) + self.get_rep(ind) 1142 new.set_rep(ind, value) 1143 1144 return new
1145
1146 - def __iadd__(self, obj, fact=1):
1147 1148 if not obj: 1149 return self 1150 1151 assert(obj.vartype == 4 == self.vartype) # are LorentzObjectRepresentation 1152 1153 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind: 1154 1155 # if the order of indices are different compute a mapping 1156 switch_order = [] 1157 self.get_mapping(obj.lorentz_ind, self.lorentz_ind, switch_order) 1158 self.get_mapping(obj.spin_ind, self.spin_ind, switch_order) 1159 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))]) 1160 else: 1161 # no mapping needed (define switch as identity) 1162 switch = lambda ind : (ind) 1163 1164 # Some sanity check 1165 assert tuple(switch(self.lorentz_ind+self.spin_ind)) == tuple(obj.lorentz_ind+obj.spin_ind), '%s!=%s' % (switch(self.lorentz_ind+self.spin_ind), (obj.lorentz_ind+obj.spin_ind)) 1166 assert tuple(switch(self.lorentz_ind) )== tuple(obj.lorentz_ind), '%s!=%s' % (switch(self.lorentz_ind), tuple(obj.lorentz_ind)) 1167 1168 # loop over all indices and fullfill the new object 1169 if fact == 1: 1170 for ind in self.listindices(): 1171 self[tuple(ind)] += obj.get_rep(switch(ind)) 1172 else: 1173 for ind in self.listindices(): 1174 self[tuple(ind)] += fact * obj.get_rep(switch(ind)) 1175 return self
1176
1177 - def __sub__(self, obj):
1178 return self.__add__(obj, fact= -1)
1179
1180 - def __rsub__(self, obj):
1181 return obj.__add__(self, fact= -1)
1182
1183 - def __isub__(self, obj):
1184 return self.__add__(obj, fact= -1)
1185
1186 - def __neg__(self):
1187 self *= -1 1188 return self
1189
1190 - def __mul__(self, obj):
1191 """multiplication performing directly the einstein/spin sommation. 1192 """ 1193 1194 if not hasattr(obj, 'vartype'): 1195 out = LorentzObjectRepresentation({}, self.lorentz_ind, self.spin_ind) 1196 for ind in out.listindices(): 1197 out.set_rep(ind, obj * self.get_rep(ind)) 1198 return out 1199 1200 # Sanity Check 1201 assert(obj.__class__ == LorentzObjectRepresentation), \ 1202 '%s is not valid class for this operation' %type(obj) 1203 1204 # compute information on the status of the index (which are contracted/ 1205 #not contracted 1206 l_ind, sum_l_ind = self.compare_indices(self.lorentz_ind, \ 1207 obj.lorentz_ind) 1208 s_ind, sum_s_ind = self.compare_indices(self.spin_ind, \ 1209 obj.spin_ind) 1210 if not(sum_l_ind or sum_s_ind): 1211 # No contraction made a tensor product 1212 return self.tensor_product(obj) 1213 1214 # elsewher made a spin contraction 1215 # create an empty representation but with correct indices 1216 new_object = LorentzObjectRepresentation({}, l_ind, s_ind) 1217 #loop and fullfill the representation 1218 for indices in new_object.listindices(): 1219 #made a dictionary (pos -> index_value) for how call the object 1220 dict_l_ind = self.pass_ind_in_dict(indices[:len(l_ind)], l_ind) 1221 dict_s_ind = self.pass_ind_in_dict(indices[len(l_ind):], s_ind) 1222 #add the new value 1223 new_object.set_rep(indices, \ 1224 self.contraction(obj, sum_l_ind, sum_s_ind, \ 1225 dict_l_ind, dict_s_ind)) 1226 1227 return new_object
1228 1229 __rmul__ = __mul__ 1230 __imul__ = __mul__ 1231
1232 - def contraction(self, obj, l_sum, s_sum, l_dict, s_dict):
1233 """ make the Lorentz/spin contraction of object self and obj. 1234 l_sum/s_sum are the position of the sum indices 1235 l_dict/s_dict are dict given the value of the fix indices (indices->value) 1236 """ 1237 out = 0 # initial value for the output 1238 len_l = len(l_sum) #store len for optimization 1239 len_s = len(s_sum) # same 1240 1241 # loop over the possibility for the sum indices and update the dictionary 1242 # (indices->value) 1243 for l_value in IndicesIterator(len_l): 1244 l_dict.update(self.pass_ind_in_dict(l_value, l_sum)) 1245 for s_value in IndicesIterator(len_s): 1246 #s_dict_final = s_dict.copy() 1247 s_dict.update(self.pass_ind_in_dict(s_value, s_sum)) 1248 1249 #return the indices in the correct order 1250 self_ind = self.combine_indices(l_dict, s_dict) 1251 obj_ind = obj.combine_indices(l_dict, s_dict) 1252 1253 # call the object 1254 factor = obj.get_rep(obj_ind) * self.get_rep(self_ind) 1255 1256 if factor: 1257 #compute the prefactor due to the lorentz contraction 1258 try: 1259 factor.prefactor *= (-1) ** (len(l_value) - l_value.count(0)) 1260 except Exception: 1261 factor *= (-1) ** (len(l_value) - l_value.count(0)) 1262 out += factor 1263 return out
1264
1265 - def tensor_product(self, obj):
1266 """ return the tensorial product of the object""" 1267 assert(obj.vartype == 4) #isinstance(obj, LorentzObjectRepresentation)) 1268 1269 new_object = LorentzObjectRepresentation({}, \ 1270 self.lorentz_ind + obj.lorentz_ind, \ 1271 self.spin_ind + obj.spin_ind) 1272 1273 #some shortcut 1274 lor1 = self.nb_lor 1275 lor2 = obj.nb_lor 1276 spin1 = self.nb_spin 1277 spin2 = obj.nb_spin 1278 1279 #define how to call build the indices first for the first object 1280 if lor1 == 0 == spin1: 1281 #special case for scalar 1282 selfind = lambda indices: [0] 1283 else: 1284 selfind = lambda indices: indices[:lor1] + \ 1285 indices[lor1 + lor2: lor1 + lor2 + spin1] 1286 1287 #then for the second 1288 if lor2 == 0 == spin2: 1289 #special case for scalar 1290 objind = lambda indices: [0] 1291 else: 1292 objind = lambda indices: indices[lor1: lor1 + lor2] + \ 1293 indices[lor1 + lor2 + spin1:] 1294 1295 # loop on the indices and assign the product 1296 for indices in new_object.listindices(): 1297 1298 fac1 = self.get_rep(tuple(selfind(indices))) 1299 fac2 = obj.get_rep(tuple(objind(indices))) 1300 new_object.set_rep(indices, fac1 * fac2) 1301 1302 return new_object
1303
1304 - def factorize(self):
1305 """Try to factorize each component""" 1306 for ind, fact in self.items(): 1307 if fact: 1308 self.set_rep(ind, fact.factorize()) 1309 1310 1311 return self
1312
1313 - def simplify(self):
1314 """Check if we can simplify the object (check for non treated Sum)""" 1315 1316 #Look for internal simplification 1317 for ind, term in self.items(): 1318 if hasattr(term, 'vartype'): 1319 self[ind] = term.simplify() 1320 #no additional simplification 1321 return self
1322 1323 @staticmethod
1324 - def compare_indices(list1, list2):
1325 """return two list, the first one contains the position of non summed 1326 index and the second one the position of summed index.""" 1327 #init object 1328 1329 # equivalent set call --slightly slower 1330 #return list(set(list1) ^ set(list2)), list(set(list1) & set(list2)) 1331 1332 1333 are_unique, are_sum = [], [] 1334 # loop over the first list and check if they are in the second list 1335 1336 for indice in list1: 1337 if indice in list2: 1338 are_sum.append(indice) 1339 else: 1340 are_unique.append(indice) 1341 # loop over the second list for additional unique item 1342 1343 for indice in list2: 1344 if indice not in are_sum: 1345 are_unique.append(indice) 1346 1347 # return value 1348 return are_unique, are_sum
1349 1350 @staticmethod
1351 - def pass_ind_in_dict(indices, key):
1352 """made a dictionary (pos -> index_value) for how call the object""" 1353 if not key: 1354 return {} 1355 out = {} 1356 for i, ind in enumerate(indices): 1357 out[key[i]] = ind 1358 return out
1359
1360 - def combine_indices(self, l_dict, s_dict):
1361 """return the indices in the correct order following the dicts rules""" 1362 1363 out = [] 1364 # First for the Lorentz indices 1365 for value in self.lorentz_ind: 1366 out.append(l_dict[value]) 1367 # Same for the spin 1368 for value in self.spin_ind: 1369 out.append(s_dict[value]) 1370 1371 return out
1372
1373 - def split(self, variables_id):
1374 """return a dict with the key being the power associated to each variables 1375 and the value being the object remaining after the suppression of all 1376 the variable""" 1377 1378 out = SplitCoefficient() 1379 zero_rep = {} 1380 for ind in self.listindices(): 1381 zero_rep[tuple(ind)] = 0 1382 1383 for ind in self.listindices(): 1384 for key, value in self.get_rep(ind).split(variables_id).items(): 1385 if key in out: 1386 out[key][tuple(ind)] += value 1387 else: 1388 out[key] = LorentzObjectRepresentation(dict(zero_rep), 1389 self.lorentz_ind, self.spin_ind) 1390 out[key][tuple(ind)] += value 1391 1392 return out
1393
1394 1395 1396 1397 #=============================================================================== 1398 # IndicesIterator 1399 #=============================================================================== 1400 -class IndicesIterator:
1401 """Class needed for the iterator""" 1402
1403 - def __init__(self, len):
1404 """ create an iterator looping over the indices of a list of len "len" 1405 with each value can take value between 0 and 3 """ 1406 1407 self.len = len # number of indices 1408 if len: 1409 # initialize the position. The first position is -1 due to the method 1410 #in place which start by rising an index before returning smtg 1411 self.data = [-1] + [0] * (len - 1) 1412 else: 1413 # Special case for Scalar object 1414 self.data = 0 1415 self.next = self.nextscalar
1416
1417 - def __iter__(self):
1418 return self
1419
1420 - def next(self):
1421 for i in range(self.len): 1422 if self.data[i] < 3: 1423 self.data[i] += 1 1424 return self.data 1425 else: 1426 self.data[i] = 0 1427 raise StopIteration
1428
1429 - def nextscalar(self):
1430 if self.data: 1431 raise StopIteration 1432 else: 1433 self.data = True 1434 return [0]
1435
1436 -class SplitCoefficient(dict):
1437
1438 - def __init__(self, *args, **opt):
1439 dict.__init__(self, *args, **opt) 1440 self.tag=set()
1441
1442 - def get_max_rank(self):
1443 """return the highest rank of the coefficient""" 1444 1445 return max([max(arg[:4]) for arg in self])
1446 1447 1448 if '__main__' ==__name__: 1449 1450 import cProfile
1451 - def create():
1452 for i in range(10000): 1453 LorentzObjectRepresentation.compare_indices(range(i%10),[4,3,5])
1454 1455 cProfile.run('create()') 1456